diff options
Diffstat (limited to 'zkvms_host_io/input_macros/src')
| -rw-r--r-- | zkvms_host_io/input_macros/src/lib.rs | 113 |
1 files changed, 36 insertions, 77 deletions
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs index 72c4d11..e66a0a4 100644 --- a/zkvms_host_io/input_macros/src/lib.rs +++ b/zkvms_host_io/input_macros/src/lib.rs @@ -2,20 +2,12 @@ use proc_macro::TokenStream; #[path = "../../../guests_macro/src/parse_fn.rs"] mod parse_fn; -use crate::parse_fn::{ - args_divide, args_divide_public, args_split, args_split_public, group_streams, -}; +use crate::parse_fn::FunctionDefinition; /// Parses the `guests/type.txt` type note, created from the guest /// Returns a tuple of the arguments group and the return type -fn get_types() -> (TokenStream, TokenStream) { - let types: Vec<&str> = include_str!("../../../guests/type.txt") - .split('\n') - .collect(); - ( - types[0].parse::<TokenStream>().unwrap(), - types[1].parse::<TokenStream>().unwrap(), - ) +fn new_fd() -> FunctionDefinition { + FunctionDefinition::new(&include_str!("../../../guests/type.txt").parse().unwrap()) } static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"; @@ -72,33 +64,23 @@ static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deseria /// ``` #[proc_macro] pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { - let (args, ret) = get_types(); - let (patterns, types) = args_divide(&args); + let fd = new_fd(); - let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!( - env!("INPUTS_DIR"), - "/default_public_input.toml" - ))) - .unwrap(); - let public_types = args_divide_public(&args, &public_inputs.keys().collect()) - .0 - .1 - .iter() - .map(|x| x.to_string() + ", ") - .collect::<String>(); - let output_type = format!("pub type Output = ({} {});", public_types, ret).to_string(); + let sep = if fd.types().is_empty() { "" } else { ", " }; + let output_type = format!("pub type Output = ({} {} {});", fd.grouped_public_types(), sep, fd.return_type).to_string(); - let return_type = format!("pub type Return = {};", ret).to_string(); + let return_type = format!("pub type Return = {};", fd.return_type).to_string(); - let (public_args, private_args) = args_split_public(&args, &public_inputs.keys().collect()); - let public_attrs = public_args + let public_attrs = fd + .public_arguments() .iter() .map(|x| format!("pub {x},")) .collect::<String>(); let public_input_type = format!("{} pub struct PublicInput {{ {} }}", DERIVES, public_attrs).to_string(); - let private_attrs = private_args + let private_attrs = fd + .private_arguments() .iter() .map(|x| format!("pub {x},")) .collect::<String>(); @@ -108,26 +90,28 @@ pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream { ) .to_string(); - let all_args = args_split(&args); - - let mut struct_def = format!("{} pub struct Input {{", DERIVES); - for arg in all_args { - struct_def += &format!("pub {arg},"); - } - - let types = group_streams(&types); - struct_def += &format!( - "}} - impl From<Input> for {types} {{ - fn from(input: Input) -> {types} {{ - ( - " - ); - - for field in patterns { - struct_def += &format!("input.{field},"); - } - struct_def += ") } }"; + let attrs = fd + .arguments() + .iter() + .map(|x| format!("pub {x},")) + .collect::<String>(); + let convertion = fd + .patterns() + .clone() + .iter() + .map(|x| format!("input.{x},")) + .collect::<String>(); + let types = fd.grouped_types(); + let struct_def = &format!(" + {DERIVES} pub struct Input {{ + {attrs} + }} + impl From<Input> for ({types}) {{ + fn from(input: Input) -> ({types}) {{ + ({convertion}) + }} + }} + ").to_string(); (output_type + &return_type + &public_input_type + &private_input_type + &struct_def) .parse::<TokenStream>() @@ -151,10 +135,7 @@ fn foreach_field(item: TokenStream, fields: Vec<TokenStream>) -> TokenStream { /// field name. #[proc_macro] pub fn foreach_input_field(item: TokenStream) -> TokenStream { - let (args, _) = get_types(); - let arg_patterns = args_divide(&args).0; - - foreach_field(item, arg_patterns) + foreach_field(item, new_fd().patterns().clone()) } /// Repeats the given code as many times as fields there are in the @@ -162,18 +143,7 @@ pub fn foreach_input_field(item: TokenStream) -> TokenStream { /// concrete field name. #[proc_macro] pub fn foreach_public_input_field(item: TokenStream) -> TokenStream { - let (args, _) = get_types(); - - let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!( - env!("INPUTS_DIR"), - "/default_public_input.toml" - ))) - .unwrap(); - let public_patterns = args_divide_public(&args, &public_inputs.keys().collect()) - .0 - .0; - - foreach_field(item, public_patterns) + foreach_field(item, new_fd().public_patterns().clone()) } /// Repeats the given code as many times as fields there are in the @@ -181,18 +151,7 @@ pub fn foreach_public_input_field(item: TokenStream) -> TokenStream { /// concrete field name. #[proc_macro] pub fn foreach_private_input_field(item: TokenStream) -> TokenStream { - let (args, _) = get_types(); - - let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!( - env!("INPUTS_DIR"), - "/default_public_input.toml" - ))) - .unwrap(); - let private_patterns = args_divide_public(&args, &public_inputs.keys().collect()) - .1 - .0; - - foreach_field(item, private_patterns) + foreach_field(item, new_fd().private_patterns().clone()) } /// Assuming the `run_info` variable is present, it creates a block with all |
