diff options
Diffstat (limited to 'zkvms/jolt/wrapper_macro/src/lib.rs')
| -rw-r--r-- | zkvms/jolt/wrapper_macro/src/lib.rs | 38 |
1 files changed, 17 insertions, 21 deletions
diff --git a/zkvms/jolt/wrapper_macro/src/lib.rs b/zkvms/jolt/wrapper_macro/src/lib.rs index 5f73e70..b35a79a 100644 --- a/zkvms/jolt/wrapper_macro/src/lib.rs +++ b/zkvms/jolt/wrapper_macro/src/lib.rs @@ -4,7 +4,7 @@ use quote::quote; #[path = "../../../../guests_macro/src/parse_fn.rs"] mod parse_fn; -use crate::parse_fn::{args_divide, args_split, group_streams, split_fn}; +use crate::parse_fn::FunctionDefinition; /// Create a set of three helper functions. /// @@ -44,37 +44,33 @@ use crate::parse_fn::{args_divide, args_split, group_streams, split_fn}; /// ``` #[proc_macro] pub fn make_wrapper(item: TokenStream) -> TokenStream { - let (name, args, ret) = split_fn(&item); - - let (patterns, types) = args_divide(&args); - let ts_patterns = group_streams(&patterns); - let ts_types = group_streams(&types); + let fd = FunctionDefinition::new(&item); let mut out = TokenStream::new(); - out.extend(format!("zkp::{}{}", name, ts_patterns).parse::<TokenStream>()); + out.extend(format!("zkp::{}({})", fd.name, fd.grouped_patterns()).parse::<TokenStream>()); let mut func = TokenStream::new(); func.extend( format!( "#[jolt::provable(max_input_size = 100000)] fn guest{} -> {} {{ {} }}", - args, ret, out + fd.args, fd.return_type, out ) .parse::<TokenStream>(), ); func.extend(make_build_fn( - patterns.clone(), - types.clone(), - ts_patterns.clone(), - ts_types.clone(), - ret.clone(), + fd.patterns().clone(), + fd.types().clone(), + fd.grouped_patterns().clone(), + fd.grouped_types().clone(), + fd.return_type.clone(), )); func.extend(make_preprocess_fn( - patterns, - types, - ts_patterns, - ts_types, - ret, + fd.patterns().clone(), + fd.types().clone(), + fd.grouped_patterns().clone(), + fd.grouped_types().clone(), + fd.return_type.clone(), )); func } @@ -104,7 +100,7 @@ fn make_build_fn( quote! { #[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))] pub fn guest_closures(elf_path: String) -> ( - impl Fn(#ts_types) -> (#ret, jolt::JoltHyperKZGProof) + Sync + Send, + impl Fn((#ts_types)) -> (#ret, jolt::JoltHyperKZGProof) + Sync + Send, impl Fn(jolt::JoltHyperKZGProof) -> bool + Sync + Send ) { #imports @@ -115,10 +111,10 @@ fn make_build_fn( let program_cp = program.clone(); let preprocessing_cp = preprocessing.clone(); - let prove_closure = move |args: #ts_types| { + let prove_closure = move |args: (#ts_types)| { let program = (*program).clone(); let preprocessing = (*preprocessing).clone(); - let #ts_patterns = args; + let (#ts_patterns) = args; prove_guest(program, preprocessing, #(#patterns),*) }; |
