diff options
| author | Kamen Mladenov <kamen@syndamia.com> | 2025-02-06 13:48:17 +0200 |
|---|---|---|
| committer | Kamen Mladenov <kamen@syndamia.com> | 2025-02-06 13:48:17 +0200 |
| commit | f56218406dbe5e4918d560778b24b366cd3bda9f (patch) | |
| tree | ec5cd372a81860bf89439dc8c1ae791ea061605f /zkvms/zkwasm/wrapper_macro | |
| parent | 5cbd68a39d7a91160fbb8ed81fe9aef5b7dc14b6 (diff) | |
| download | zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar.gz zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.zip | |
feat(zkvms/zkwasm): Use path input type and move container size information to host
Diffstat (limited to 'zkvms/zkwasm/wrapper_macro')
| -rw-r--r-- | zkvms/zkwasm/wrapper_macro/src/lib.rs | 85 |
1 files changed, 29 insertions, 56 deletions
diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs index f5cdfbb..e7a201c 100644 --- a/zkvms/zkwasm/wrapper_macro/src/lib.rs +++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs @@ -5,40 +5,11 @@ mod parse_fn; use crate::parse_fn::{ split_fn, args_divide_grouped, args_divide_public, group_streams }; use toml::Table; -fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, inputs: &Table, readfn: &str) { +fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, readfn: &str) { for i in 0..patterns.len() { - let mut value = &inputs[&patterns[i].to_string()]; let type_note: String = format!("{}", types[i]) .replace('<', "[") - .replace('>', "]") - .split("[") - .map(|x| x.trim()) - .map(|typ| - // Array - if typ.is_empty() { - let array = value.as_array() - .expect("value is of type Array but isn't an array"); - value = &array[0]; - "[".to_string() - } - // STD Vec - else if typ.ends_with("Vec") { - let array = value.as_array() - .expect("value is of type Vec but isn't an array"); - value = &array[0]; - format!("{} [ {} ", typ, array.len()) - } - // STD HashMap - else if typ.ends_with("HashMap") { - let array = value.as_table() - .expect("value is of type HashMap but isn't an array"); - // value = &array[0]; - format!("{} [ {} ", typ, array.len()) - } - else { - typ.to_string() - }) - .collect(); + .replace('>', "]"); out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note).parse::<TokenStream>()); } } @@ -51,20 +22,20 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream { include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")) ) .unwrap(); - let private_inputs = toml::from_str::<Table>( - include_str!(concat!(env!("INPUTS_DIR"), "/default_private_input.toml")) - ) - .unwrap(); let ((pub_pat, pub_typ), (prv_pat, prv_typ)) = args_divide_public(&args, &public_inputs.keys().collect()); let mut out = TokenStream::new(); - insert_reads(&mut out, &pub_pat, &pub_typ, &public_inputs, "read_public"); - insert_reads(&mut out, &prv_pat, &prv_typ, &private_inputs, "read_private"); + insert_reads(&mut out, &pub_pat, &pub_typ, "read_public"); + insert_reads(&mut out, &prv_pat, &prv_typ, "read_private"); let (ts_patterns, _) = args_divide_grouped(&args); - out.extend(format!("let result = zkp::{}{}; assert(result); write(result as u64);", name, ts_patterns).parse::<TokenStream>()); + out.extend(format!(" + let result = zkp::{}{}; + assert(result); + write(result as u64); + ", name, ts_patterns).parse::<TokenStream>()); let mut block = TokenStream::new(); block.extend(format!("{{ {} }}", out).parse::<TokenStream>()); @@ -103,11 +74,12 @@ fn return_string(readfn: &TokenTree) -> TokenStream { ").parse().unwrap() } -fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { format!(" {{ let mut ret = Vec::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.push(read!({readfn} {inner})); }} ret.try_into().unwrap() @@ -115,11 +87,12 @@ fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> To ").parse().unwrap() } -fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_vec(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { format!(" {{ let mut ret = Vec::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.push(read!({readfn} {inner})); }} ret @@ -127,7 +100,7 @@ fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> Toke ").parse().unwrap() } -fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_hashmap(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { let mut inner = inner.clone().into_iter(); let key_type = inner.next().unwrap(); inner.next().unwrap(); @@ -135,7 +108,8 @@ fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> format!(r#" {{ let mut ret = HashMap::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type})); }} ret @@ -148,7 +122,12 @@ fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { for subtype in inner.clone().into_iter() { value += &format!("read!({readfn} {subtype}), "); } - format!("( {value} )").parse().unwrap() + format!(" + {{ + let _ = read!({readfn} usize); + ( {value} ) + }} + ").parse().unwrap() } #[proc_macro] @@ -175,15 +154,11 @@ pub fn read(item: TokenStream) -> TokenStream { let mut group = parts.next() .expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str()); if let TokenTree::Group(inner_group) = group { - let mut group = inner_group.stream().into_iter(); - let size = group.next().unwrap(); - - let mut rest = TokenStream::new(); - rest.extend(group); + let rest = inner_group.stream(); match ident.to_string().as_str() { - "Vec" => return_vec(&readfn, &size, &rest), - "HashMap" => return_hashmap(&readfn, &size, &rest), + "Vec" => return_vec(&readfn, &rest), + "HashMap" => return_hashmap(&readfn, &rest), _ => todo!("Unsupported container {ident}"), } } @@ -199,10 +174,8 @@ pub fn read(item: TokenStream) -> TokenStream { match current { TokenTree::Punct(punct) => match punct.as_char() { // Array - ';' => { - let size = group.next().unwrap(); - return return_array(&readfn, &size, &inner); - }, + ';' => + return return_array(&readfn, &inner), // Tuple ',' => continue, _ => unreachable!("Group contains unexpected \"{punct}\""), |
