diff options
| -rw-r--r-- | zkvms/zkwasm/wrapper_macro/src/lib.rs | 21 | ||||
| -rw-r--r-- | zkvms_host_io/src/lib.rs | 2 |
2 files changed, 15 insertions, 8 deletions
diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs index 76f49fc..301ba60 100644 --- a/zkvms/zkwasm/wrapper_macro/src/lib.rs +++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs @@ -86,27 +86,27 @@ fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { ").parse().unwrap() } -fn return_vec(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_cont(readfn: &TokenTree, container: &Ident, pushfn: &str, inner: &TokenStream) -> TokenStream { format!(" {{ - let mut ret = Vec::new(); + let mut ret = {container}::new(); let size = read!({readfn} usize); for _ in 0..size {{ - ret.push(read!({readfn} {inner})); + ret.{pushfn}(read!({readfn} {inner})); }} ret }} ").parse().unwrap() } -fn return_hashmap(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { +fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) -> TokenStream { let mut inner = inner.clone().into_iter(); let key_type = inner.next().unwrap(); inner.next().unwrap(); let value_type = inner.next().unwrap(); format!(r#" {{ - let mut ret = HashMap::new(); + let mut ret = {container}::new(); let size = read!({readfn} usize); for _ in 0..size {{ ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type})); @@ -156,8 +156,15 @@ pub fn read(item: TokenStream) -> TokenStream { let rest = inner_group.stream(); match ident.to_string().as_str() { - "Vec" => return_vec(&readfn, &rest), - "HashMap" => return_hashmap(&readfn, &rest), + // https://doc.rust-lang.org/std/collections/ + "Vec" | "BinaryHeap" => + return_cont(&readfn, &ident, "push", &rest), + "VecDeque" | "LinkedList" => + return_cont(&readfn, &ident, "push_back", &rest), + "HashSet" | "BTreeSet" => + return_cont(&readfn, &ident, "insert", &rest), + "HashMap" | "BTreeMap" => + return_hashmap(&readfn, &ident, &rest), _ => todo!("Unsupported container {ident}"), } } diff --git a/zkvms_host_io/src/lib.rs b/zkvms_host_io/src/lib.rs index abf9edf..7090023 100644 --- a/zkvms_host_io/src/lib.rs +++ b/zkvms_host_io/src/lib.rs @@ -2,7 +2,7 @@ use clap::{Parser, ValueEnum}; use num_traits::NumCast; use serde::{ Serialize, Deserialize }; use env_file_reader::read_str; -use std::{env, option::Option, fs::read_to_string, collections::HashMap}; +use std::{env, option::Option, fs::read_to_string, collections::*}; pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field }; static DEFAULT_PUBLIC_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")); |
