aboutsummaryrefslogtreecommitdiff
path: root/zkvms_host_io/input_macros/src
diff options
context:
space:
mode:
authorKamen Mladenov <kamen@syndamia.com>2025-04-07 16:01:45 +0300
committerKamen Mladenov <kamen@syndamia.com>2025-04-07 17:41:20 +0300
commit882f8e99642931dbed953755344e1ebee4092db8 (patch)
tree151f5ec01e5cc50b411f326c2bd8cdea52f52025 /zkvms_host_io/input_macros/src
parent0fbc78777ead39adba3950edd8e8b92ae1db3482 (diff)
downloadzkVMs-benchmarks-882f8e99642931dbed953755344e1ebee4092db8.tar
zkVMs-benchmarks-882f8e99642931dbed953755344e1ebee4092db8.tar.gz
zkVMs-benchmarks-882f8e99642931dbed953755344e1ebee4092db8.zip
feat(zkvms_host_io): Update to use FunctionDefinition
Diffstat (limited to 'zkvms_host_io/input_macros/src')
-rw-r--r--zkvms_host_io/input_macros/src/lib.rs113
1 files changed, 36 insertions, 77 deletions
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs
index 72c4d11..e66a0a4 100644
--- a/zkvms_host_io/input_macros/src/lib.rs
+++ b/zkvms_host_io/input_macros/src/lib.rs
@@ -2,20 +2,12 @@ use proc_macro::TokenStream;
#[path = "../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
-use crate::parse_fn::{
- args_divide, args_divide_public, args_split, args_split_public, group_streams,
-};
+use crate::parse_fn::FunctionDefinition;
/// Parses the `guests/type.txt` type note, created from the guest
/// Returns a tuple of the arguments group and the return type
-fn get_types() -> (TokenStream, TokenStream) {
- let types: Vec<&str> = include_str!("../../../guests/type.txt")
- .split('\n')
- .collect();
- (
- types[0].parse::<TokenStream>().unwrap(),
- types[1].parse::<TokenStream>().unwrap(),
- )
+fn new_fd() -> FunctionDefinition {
+ FunctionDefinition::new(&include_str!("../../../guests/type.txt").parse().unwrap())
}
static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]";
@@ -72,33 +64,23 @@ static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deseria
/// ```
#[proc_macro]
pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream {
- let (args, ret) = get_types();
- let (patterns, types) = args_divide(&args);
+ let fd = new_fd();
- let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!(
- env!("INPUTS_DIR"),
- "/default_public_input.toml"
- )))
- .unwrap();
- let public_types = args_divide_public(&args, &public_inputs.keys().collect())
- .0
- .1
- .iter()
- .map(|x| x.to_string() + ", ")
- .collect::<String>();
- let output_type = format!("pub type Output = ({} {});", public_types, ret).to_string();
+ let sep = if fd.types().is_empty() { "" } else { ", " };
+ let output_type = format!("pub type Output = ({} {} {});", fd.grouped_public_types(), sep, fd.return_type).to_string();
- let return_type = format!("pub type Return = {};", ret).to_string();
+ let return_type = format!("pub type Return = {};", fd.return_type).to_string();
- let (public_args, private_args) = args_split_public(&args, &public_inputs.keys().collect());
- let public_attrs = public_args
+ let public_attrs = fd
+ .public_arguments()
.iter()
.map(|x| format!("pub {x},"))
.collect::<String>();
let public_input_type =
format!("{} pub struct PublicInput {{ {} }}", DERIVES, public_attrs).to_string();
- let private_attrs = private_args
+ let private_attrs = fd
+ .private_arguments()
.iter()
.map(|x| format!("pub {x},"))
.collect::<String>();
@@ -108,26 +90,28 @@ pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream {
)
.to_string();
- let all_args = args_split(&args);
-
- let mut struct_def = format!("{} pub struct Input {{", DERIVES);
- for arg in all_args {
- struct_def += &format!("pub {arg},");
- }
-
- let types = group_streams(&types);
- struct_def += &format!(
- "}}
- impl From<Input> for {types} {{
- fn from(input: Input) -> {types} {{
- (
- "
- );
-
- for field in patterns {
- struct_def += &format!("input.{field},");
- }
- struct_def += ") } }";
+ let attrs = fd
+ .arguments()
+ .iter()
+ .map(|x| format!("pub {x},"))
+ .collect::<String>();
+ let convertion = fd
+ .patterns()
+ .clone()
+ .iter()
+ .map(|x| format!("input.{x},"))
+ .collect::<String>();
+ let types = fd.grouped_types();
+ let struct_def = &format!("
+ {DERIVES} pub struct Input {{
+ {attrs}
+ }}
+ impl From<Input> for ({types}) {{
+ fn from(input: Input) -> ({types}) {{
+ ({convertion})
+ }}
+ }}
+ ").to_string();
(output_type + &return_type + &public_input_type + &private_input_type + &struct_def)
.parse::<TokenStream>()
@@ -151,10 +135,7 @@ fn foreach_field(item: TokenStream, fields: Vec<TokenStream>) -> TokenStream {
/// field name.
#[proc_macro]
pub fn foreach_input_field(item: TokenStream) -> TokenStream {
- let (args, _) = get_types();
- let arg_patterns = args_divide(&args).0;
-
- foreach_field(item, arg_patterns)
+ foreach_field(item, new_fd().patterns().clone())
}
/// Repeats the given code as many times as fields there are in the
@@ -162,18 +143,7 @@ pub fn foreach_input_field(item: TokenStream) -> TokenStream {
/// concrete field name.
#[proc_macro]
pub fn foreach_public_input_field(item: TokenStream) -> TokenStream {
- let (args, _) = get_types();
-
- let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!(
- env!("INPUTS_DIR"),
- "/default_public_input.toml"
- )))
- .unwrap();
- let public_patterns = args_divide_public(&args, &public_inputs.keys().collect())
- .0
- .0;
-
- foreach_field(item, public_patterns)
+ foreach_field(item, new_fd().public_patterns().clone())
}
/// Repeats the given code as many times as fields there are in the
@@ -181,18 +151,7 @@ pub fn foreach_public_input_field(item: TokenStream) -> TokenStream {
/// concrete field name.
#[proc_macro]
pub fn foreach_private_input_field(item: TokenStream) -> TokenStream {
- let (args, _) = get_types();
-
- let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!(
- env!("INPUTS_DIR"),
- "/default_public_input.toml"
- )))
- .unwrap();
- let private_patterns = args_divide_public(&args, &public_inputs.keys().collect())
- .1
- .0;
-
- foreach_field(item, private_patterns)
+ foreach_field(item, new_fd().private_patterns().clone())
}
/// Assuming the `run_info` variable is present, it creates a block with all