aboutsummaryrefslogtreecommitdiff
path: root/zkvms/jolt/wrapper_macro/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zkvms/jolt/wrapper_macro/src/lib.rs')
-rw-r--r--zkvms/jolt/wrapper_macro/src/lib.rs38
1 files changed, 17 insertions, 21 deletions
diff --git a/zkvms/jolt/wrapper_macro/src/lib.rs b/zkvms/jolt/wrapper_macro/src/lib.rs
index 5f73e70..b35a79a 100644
--- a/zkvms/jolt/wrapper_macro/src/lib.rs
+++ b/zkvms/jolt/wrapper_macro/src/lib.rs
@@ -4,7 +4,7 @@ use quote::quote;
#[path = "../../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
-use crate::parse_fn::{args_divide, args_split, group_streams, split_fn};
+use crate::parse_fn::FunctionDefinition;
/// Create a set of three helper functions.
///
@@ -44,37 +44,33 @@ use crate::parse_fn::{args_divide, args_split, group_streams, split_fn};
/// ```
#[proc_macro]
pub fn make_wrapper(item: TokenStream) -> TokenStream {
- let (name, args, ret) = split_fn(&item);
-
- let (patterns, types) = args_divide(&args);
- let ts_patterns = group_streams(&patterns);
- let ts_types = group_streams(&types);
+ let fd = FunctionDefinition::new(&item);
let mut out = TokenStream::new();
- out.extend(format!("zkp::{}{}", name, ts_patterns).parse::<TokenStream>());
+ out.extend(format!("zkp::{}({})", fd.name, fd.grouped_patterns()).parse::<TokenStream>());
let mut func = TokenStream::new();
func.extend(
format!(
"#[jolt::provable(max_input_size = 100000)] fn guest{} -> {} {{ {} }}",
- args, ret, out
+ fd.args, fd.return_type, out
)
.parse::<TokenStream>(),
);
func.extend(make_build_fn(
- patterns.clone(),
- types.clone(),
- ts_patterns.clone(),
- ts_types.clone(),
- ret.clone(),
+ fd.patterns().clone(),
+ fd.types().clone(),
+ fd.grouped_patterns().clone(),
+ fd.grouped_types().clone(),
+ fd.return_type.clone(),
));
func.extend(make_preprocess_fn(
- patterns,
- types,
- ts_patterns,
- ts_types,
- ret,
+ fd.patterns().clone(),
+ fd.types().clone(),
+ fd.grouped_patterns().clone(),
+ fd.grouped_types().clone(),
+ fd.return_type.clone(),
));
func
}
@@ -104,7 +100,7 @@ fn make_build_fn(
quote! {
#[cfg(all(not(target_arch = "wasm32"), not(feature = "guest")))]
pub fn guest_closures(elf_path: String) -> (
- impl Fn(#ts_types) -> (#ret, jolt::JoltHyperKZGProof) + Sync + Send,
+ impl Fn((#ts_types)) -> (#ret, jolt::JoltHyperKZGProof) + Sync + Send,
impl Fn(jolt::JoltHyperKZGProof) -> bool + Sync + Send
) {
#imports
@@ -115,10 +111,10 @@ fn make_build_fn(
let program_cp = program.clone();
let preprocessing_cp = preprocessing.clone();
- let prove_closure = move |args: #ts_types| {
+ let prove_closure = move |args: (#ts_types)| {
let program = (*program).clone();
let preprocessing = (*preprocessing).clone();
- let #ts_patterns = args;
+ let (#ts_patterns) = args;
prove_guest(program, preprocessing, #(#patterns),*)
};