diff options
| author | Kamen Mladenov <kamen@syndamia.com> | 2025-04-07 16:00:56 +0300 |
|---|---|---|
| committer | Kamen Mladenov <kamen@syndamia.com> | 2025-04-07 17:41:20 +0300 |
| commit | 0fbc78777ead39adba3950edd8e8b92ae1db3482 (patch) | |
| tree | 75db691d4f61fd8c375e583d3b26a35fbfa8ead3 | |
| parent | 7498d604be92a0b1a7a5603e0295f194aa8b05e7 (diff) | |
| download | zkVMs-benchmarks-0fbc78777ead39adba3950edd8e8b92ae1db3482.tar zkVMs-benchmarks-0fbc78777ead39adba3950edd8e8b92ae1db3482.tar.gz zkVMs-benchmarks-0fbc78777ead39adba3950edd8e8b92ae1db3482.zip | |
feat(guests_macro): Replace parse_fn mod with struct
The old method of having functions for every parsing action of a
function definition produced messy results. We're replacing it with a
struct which will hold a wide assortment of parsed values. This makes
the interface much sleaker with only a minor increase in code.
The downside is a lot of data gets repeated, however since this struct
will only be used in macros, i.e. compile-time, that doesn't matter too
much.
| -rw-r--r-- | guests_macro/Cargo.toml | 3 | ||||
| -rw-r--r-- | guests_macro/src/lib.rs | 10 | ||||
| -rw-r--r-- | guests_macro/src/parse_fn.rs | 451 |
3 files changed, 235 insertions, 229 deletions
diff --git a/guests_macro/Cargo.toml b/guests_macro/Cargo.toml index 6173b91..5831d1c 100644 --- a/guests_macro/Cargo.toml +++ b/guests_macro/Cargo.toml @@ -6,3 +6,6 @@ edition = "2021" [lib] proc-macro = true + +[dependencies] +toml = "0.8.19" diff --git a/guests_macro/src/lib.rs b/guests_macro/src/lib.rs index 2a35292..5e78ac0 100644 --- a/guests_macro/src/lib.rs +++ b/guests_macro/src/lib.rs @@ -37,23 +37,23 @@ mod parse_fn; /// ``` #[proc_macro_attribute] pub fn proving_entrypoint(_: TokenStream, mut item: TokenStream) -> TokenStream { - let (name, args, ret) = parse_fn::split_fn(&item); + let fd = parse_fn::FunctionDefinition::new(&item); + let fn_type = format!("fn {}{} -> {}", fd.name, fd.args, fd.return_type).replace('\n', " "); // We also need to pass some type information to the host program compile-time. // Put it in the file guests/type.txt. let mut output = File::create("../type.txt").unwrap(); - writeln!(output, "{}", &format!("{args}").replace('\n', " ")); - write!(output, "{}", &format!("{ret}").replace('\n', " ")); + write!(output, "{fn_type}"); item.extend( format!( "#[macro_export] macro_rules! entrypoint_expr {{ () => {{ - make_wrapper!{{{}{} -> {}}} + make_wrapper!{{ {} }} }}; }}", - name, args, ret + fn_type ) .parse::<TokenStream>(), ); diff --git a/guests_macro/src/parse_fn.rs b/guests_macro/src/parse_fn.rs index 98d1017..3d1bc4c 100644 --- a/guests_macro/src/parse_fn.rs +++ b/guests_macro/src/parse_fn.rs @@ -7,264 +7,267 @@ use proc_macro::{Delimiter, Group, Spacing, TokenStream, TokenTree}; -/// Split function definition into triplet of name, arguments and output types. -/// -/// **Input:** "fn name(...) -> ... { ..... }" -/// **Output:** "name", "(...)", "..." -pub fn split_fn(item: &TokenStream) -> (TokenStream, TokenStream, TokenStream) { - let item = item.clone().into_iter(); +pub struct FunctionDefinition { + pub name: TokenStream, + pub args: TokenStream, + pub return_type: TokenStream, - let mut name = TokenStream::new(); - let mut args = TokenStream::new(); - let mut ret = TokenStream::new(); - let mut out: &mut TokenStream = &mut name; + patterns: Vec<TokenStream>, + types: Vec<TokenStream>, - for tt in item { - match tt { - // The conditions will later be used to return - // errors when incorrect function type is used - TokenTree::Ident(ref ident) => { - if ident.to_string() == "fn" || ident.to_string() == "pub" { - continue; - } - } - TokenTree::Punct(ref punct) => { - if punct.as_char() == '-' { - out = &mut ret; - continue; - } - if punct.as_char() == '>' && out.is_empty() { - continue; - } - } - TokenTree::Group(ref group) => { - if group.delimiter() == Delimiter::Brace { - break; - } - if !out.is_empty() { - out = &mut args; - } - } - TokenTree::Literal(_) => unreachable!("Cannot have literal inside def!"), + public_patterns: Vec<TokenStream>, + public_types: Vec<TokenStream>, + + private_patterns: Vec<TokenStream>, + private_types: Vec<TokenStream>, +} + +impl FunctionDefinition { + pub fn new(item: &TokenStream) -> FunctionDefinition { + let (name, args, return_type) = Self::split_fn(item); + let (patterns, types) = Self::args_divide(&args); + + let public_inputs = toml::from_str::<toml::Table>(include_str!(concat!( + env!("INPUTS_DIR"), + "/default_public_input.toml" + ))) + .unwrap(); + let ((public_patterns, public_types), (private_patterns, private_types)) = + Self::args_divide_public(&patterns, &types, &public_inputs.keys().collect()); + + FunctionDefinition { + name, args, return_type, + patterns, types, + public_patterns, public_types, + private_patterns, private_types, } - out.extend([tt].into_iter()); } - if ret.is_empty() { - ret.extend( - [TokenTree::Group(Group::new( - Delimiter::Parenthesis, - TokenStream::new(), - ))] - .into_iter(), - ); + pub fn patterns(&self) -> &Vec<TokenStream> { + &self.patterns + } + pub fn public_patterns(&self) -> &Vec<TokenStream> { + &self.public_patterns + } + pub fn private_patterns(&self) -> &Vec<TokenStream> { + &self.private_patterns } - (name, args, ret) -} -/// Split arguments group into a vector of each argument with it's associated -/// type. -/// -/// **Input:** "(p1 : t1, p2 : t2, ...)" -/// **Output:** vec!["p1 : t1", "p2 : t2", ...] -pub fn args_split(item: &TokenStream) -> Vec<TokenStream> { - let contents; - if let TokenTree::Group(group) = item.clone().into_iter().next().unwrap() { - contents = group.stream().into_iter(); - } else { - unreachable!("Item passed to args_split is not a group: \"{item}\""); + pub fn types(&self) -> &Vec<TokenStream> { + &self.types + } + pub fn public_types(&self) -> &Vec<TokenStream> { + &self.public_types + } + pub fn private_types(&self) -> &Vec<TokenStream> { + &self.private_types } - let mut args = Vec::new(); - let mut ts = TokenStream::new(); - let mut angle_level = 0; + pub fn grouped_patterns(&self) -> TokenStream { + Self::group_stream(&self.patterns) + } + pub fn grouped_public_patterns(&self) -> TokenStream { + Self::group_stream(&self.public_patterns) + } + pub fn grouped_private_patterns(&self) -> TokenStream { + Self::group_stream(&self.private_patterns) + } - for tt in contents { - match tt { - TokenTree::Punct(ref punct) => match punct.as_char() { - // < and > do **not** form TokenTree groups, however their - // usage is like that of a group. Hence, we need extra - // logic to skip them. - '<' => angle_level += 1, - '>' => angle_level -= 1, - ',' => { - if angle_level == 0 { - args.push(ts); - ts = TokenStream::new(); + pub fn grouped_types(&self) -> TokenStream { + Self::group_stream(&self.types) + } + pub fn grouped_public_types(&self) -> TokenStream { + Self::group_stream(&self.public_types) + } + pub fn grouped_private_types(&self) -> TokenStream { + Self::group_stream(&self.private_types) + } + + pub fn arguments(&self) -> Vec<TokenStream> { + Self::combine(self.patterns.clone(), self.types.clone()) + } + pub fn public_arguments(&self) -> Vec<TokenStream> { + Self::combine(self.public_patterns.clone(), self.public_types.clone()) + } + pub fn private_arguments(&self) -> Vec<TokenStream> { + Self::combine(self.private_patterns.clone(), self.private_types.clone()) + } + + /// Split function definition into triplet of name, arguments and output types. + /// + /// **Input:** "fn name(...) -> ... { ..... }" + /// **Output:** "name", "(...)", "..." + fn split_fn(item: &TokenStream) -> (TokenStream, TokenStream, TokenStream) { + let item = item.clone().into_iter(); + + let mut name = TokenStream::new(); + let mut args = TokenStream::new(); + let mut ret = TokenStream::new(); + let mut out: &mut TokenStream = &mut name; + + for tt in item { + match tt { + // The conditions will later be used to return + // errors when incorrect function type is used + TokenTree::Ident(ref ident) => { + if ident.to_string() == "fn" || ident.to_string() == "pub" { continue; } } - _ => {} - }, - _ => {} + TokenTree::Punct(ref punct) => { + if punct.as_char() == '-' { + out = &mut ret; + continue; + } + if punct.as_char() == '>' && out.is_empty() { + continue; + } + } + TokenTree::Group(ref group) => { + if group.delimiter() == Delimiter::Brace { + break; + } + if !out.is_empty() { + out = &mut args; + } + } + TokenTree::Literal(_) => unreachable!("Cannot have literal inside def!"), + } + out.extend([tt].into_iter()); } - ts.extend([tt].into_iter()); - } + if ret.is_empty() { + ret.extend( + [TokenTree::Group(Group::new( + Delimiter::Parenthesis, + TokenStream::new(), + ))] + .into_iter(), + ); + } - if !ts.is_empty() { - args.push(ts); + (name, args, ret) } - args -} - -/// Like `args_split`, however two vectors are returned: the first for public -/// arguments (and their types) and the second for private ones. -/// -/// `public` is a vector of argument names. -/// -/// **Input:** "(p1 : t1, p2: t2, ...)", vec!["p3", "p4", ...] -/// **Output:** vec!["p1 : t1", "p2: t2", ...], vec!["p3 : t3", "p4: t4", ...] -pub fn args_split_public( - item: &TokenStream, - public: &Vec<&String>, -) -> (Vec<TokenStream>, Vec<TokenStream>) { - let all_args = args_split(item); - let public_args: Vec<TokenStream> = all_args - .clone() - .into_iter() - .filter(|a| public.iter().any(|x| a.to_string().starts_with(*x))) - .collect(); - let private_args: Vec<TokenStream> = all_args - .into_iter() - .filter(|t| { - !public_args - .iter() - .any(|pt| *t.to_string() == pt.to_string()) - }) - .collect(); - (public_args, private_args) -} -/// Split arguments group into two vectors: one for all argument names and one -/// for every argument type. -/// -/// **Input:** "(p1 : t1, p2: t2, ...)" -/// **Output:** vec!["p1", "p2", ...], vec!["t1", "t2", ...] -pub fn args_divide(item: &TokenStream) -> (Vec<TokenStream>, Vec<TokenStream>) { - let contents; - if let TokenTree::Group(group) = item.clone().into_iter().next().unwrap() { - contents = group.stream().into_iter(); - } else { - unreachable!("Item passed to args_divide is not a group: \"{item}\""); - } + /// Split arguments group into two vectors: one for all argument names and one + /// for every argument type. + /// + /// **Input:** "(p1 : t1, p2: t2, ...)" + /// **Output:** vec!["p1", "p2", ...], vec!["t1", "t2", ...] + fn args_divide(item: &TokenStream) -> (Vec<TokenStream>, Vec<TokenStream>) { + let contents; + if let TokenTree::Group(group) = item.clone().into_iter().next().unwrap() { + contents = group.stream().into_iter(); + } else { + unreachable!("Item passed to args_divide is not a group: \"{item}\""); + } - let mut patterns = Vec::new(); - let mut types = Vec::new(); - let mut ts = TokenStream::new(); - let mut ignore_next = false; - let mut angle_level = 0; + let mut patterns = Vec::new(); + let mut types = Vec::new(); + let mut ts = TokenStream::new(); + let mut ignore_next = false; + let mut angle_level = 0; - for tt in contents { - match tt { - TokenTree::Punct(ref punct) => { - // Ignore "::" - if punct.spacing() == Spacing::Joint && punct.as_char() == ':' { - ignore_next = true; - } else if !ignore_next { - match punct.as_char() { - // < and > do **not** form TokenTree groups, however their - // usage is like that of a group. Hence, we need extra - // logic to skip them. - '<' => angle_level += 1, - '>' => angle_level -= 1, - ':' => { - patterns.push(ts); - ts = TokenStream::new(); - continue; - } - ',' => { - if angle_level == 0 { - types.push(ts); + for tt in contents { + match tt { + TokenTree::Punct(ref punct) => { + // Ignore "::" + if punct.spacing() == Spacing::Joint && punct.as_char() == ':' { + ignore_next = true; + } else if !ignore_next { + match punct.as_char() { + // < and > do **not** form TokenTree groups, however their + // usage is like that of a group. Hence, we need extra + // logic to skip them. + '<' => angle_level += 1, + '>' => angle_level -= 1, + ':' => { + patterns.push(ts); ts = TokenStream::new(); continue; } + ',' => { + if angle_level == 0 { + types.push(ts); + ts = TokenStream::new(); + continue; + } + } + _ => {} } - _ => {} + } else { + ignore_next = false; } - } else { - ignore_next = false; } + _ => {} } - _ => {} + + ts.extend([tt].into_iter()); } - ts.extend([tt].into_iter()); + types.push(ts); + (patterns, types) } - types.push(ts); - (patterns, types) -} - -/// Like `args_divide`, however two tuples of vectors are returned: the first -/// for public arguments and types, and the second for private ones. -/// -/// `public` is a vector of argument names. -/// -/// **Input:** "(p1 : t1, p2: t2, ...)", vec!["p3", "p4", ...] -/// **Output:** (vec!["p1", "p2", ...], vec!["t1", "t2", ...]), (vec!["p3", "p4", ...], vec!["t3", "t4", ...]) -pub fn args_divide_public( - item: &TokenStream, - public: &Vec<&String>, -) -> ( - (Vec<TokenStream>, Vec<TokenStream>), - (Vec<TokenStream>, Vec<TokenStream>), -) { - let (patterns, types) = args_divide(item); - - let (public_patterns, public_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns - .clone() - .into_iter() - .zip(types.clone().into_iter()) - .filter(|(p, _)| public.iter().any(|x| p.to_string() == **x)) - .unzip(); - - let (private_patterns, private_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns - .into_iter() - .zip(types.into_iter()) - .filter(|(p, _)| { - !public_patterns - .iter() - .any(|x| p.to_string() == x.to_string()) - }) - .unzip(); - ( - (public_patterns, public_types), - (private_patterns, private_types), - ) -} - -/// Like `args_divide`, but group arguments and types (via `group_streams`). -/// -/// **Input:** "(p1 : t1, p2: t2, ...)" -/// **Output:** "(p1, p2, ...)", "(t1, t2, ...)" -pub fn args_divide_grouped(item: &TokenStream) -> (TokenStream, TokenStream) { - let (patterns, types) = args_divide(&item); - (group_streams(&patterns), group_streams(&types)) -} + /// Like `args_divide`, however two tuples of vectors are returned: the first + /// for public arguments and types, and the second for private ones. + /// + /// `public` is a vector of argument names. + /// + /// **Input:** "(p1 : t1, p2: t2, ...)", vec!["p3", "p4", ...] + /// **Output:** (vec!["p1", "p2", ...], vec!["t1", "t2", ...]), (vec!["p3", "p4", ...], vec!["t3", "t4", ...]) + fn args_divide_public( + patterns: &Vec<TokenStream>, + types: &Vec<TokenStream>, + public: &Vec<&String>, + ) -> ( + (Vec<TokenStream>, Vec<TokenStream>), + (Vec<TokenStream>, Vec<TokenStream>), + ) { + let (public_patterns, public_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns + .clone() + .into_iter() + .zip(types.clone().into_iter()) + .filter(|(p, _)| public.iter().any(|x| p.to_string() == **x)) + .unzip(); -/// Transform a vector of elements into a (TokenTree) group of elements -/// -/// **Input:** vec!["p1", "p2", ...] -/// **Output:** "(p1, p2, ...)" -pub fn group_streams(patterns: &Vec<TokenStream>) -> TokenStream { - let mut inner_ts = TokenStream::new(); - inner_ts.extend( - patterns + let (private_patterns, private_types): (Vec<TokenStream>, Vec<TokenStream>) = patterns .clone() .into_iter() - .flat_map(|i| [",".parse().unwrap(), i]) - .skip(1), - ); + .zip(types.clone().into_iter()) + .filter(|(p, _)| { + !public_patterns + .iter() + .any(|x| p.to_string() == x.to_string()) + }) + .unzip(); + ( + (public_patterns, public_types), + (private_patterns, private_types), + ) + } - let mut out = TokenStream::new(); - out.extend( - [TokenTree::Group(Group::new( - Delimiter::Parenthesis, - inner_ts, - ))] - .into_iter(), - ); + /// Transform a vector of elements into a (TokenTree) group of elements + /// + /// **Input:** vec!["p1", "p2", ...] + /// **Output:** "p1, p2, ..." + fn group_stream(patterns: &Vec<TokenStream>) -> TokenStream { + let mut elems = TokenStream::new(); + elems.extend( + patterns + .clone() + .into_iter() + .flat_map(|i| [",".parse().unwrap(), i]) + .skip(1), + ); + elems + } - out + fn combine(patterns: Vec<TokenStream>, types: Vec<TokenStream>) -> Vec<TokenStream> { + patterns + .into_iter() + .zip(types.into_iter()) + .map(|(p, t)| format!("{p} : {t}").parse().unwrap()) + .collect() + } } |
