aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--guests_macro/src/parse_fn.rs11
-rw-r--r--zkvms/risc0/wrapper_macro/src/lib.rs13
-rw-r--r--zkvms_host_io/input_macros/src/lib.rs19
3 files changed, 25 insertions, 18 deletions
diff --git a/guests_macro/src/parse_fn.rs b/guests_macro/src/parse_fn.rs
index 00be250..0e57879 100644
--- a/guests_macro/src/parse_fn.rs
+++ b/guests_macro/src/parse_fn.rs
@@ -130,6 +130,17 @@ pub fn args_divide(item: &TokenStream) -> (Vec<TokenStream>, Vec<TokenStream>) {
}
/// Input: "(p1 : t1, p2: t2, ...)"
+/// Output: vec![p1, p2, ...], vec![t1, t2, ...]
+pub fn args_divide_public(item: &TokenStream, public: &Vec<&String>) -> (Vec<TokenStream>, Vec<TokenStream>) {
+ let (patterns, types) = args_divide(item);
+ patterns
+ .into_iter()
+ .zip(types.into_iter())
+ .filter(|(p, _)| public.iter().any(|x| p.to_string() == **x))
+ .unzip()
+}
+
+/// Input: "(p1 : t1, p2: t2, ...)"
/// Output: "(p1, p2, ...)", "(t1, t2, ...)"
pub fn args_divide_grouped(item: &TokenStream) -> (TokenStream, TokenStream) {
let (patterns, types) = args_divide(&item);
diff --git a/zkvms/risc0/wrapper_macro/src/lib.rs b/zkvms/risc0/wrapper_macro/src/lib.rs
index bbdf6e4..d4ba858 100644
--- a/zkvms/risc0/wrapper_macro/src/lib.rs
+++ b/zkvms/risc0/wrapper_macro/src/lib.rs
@@ -2,7 +2,7 @@ use proc_macro::TokenStream;
#[path = "../../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
-use crate::parse_fn::{ split_fn, args_split, args_divide_grouped };
+use crate::parse_fn::{ split_fn, args_split, args_divide_public, args_divide_grouped };
#[proc_macro]
pub fn make_wrapper(item: TokenStream) -> TokenStream {
@@ -18,14 +18,15 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream {
include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"))
)
.unwrap();
- let mut commitment = String::new();
- for input in public_inputs.keys() {
- commitment += &format!("{}.clone(), ", input);
- }
+ let public_patterns = args_divide_public(&args, &public_inputs.keys().collect())
+ .0
+ .iter()
+ .map(|x| x.to_string() + ".clone(), ")
+ .collect::<String>();
let (ts_patterns, _) = args_divide_grouped(&args);
- out.extend(format!("commit(&({} zkp::{}{}));", commitment, name, ts_patterns).parse::<TokenStream>());
+ out.extend(format!("commit(&({} zkp::{}{}));", public_patterns, name, ts_patterns).parse::<TokenStream>());
let mut block = TokenStream::new();
block.extend(format!("{{ {} }}", out).parse::<TokenStream>());
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs
index 5184192..9c15102 100644
--- a/zkvms_host_io/input_macros/src/lib.rs
+++ b/zkvms_host_io/input_macros/src/lib.rs
@@ -2,7 +2,7 @@ use proc_macro::TokenStream;
#[path = "../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
-use crate::parse_fn::{ args_split, args_divide, group_streams };
+use crate::parse_fn::{ args_split, args_divide, args_divide_public, group_streams };
fn get_types() -> (TokenStream, TokenStream) {
let types: Vec<&str> = include_str!("../../../guests/type.txt")
@@ -14,22 +14,17 @@ fn get_types() -> (TokenStream, TokenStream) {
#[proc_macro]
pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream {
let (args, ret) = get_types();
- let (patterns, types) = args_divide(&args);
- let mut patterns = patterns
- .iter()
- .map(|x| x.to_string());
let public_inputs = toml::from_str::<toml::Table>(
include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"))
)
.unwrap();
- let mut commitment = String::new();
- for input in public_inputs.keys() {
- if let Some(index) = patterns.clone().position(|x| x == *input) {
- commitment += &format!("{}, ", types[index]);
- }
- }
- let output_type = format!("pub type Output = ({} {});", commitment, ret).to_string();
+ let public_types = args_divide_public(&args, &public_inputs.keys().collect())
+ .1
+ .iter()
+ .map(|x| x.to_string() + ", ")
+ .collect::<String>();
+ let output_type = format!("pub type Output = ({} {});", public_types, ret).to_string();
let all_args = args_split(&args);