diff options
| -rw-r--r-- | guests_macro/src/parse_fn.rs | 34 | ||||
| -rw-r--r-- | zkvms/nexus/wrapper_macro/src/lib.rs | 2 | ||||
| -rw-r--r-- | zkvms/risc0/wrapper_macro/src/lib.rs | 2 | ||||
| -rw-r--r-- | zkvms_host_io/input_macros/src/lib.rs | 10 |
4 files changed, 32 insertions, 16 deletions
diff --git a/guests_macro/src/parse_fn.rs b/guests_macro/src/parse_fn.rs index e5558ad..535d27f 100644 --- a/guests_macro/src/parse_fn.rs +++ b/guests_macro/src/parse_fn.rs @@ -79,12 +79,19 @@ pub fn args_split(item: &TokenStream) -> Vec<TokenStream> { } /// Input: "(p1 : t1, p2: t2, ...)" -/// Output: vec!["p1 : t1", "p2: t2", ...] -pub fn args_split_public(item: &TokenStream, public: &Vec<&String>) -> Vec<TokenStream> { - args_split(item) +/// Output: vec!["p1 : t1", "p2: t2", ...], vec!["p1 : t1", "p2: t2", ...] +pub fn args_split_public(item: &TokenStream, public: &Vec<&String>) -> (Vec<TokenStream>, Vec<TokenStream>) { + let all_args = args_split(item); + let public_args: Vec<TokenStream> = all_args + .clone() .into_iter() .filter(|a| public.iter().any(|x| a.to_string().starts_with(*x))) - .collect() + .collect(); + let private_args: Vec<TokenStream> = all_args + .into_iter() + .filter(|t| !public_args.iter().any(|pt| *t.to_string() == pt.to_string())) + .collect(); + (public_args, private_args) } /// Input: "(p1 : t1, p2: t2, ...)" @@ -139,14 +146,23 @@ pub fn args_divide(item: &TokenStream) -> (Vec<TokenStream>, Vec<TokenStream>) { } /// Input: "(p1 : t1, p2: t2, ...)" -/// Output: vec![p1, p2, ...], vec![t1, t2, ...] -pub fn args_divide_public(item: &TokenStream, public: &Vec<&String>) -> (Vec<TokenStream>, Vec<TokenStream>) { +/// Output: (vec![p1, p2, ...], vec![t1, t2, ...]), (vec![p1, p2, ...], vec![t1, t2, ...]) +pub fn args_divide_public(item: &TokenStream, public: &Vec<&String>) -> ((Vec<TokenStream>, Vec<TokenStream>), (Vec<TokenStream>, Vec<TokenStream>)) { let (patterns, types) = args_divide(item); - patterns + + let (public_patterns, public_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns + .clone() .into_iter() - .zip(types.into_iter()) + .zip(types.clone().into_iter()) .filter(|(p, _)| public.iter().any(|x| p.to_string() == **x)) - .unzip() + .unzip(); + + let (private_patterns, private_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns + .into_iter() + .zip(types.into_iter()) + .filter(|(p, _)| !public_patterns.iter().any(|x| p.to_string() == x.to_string())) + .unzip(); + ((public_patterns, public_types), (private_patterns, private_types)) } /// Input: "(p1 : t1, p2: t2, ...)" diff --git a/zkvms/nexus/wrapper_macro/src/lib.rs b/zkvms/nexus/wrapper_macro/src/lib.rs index 0814e2a..89ddc05 100644 --- a/zkvms/nexus/wrapper_macro/src/lib.rs +++ b/zkvms/nexus/wrapper_macro/src/lib.rs @@ -17,7 +17,7 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream { include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")) ) .unwrap(); - let (public_patterns, public_types) = args_divide_public(&args, &public_inputs.keys().collect()); + let (public_patterns, public_types) = args_divide_public(&args, &public_inputs.keys().collect()).0; let public_patterns: Vec<(TokenStream, TokenStream)> = public_patterns .into_iter() .zip(public_types.into_iter()) diff --git a/zkvms/risc0/wrapper_macro/src/lib.rs b/zkvms/risc0/wrapper_macro/src/lib.rs index 5f924a8..22afa44 100644 --- a/zkvms/risc0/wrapper_macro/src/lib.rs +++ b/zkvms/risc0/wrapper_macro/src/lib.rs @@ -18,7 +18,7 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream { include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")) ) .unwrap(); - let public_patterns = args_divide_public(&args, &public_inputs.keys().collect()).0; + let public_patterns = args_divide_public(&args, &public_inputs.keys().collect()).0.0; for pattern in public_patterns.iter() { out.extend(format!("commit(&{});", pattern).parse::<TokenStream>()); } diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs index ea59460..49cff80 100644 --- a/zkvms_host_io/input_macros/src/lib.rs +++ b/zkvms_host_io/input_macros/src/lib.rs @@ -23,28 +23,28 @@ pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { ) .unwrap(); let public_types = args_divide_public(&args, &public_inputs.keys().collect()) + .0 .1 .iter() .map(|x| x.to_string() + ", ") .collect::<String>(); let output_type = format!("pub type Output = ({} {});", public_types, ret).to_string(); - let all_args = args_split(&args); - - let public_args = args_split_public(&args, &public_inputs.keys().collect()); + let (public_args, private_args) = args_split_public(&args, &public_inputs.keys().collect()); let public_attrs = public_args .iter() .map(|x| format!("pub {x},")) .collect::<String>(); let public_input_type = format!("{} pub struct PublicInput {{ {} }}", DERIVES, public_attrs).to_string(); - let private_attrs = all_args + let private_attrs = private_args .iter() - .filter(|t| !public_args.iter().any(|pt| *t.to_string() == pt.to_string())) .map(|x| format!("pub {x},")) .collect::<String>(); let private_input_type = format!("{} pub struct PrivateInput {{ {} }}", DERIVES, private_attrs).to_string(); + let all_args = args_split(&args); + let mut struct_def = format!("{} pub struct Input {{", DERIVES); for arg in all_args { struct_def += &format!("pub {arg},"); |
