aboutsummaryrefslogtreecommitdiff
path: root/zkvms/zkwasm
diff options
context:
space:
mode:
authorKamen Mladenov <kamen@syndamia.com>2025-02-06 13:48:17 +0200
committerKamen Mladenov <kamen@syndamia.com>2025-02-06 13:48:17 +0200
commitf56218406dbe5e4918d560778b24b366cd3bda9f (patch)
treeec5cd372a81860bf89439dc8c1ae791ea061605f /zkvms/zkwasm
parent5cbd68a39d7a91160fbb8ed81fe9aef5b7dc14b6 (diff)
downloadzkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar
zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.tar.gz
zkVMs-benchmarks-f56218406dbe5e4918d560778b24b366cd3bda9f.zip
feat(zkvms/zkwasm): Use path input type and move container size information to host
Diffstat (limited to 'zkvms/zkwasm')
-rw-r--r--zkvms/zkwasm/host/src/main.rs141
-rw-r--r--zkvms/zkwasm/wrapper_macro/src/lib.rs85
2 files changed, 127 insertions, 99 deletions
diff --git a/zkvms/zkwasm/host/src/main.rs b/zkvms/zkwasm/host/src/main.rs
index d6c5381..179d388 100644
--- a/zkvms/zkwasm/host/src/main.rs
+++ b/zkvms/zkwasm/host/src/main.rs
@@ -1,52 +1,100 @@
-use zkvms_host_io::{PublicInput, PrivateInput, foreach_public_input_field, foreach_private_input_field, read_args, RunType::{Execute, Prove, Verify}};
+use zkvms_host_io::{PublicInput, PrivateInput, foreach_public_input_field, foreach_private_input_field, read_args, RunType::{Execute, Prove, Verify}, RunWith};
use std::io::{self, Write};
use std::process::{Command, Stdio};
use regex::Regex;
-fn build_public_input(input: &PublicInput) -> String {
- let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap();
-
- let mut ret = String::new();
- foreach_public_input_field!{
- let flat = format!("{:?}", input.yield)
- .replace("false", "0")
- .replace("true", "1");
-
- let numbers: Vec<&str> = numreg
- .captures_iter(&flat)
- .map(|cap| cap.get(1).unwrap().as_str())
- .collect();
-
- for num in numbers {
- ret.push_str(num);
- ret.push_str(":i64,");
- }
+static PUBLIC_INPUT_PATH: &str = "public_input.bin";
+static PRIVATE_INPUT_PATH: &str = "private_input.bin";
+
+fn get_with_sizes(flat: &str) -> String {
+ let mut values = flat
+ .split('[')
+ .map(|x| x.trim())
+ .skip(1);
+ let current = values
+ .next()
+ .unwrap_or(flat);
+
+ // 1D collection or not a collection
+ if current != "" {
+ let size = 1 + current
+ .clone()
+ .to_string()
+ .chars()
+ .take_while(|x| *x != ']')
+ .map(|x| (x == ',') as usize)
+ .sum::<usize>();
+
+ (if size > 1 { size.to_string() } else { String::new() })
+ + "["
+ + current
+ + &values
+ .map(|x| "[".to_string() + x)
+ .collect::<String>()
+ }
+ // ND collection
+ else {
+ let size: usize = values
+ .clone()
+ .count();
+
+ let subcollections = values
+ .map(|x| get_with_sizes(x))
+ .collect::<String>();
+
+ size.to_string()
+ + "["
+ + &subcollections
}
- ret.pop(); // removes trailing comma
- ret
}
-fn build_private_input(input: &PrivateInput) -> String {
- let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap();
-
- let mut ret = String::new();
- foreach_private_input_field!{
- let flat = format!("{:?}", input.yield)
- .replace("false", "0")
- .replace("true", "1");
-
- let numbers: Vec<&str> = numreg
- .captures_iter(&flat)
- .map(|cap| cap.get(1).unwrap().as_str())
- .collect();
-
- for num in numbers {
- ret.push_str(num);
- ret.push_str(":i64,");
+macro_rules! build_input {
+ ($input:expr , $path:ident , $type:ident) => {
+ |run_info: &RunWith| {
+ let numreg: Regex = Regex::new("(?:^|[^A-Za-z])([0-9]+)").unwrap();
+ let stringreg: Regex = Regex::new("\\\"[^\"]*\\\"").unwrap();
+
+ let mut ret: Vec<u64> = Vec::new();
+ $type!{
+ let flat = format!("{:?}", $input.yield)
+ .replace("false", "0")
+ .replace("true", "1")
+ .replace('(', "[")
+ .replace(')', "]")
+ .replace('{', "[")
+ .replace('{', "]");
+
+ let flat = get_with_sizes(&flat);
+
+ let numbers = numreg
+ .captures_iter(&flat)
+ .map(|cap|
+ cap.get(1)
+ .unwrap()
+ .as_str()
+ .to_string()
+ .parse::<u64>()
+ .unwrap())
+ .collect::<Vec<u64>>();
+
+ ret.extend(numbers);
+
+ // let strings: Vec<&str> = stringreg
+ // .captures_iter(&flat)
+ // .map(|cap| cap.get(0).unwrap().as_str())
+ // .collect();
+ //
+ // panic!("{:#?}", strings);
+ }
+ let bytes = ret
+ .iter()
+ .map(|x| x.to_be_bytes())
+ .flatten()
+ .collect::<Vec<u8>>();
+ std::fs::write($path, bytes);
+ format!("{}:file", $path)
}
- }
- ret.pop(); // removes trailing comma
- ret
+ };
}
fn zkwasm_command(subcmd: &str) -> Command {
@@ -81,8 +129,15 @@ fn main() {
.arg("-k").arg(k)
.arg("--scheme").arg(scheme));
- let public_input = build_public_input(&run_info.public_input);
- let private_input = build_private_input(&run_info.private_input);
+ let public_input = build_input!(
+ run_info.public_input,
+ PUBLIC_INPUT_PATH,
+ foreach_public_input_field)(&run_info);
+
+ let private_input = build_input!(
+ run_info.private_input,
+ PRIVATE_INPUT_PATH,
+ foreach_private_input_field)(&run_info);
let output = run_info
.env_or(
diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs
index f5cdfbb..e7a201c 100644
--- a/zkvms/zkwasm/wrapper_macro/src/lib.rs
+++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs
@@ -5,40 +5,11 @@ mod parse_fn;
use crate::parse_fn::{ split_fn, args_divide_grouped, args_divide_public, group_streams };
use toml::Table;
-fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, inputs: &Table, readfn: &str) {
+fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, readfn: &str) {
for i in 0..patterns.len() {
- let mut value = &inputs[&patterns[i].to_string()];
let type_note: String = format!("{}", types[i])
.replace('<', "[")
- .replace('>', "]")
- .split("[")
- .map(|x| x.trim())
- .map(|typ|
- // Array
- if typ.is_empty() {
- let array = value.as_array()
- .expect("value is of type Array but isn't an array");
- value = &array[0];
- "[".to_string()
- }
- // STD Vec
- else if typ.ends_with("Vec") {
- let array = value.as_array()
- .expect("value is of type Vec but isn't an array");
- value = &array[0];
- format!("{} [ {} ", typ, array.len())
- }
- // STD HashMap
- else if typ.ends_with("HashMap") {
- let array = value.as_table()
- .expect("value is of type HashMap but isn't an array");
- // value = &array[0];
- format!("{} [ {} ", typ, array.len())
- }
- else {
- typ.to_string()
- })
- .collect();
+ .replace('>', "]");
out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note).parse::<TokenStream>());
}
}
@@ -51,20 +22,20 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream {
include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"))
)
.unwrap();
- let private_inputs = toml::from_str::<Table>(
- include_str!(concat!(env!("INPUTS_DIR"), "/default_private_input.toml"))
- )
- .unwrap();
let ((pub_pat, pub_typ), (prv_pat, prv_typ)) = args_divide_public(&args, &public_inputs.keys().collect());
let mut out = TokenStream::new();
- insert_reads(&mut out, &pub_pat, &pub_typ, &public_inputs, "read_public");
- insert_reads(&mut out, &prv_pat, &prv_typ, &private_inputs, "read_private");
+ insert_reads(&mut out, &pub_pat, &pub_typ, "read_public");
+ insert_reads(&mut out, &prv_pat, &prv_typ, "read_private");
let (ts_patterns, _) = args_divide_grouped(&args);
- out.extend(format!("let result = zkp::{}{}; assert(result); write(result as u64);", name, ts_patterns).parse::<TokenStream>());
+ out.extend(format!("
+ let result = zkp::{}{};
+ assert(result);
+ write(result as u64);
+ ", name, ts_patterns).parse::<TokenStream>());
let mut block = TokenStream::new();
block.extend(format!("{{ {} }}", out).parse::<TokenStream>());
@@ -103,11 +74,12 @@ fn return_string(readfn: &TokenTree) -> TokenStream {
").parse().unwrap()
}
-fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream {
+fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
format!("
{{
let mut ret = Vec::new();
- for _ in 0..{size} {{
+ let size = read!({readfn} usize);
+ for _ in 0..size {{
ret.push(read!({readfn} {inner}));
}}
ret.try_into().unwrap()
@@ -115,11 +87,12 @@ fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> To
").parse().unwrap()
}
-fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream {
+fn return_vec(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
format!("
{{
let mut ret = Vec::new();
- for _ in 0..{size} {{
+ let size = read!({readfn} usize);
+ for _ in 0..size {{
ret.push(read!({readfn} {inner}));
}}
ret
@@ -127,7 +100,7 @@ fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> Toke
").parse().unwrap()
}
-fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream {
+fn return_hashmap(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
let mut inner = inner.clone().into_iter();
let key_type = inner.next().unwrap();
inner.next().unwrap();
@@ -135,7 +108,8 @@ fn return_hashmap(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) ->
format!(r#"
{{
let mut ret = HashMap::new();
- for _ in 0..{size} {{
+ let size = read!({readfn} usize);
+ for _ in 0..size {{
ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type}));
}}
ret
@@ -148,7 +122,12 @@ fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
for subtype in inner.clone().into_iter() {
value += &format!("read!({readfn} {subtype}), ");
}
- format!("( {value} )").parse().unwrap()
+ format!("
+ {{
+ let _ = read!({readfn} usize);
+ ( {value} )
+ }}
+ ").parse().unwrap()
}
#[proc_macro]
@@ -175,15 +154,11 @@ pub fn read(item: TokenStream) -> TokenStream {
let mut group = parts.next()
.expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str());
if let TokenTree::Group(inner_group) = group {
- let mut group = inner_group.stream().into_iter();
- let size = group.next().unwrap();
-
- let mut rest = TokenStream::new();
- rest.extend(group);
+ let rest = inner_group.stream();
match ident.to_string().as_str() {
- "Vec" => return_vec(&readfn, &size, &rest),
- "HashMap" => return_hashmap(&readfn, &size, &rest),
+ "Vec" => return_vec(&readfn, &rest),
+ "HashMap" => return_hashmap(&readfn, &rest),
_ => todo!("Unsupported container {ident}"),
}
}
@@ -199,10 +174,8 @@ pub fn read(item: TokenStream) -> TokenStream {
match current {
TokenTree::Punct(punct) => match punct.as_char() {
// Array
- ';' => {
- let size = group.next().unwrap();
- return return_array(&readfn, &size, &inner);
- },
+ ';' =>
+ return return_array(&readfn, &inner),
// Tuple
',' => continue,
_ => unreachable!("Group contains unexpected \"{punct}\""),