diff options
| author | Kamen Mladenov <kamen@syndamia.com> | 2025-02-06 13:48:17 +0200 |
|---|---|---|
| committer | Kamen Mladenov <kamen@syndamia.com> | 2025-02-06 13:48:17 +0200 |
| commit | f56218406dbe5e4918d560778b24b366cd3bda9f (patch) | |
| tree | ec5cd372a81860bf89439dc8c1ae791ea061605f /zkvms/zkwasm | |
| parent | 5cbd68a39d7a91160fbb8ed81fe9aef5b7dc14b6 (diff) | |
| download | zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar.gz zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.zip | |
feat(zkvms/zkwasm): Use path input type and move container size information to host
Diffstat (limited to 'zkvms/zkwasm')
| -rw-r--r-- | zkvms/zkwasm/host/src/main.rs | 141 | ||||
| -rw-r--r-- | zkvms/zkwasm/wrapper_macro/src/lib.rs | 85 |
2 files changed, 127 insertions, 99 deletions
diff --git a/zkvms/zkwasm/host/src/main.rs b/zkvms/zkwasm/host/src/main.rs index d6c5381..179d388 100644 --- a/zkvms/zkwasm/host/src/main.rs +++ b/zkvms/zkwasm/host/src/main.rs @@ -1,52 +1,100 @@ -use zkvms_host_io::{PublicInput, PrivateInput, foreach_public_input_field, foreach_private_input_field, read_args, RunType::{Execute, Prove, Verify}}; +use zkvms_host_io::{PublicInput, PrivateInput, foreach_public_input_field, foreach_private_input_field, read_args, RunType::{Execute, Prove, Verify}, RunWith}; use std::io::{self, Write}; use std::process::{Command, Stdio}; use regex::Regex; -fn build_public_input(input: &PublicInput) -> String { - let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap(); - - let mut ret = String::new(); - foreach_public_input_field!{ - let flat = format!("{:?}", input.yield) - .replace("false", "0") - .replace("true", "1"); - - let numbers: Vec<&str> = numreg - .captures_iter(&flat) - .map(|cap| cap.get(1).unwrap().as_str()) - .collect(); - - for num in numbers { - ret.push_str(num); - ret.push_str(":i64,"); - } +static PUBLIC_INPUT_PATH: &str = "public_input.bin"; +static PRIVATE_INPUT_PATH: &str = "private_input.bin"; + +fn get_with_sizes(flat: &str) -> String { + let mut values = flat + .split('[') + .map(|x| x.trim()) + .skip(1); + let current = values + .next() + .unwrap_or(flat); + + // 1D collection or not a collection + if current != "" { + let size = 1 + current + .clone() + .to_string() + .chars() + .take_while(|x| *x != ']') + .map(|x| (x == ',') as usize) + .sum::<usize>(); + + (if size > 1 { size.to_string() } else { String::new() }) + + "[" + + current + + &values + .map(|x| "[".to_string() + x) + .collect::<String>() + } + // ND collection + else { + let size: usize = values + .clone() + .count(); + + let subcollections = values + .map(|x| get_with_sizes(x)) + .collect::<String>(); + + size.to_string() + + "[" + + &subcollections } - ret.pop(); // removes trailing comma - ret } -fn build_private_input(input: &PrivateInput) -> String { - let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap(); - - let mut ret = String::new(); - foreach_private_input_field!{ - let flat = format!("{:?}", input.yield) - .replace("false", "0") - .replace("true", "1"); - - let numbers: Vec<&str> = numreg - .captures_iter(&flat) - .map(|cap| cap.get(1).unwrap().as_str()) - .collect(); - - for num in numbers { - ret.push_str(num); - ret.push_str(":i64,"); +macro_rules! build_input { + ($input:expr , $path:ident , $type:ident) => { + |run_info: &RunWith| { + let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap(); + let stringreg: Regex = Regex::new("\\\"[^\"]*\\\"").unwrap(); + + let mut ret: Vec<u64> = Vec::new(); + $type!{ + let flat = format!("{:?}", $input.yield) + .replace("false", "0") + .replace("true", "1") + .replace('(', "[") + .replace(')', "]") + .replace('{', "[") + .replace('{', "]"); + + let flat = get_with_sizes(&flat); + + let numbers = numreg + .captures_iter(&flat) + .map(|cap| + cap.get(1) + .unwrap() + .as_str() + .to_string() + .parse::<u64>() + .unwrap()) + .collect::<Vec<u64>>(); + + ret.extend(numbers); + + // let strings: Vec<&str> = stringreg + // .captures_iter(&flat) + // .map(|cap| cap.get(0).unwrap().as_str()) + // .collect(); + // + // panic!("{:#?}", strings); + } + let bytes = ret + .iter() + .map(|x| x.to_be_bytes()) + .flatten() + .collect::<Vec<u8>>(); + std::fs::write($path, bytes); + format!("{}:file", $path) } - } - ret.pop(); // removes trailing comma - ret + }; } fn zkwasm_command(subcmd: &str) -> Command { @@ -81,8 +129,15 @@ fn main() { .arg("-k").arg(k) .arg("--scheme").arg(scheme)); - let public_input = build_public_input(&run_info.public_input); - let private_input = build_private_input(&run_info.private_input); + let public_input = build_input!( + run_info.public_input, + PUBLIC_INPUT_PATH, + foreach_public_input_field)(&run_info); + + let private_input = build_input!( + run_info.private_input, + PRIVATE_INPUT_PATH, + foreach_private_input_field)(&run_info); let output = run_info .env_or( diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs index f5cdfbb..e7a201c 100644 --- a/zkvms/zkwasm/wrapper_macro/src/lib.rs +++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs @@ -5,40 +5,11 @@ mod parse_fn; use crate::parse_fn::{ split_fn, args_divide_grouped, args_divide_public, group_streams }; use toml::Table; -fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, inputs: &Table, readfn: &str) { +fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, readfn: &str) { for i in 0..patterns.len() { - let mut value = &inputs[&patterns[i].to_string()]; let type_note: String = format!("{}", types[i]) .replace('<', "[") - .replace('>', "]") - .split("[") - .map(|x| x.trim()) - .map(|typ| - // Array - if typ.is_empty() { - let array = value.as_array() - .expect("value is of type Array but isn't an array"); - value = &array[0]; - "[".to_string() - } - // STD Vec - else if typ.ends_with("Vec") { - let array = value.as_array() - .expect("value is of type Vec but isn't an array"); - value = &array[0]; - format!("{} [ {} ", typ, array.len()) - } - // STD HashMap - else if typ.ends_with("HashMap") { - let array = value.as_table() - .expect("value is of type HashMap but isn't an array"); - // value = &array[0]; - format!("{} [ {} ", typ, array.len()) - } - else { - typ.to_string() - }) - .collect(); + .replace('>', "]"); out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note).parse::<TokenStream>()); } } @@ -51,20 +22,20 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream { include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")) ) .unwrap(); - let private_inputs = toml::from_str::<Table>( - include_str!(concat!(env!("INPUTS_DIR"), "/default_private_input.toml")) - ) - .unwrap(); let ((pub_pat, pub_typ), (prv_pat, prv_typ)) = args_divide_public(&args, &public_inputs.keys().collect()); let mut out = TokenStream::new(); - insert_reads(&mut out, &pub_pat, &pub_typ, &public_inputs, "read_public"); - insert_reads(&mut out, &prv_pat, &prv_typ, &private_inputs, "read_private"); + insert_reads(&mut out, &pub_pat, &pub_typ, "read_public"); + insert_reads(&mut out, &prv_pat, &prv_typ, "read_private"); let (ts_patterns, _) = args_divide_grouped(&args); - out.extend(format!("let result = zkp::{}{}; assert(result); write(result as u64);", name, ts_patterns).parse::<TokenStream>()); + out.extend(format!(" + let result = zkp::{}{}; + assert(result); + write(result as u64); + ", name, ts_patterns).parse::<TokenStream>()); let mut block = TokenStream::new(); block.extend(format!("{{ {} }}", out).parse::<TokenStream>()); @@ -103,11 +74,12 @@ fn return_string(readfn: &TokenTree) -> TokenStream { ").parse().unwrap() } -fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { format!(" {{ let mut ret = Vec::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.push(read!({readfn} {inner})); }} ret.try_into().unwrap() @@ -115,11 +87,12 @@ fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> To ").parse().unwrap() } -fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_vec(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { format!(" {{ let mut ret = Vec::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.push(read!({readfn} {inner})); }} ret @@ -127,7 +100,7 @@ fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> Toke ").parse().unwrap() } -fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_hashmap(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { let mut inner = inner.clone().into_iter(); let key_type = inner.next().unwrap(); inner.next().unwrap(); @@ -135,7 +108,8 @@ fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> format!(r#" {{ let mut ret = HashMap::new(); - for _ in 0..{size} {{ + let size = read!({readfn} usize); + for _ in 0..size {{ ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type})); }} ret @@ -148,7 +122,12 @@ fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { for subtype in inner.clone().into_iter() { value += &format!("read!({readfn} {subtype}), "); } - format!("( {value} )").parse().unwrap() + format!(" + {{ + let _ = read!({readfn} usize); + ( {value} ) + }} + ").parse().unwrap() } #[proc_macro] @@ -175,15 +154,11 @@ pub fn read(item: TokenStream) -> TokenStream { let mut group = parts.next() .expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str()); if let TokenTree::Group(inner_group) = group { - let mut group = inner_group.stream().into_iter(); - let size = group.next().unwrap(); - - let mut rest = TokenStream::new(); - rest.extend(group); + let rest = inner_group.stream(); match ident.to_string().as_str() { - "Vec" => return_vec(&readfn, &size, &rest), - "HashMap" => return_hashmap(&readfn, &size, &rest), + "Vec" => return_vec(&readfn, &rest), + "HashMap" => return_hashmap(&readfn, &rest), _ => todo!("Unsupported container {ident}"), } } @@ -199,10 +174,8 @@ pub fn read(item: TokenStream) -> TokenStream { match current { TokenTree::Punct(punct) => match punct.as_char() { // Array - ';' => { - let size = group.next().unwrap(); - return return_array(&readfn, &size, &inner); - }, + ';' => + return return_array(&readfn, &inner), // Tuple ',' => continue, _ => unreachable!("Group contains unexpected \"{punct}\""), |
