aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--zkvms/zkwasm/wrapper_macro/src/lib.rs21
-rw-r--r--zkvms_host_io/src/lib.rs2
2 files changed, 15 insertions, 8 deletions
diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs
index 76f49fc..301ba60 100644
--- a/zkvms/zkwasm/wrapper_macro/src/lib.rs
+++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs
@@ -86,27 +86,27 @@ fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
").parse().unwrap()
}
-fn return_vec(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
+fn return_cont(readfn: &TokenTree, container: &Ident, pushfn: &str, inner: &TokenStream) -> TokenStream {
format!("
{{
- let mut ret = Vec::new();
+ let mut ret = {container}::new();
let size = read!({readfn} usize);
for _ in 0..size {{
- ret.push(read!({readfn} {inner}));
+ ret.{pushfn}(read!({readfn} {inner}));
}}
ret
}}
").parse().unwrap()
}
-fn return_hashmap(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
+fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) -> TokenStream {
let mut inner = inner.clone().into_iter();
let key_type = inner.next().unwrap();
inner.next().unwrap();
let value_type = inner.next().unwrap();
format!(r#"
{{
- let mut ret = HashMap::new();
+ let mut ret = {container}::new();
let size = read!({readfn} usize);
for _ in 0..size {{
ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type}));
@@ -156,8 +156,15 @@ pub fn read(item: TokenStream) -> TokenStream {
let rest = inner_group.stream();
match ident.to_string().as_str() {
- "Vec" => return_vec(&readfn, &rest),
- "HashMap" => return_hashmap(&readfn, &rest),
+ // https://doc.rust-lang.org/std/collections/
+ "Vec" | "BinaryHeap" =>
+ return_cont(&readfn, &ident, "push", &rest),
+ "VecDeque" | "LinkedList" =>
+ return_cont(&readfn, &ident, "push_back", &rest),
+ "HashSet" | "BTreeSet" =>
+ return_cont(&readfn, &ident, "insert", &rest),
+ "HashMap" | "BTreeMap" =>
+ return_hashmap(&readfn, &ident, &rest),
_ => todo!("Unsupported container {ident}"),
}
}
diff --git a/zkvms_host_io/src/lib.rs b/zkvms_host_io/src/lib.rs
index abf9edf..7090023 100644
--- a/zkvms_host_io/src/lib.rs
+++ b/zkvms_host_io/src/lib.rs
@@ -2,7 +2,7 @@ use clap::{Parser, ValueEnum};
use num_traits::NumCast;
use serde::{ Serialize, Deserialize };
use env_file_reader::read_str;
-use std::{env, option::Option, fs::read_to_string, collections::HashMap};
+use std::{env, option::Option, fs::read_to_string, collections::*};
pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field };
static DEFAULT_PUBLIC_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"));