diff options
| -rw-r--r-- | guests_macro/src/parse_fn.rs | 9 | ||||
| -rw-r--r-- | zkvms_host_io/input_macros/src/lib.rs | 24 | ||||
| -rw-r--r-- | zkvms_host_io/src/lib.rs | 8 |
3 files changed, 36 insertions, 5 deletions
diff --git a/guests_macro/src/parse_fn.rs b/guests_macro/src/parse_fn.rs index 0e57879..e5558ad 100644 --- a/guests_macro/src/parse_fn.rs +++ b/guests_macro/src/parse_fn.rs @@ -79,6 +79,15 @@ pub fn args_split(item: &TokenStream) -> Vec<TokenStream> { } /// Input: "(p1 : t1, p2: t2, ...)" +/// Output: vec!["p1 : t1", "p2: t2", ...] +pub fn args_split_public(item: &TokenStream, public: &Vec<&String>) -> Vec<TokenStream> { + args_split(item) + .into_iter() + .filter(|a| public.iter().any(|x| a.to_string().starts_with(*x))) + .collect() +} + +/// Input: "(p1 : t1, p2: t2, ...)" /// Output: vec![p1, p2, ...], vec![t1, t2, ...] pub fn args_divide(item: &TokenStream) -> (Vec<TokenStream>, Vec<TokenStream>) { let contents; diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs index 5e36466..ea59460 100644 --- a/zkvms_host_io/input_macros/src/lib.rs +++ b/zkvms_host_io/input_macros/src/lib.rs @@ -2,7 +2,7 @@ use proc_macro::TokenStream; #[path = "../../../guests_macro/src/parse_fn.rs"] mod parse_fn; -use crate::parse_fn::{ args_split, args_divide, args_divide_public, group_streams }; +use crate::parse_fn::{ args_split, args_split_public, args_divide, args_divide_public, group_streams }; fn get_types() -> (TokenStream, TokenStream) { let types: Vec<&str> = include_str!("../../../guests/type.txt") @@ -11,9 +11,12 @@ fn get_types() -> (TokenStream, TokenStream) { (types[0].parse::<TokenStream>().unwrap(), types[1].parse::<TokenStream>().unwrap()) } +static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"; + #[proc_macro] pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { let (args, ret) = get_types(); + let (patterns, types) = args_divide(&args); let public_inputs = toml::from_str::<toml::Table>( include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")) @@ -28,12 +31,25 @@ pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { let all_args = args_split(&args); - let mut struct_def = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Input {".to_string(); + let public_args = args_split_public(&args, &public_inputs.keys().collect()); + let public_attrs = public_args + .iter() + .map(|x| format!("pub {x},")) + .collect::<String>(); + let public_input_type = format!("{} pub struct PublicInput {{ {} }}", DERIVES, public_attrs).to_string(); + + let private_attrs = all_args + .iter() + .filter(|t| !public_args.iter().any(|pt| *t.to_string() == pt.to_string())) + .map(|x| format!("pub {x},")) + .collect::<String>(); + let private_input_type = format!("{} pub struct PrivateInput {{ {} }}", DERIVES, private_attrs).to_string(); + + let mut struct_def = format!("{} pub struct Input {{", DERIVES); for arg in all_args { struct_def += &format!("pub {arg},"); } - let (patterns, types) = args_divide(&args); let types = group_streams(&types); struct_def += &format!("}} impl From<Input> for {types} {{ @@ -46,7 +62,7 @@ pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { } struct_def += ") } }"; - (output_type + &struct_def).parse::<TokenStream>().unwrap() + (output_type + &public_input_type + &private_input_type + &struct_def).parse::<TokenStream>().unwrap() } #[proc_macro] diff --git a/zkvms_host_io/src/lib.rs b/zkvms_host_io/src/lib.rs index d63cb56..2624352 100644 --- a/zkvms_host_io/src/lib.rs +++ b/zkvms_host_io/src/lib.rs @@ -31,6 +31,8 @@ pub enum RunType { pub struct RunWith { pub run_type: RunType, pub input: Input, + pub public_input: PublicInput, + pub private_input: PrivateInput, pub default_env: HashMap<String, String>, } @@ -66,13 +68,17 @@ pub fn read_args() -> RunWith { } else { DEFAULT_PRIVATE_INPUT.to_string() }; - let input: Input = toml::from_str(&(public_contents + &private_contents)).unwrap(); + let input: Input = toml::from_str(&(public_contents.clone() + &private_contents)).unwrap(); + let public_input: PublicInput = toml::from_str(&public_contents).unwrap(); + let private_input: PrivateInput = toml::from_str(&private_contents).unwrap(); let default_env = read_str(DEFAULT_ENV).unwrap(); RunWith { run_type: cli.run_type, input, + public_input, + private_input, default_env, } } |
