aboutsummaryrefslogtreecommitdiff
path: root/zkvms/zkwasm/wrapper_macro
diff options
context:
space:
mode:
Diffstat (limited to 'zkvms/zkwasm/wrapper_macro')
-rw-r--r--zkvms/zkwasm/wrapper_macro/src/lib.rs163
1 files changed, 104 insertions, 59 deletions
diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs
index 8bdb581..288eddb 100644
--- a/zkvms/zkwasm/wrapper_macro/src/lib.rs
+++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs
@@ -1,8 +1,8 @@
-use proc_macro::{ TokenStream, TokenTree, Ident };
+use proc_macro::{Ident, TokenStream, TokenTree};
#[path = "../../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
-use crate::parse_fn::{ split_fn, args_divide_grouped, args_divide_public, group_streams };
+use crate::parse_fn::{args_divide_grouped, args_divide_public, group_streams, split_fn};
use toml::Table;
/// Extends an out TokenStream with `let` directives for all patterns (and
@@ -11,12 +11,21 @@ use toml::Table;
/// Each `let` binding calls the read! macro (defined in zkWasm wrapper_macro
/// crate), feeding it an autogenerated "type note", which is then used to
/// "deserialize" the input.
-fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<TokenStream>, readfn: &str) {
+fn insert_reads(
+ out: &mut TokenStream,
+ patterns: &Vec<TokenStream>,
+ types: &Vec<TokenStream>,
+ readfn: &str,
+) {
for i in 0..patterns.len() {
- let type_note: String = format!("{}", types[i])
- .replace('<', "[")
- .replace('>', "]");
- out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note).parse::<TokenStream>());
+ let type_note: String = format!("{}", types[i]).replace('<', "[").replace('>', "]");
+ out.extend(
+ format!(
+ "let {} : {} = read!({} {});",
+ patterns[i], types[i], readfn, type_note
+ )
+ .parse::<TokenStream>(),
+ );
}
}
@@ -50,11 +59,13 @@ fn insert_reads(out: &mut TokenStream, patterns: &Vec<TokenStream>, types: &Vec<
pub fn make_wrapper(item: TokenStream) -> TokenStream {
let (name, args, ret) = split_fn(&item);
- let public_inputs = toml::from_str::<Table>(
- include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"))
- )
- .unwrap();
- let ((pub_pat, pub_typ), (prv_pat, prv_typ)) = args_divide_public(&args, &public_inputs.keys().collect());
+ let public_inputs = toml::from_str::<Table>(include_str!(concat!(
+ env!("INPUTS_DIR"),
+ "/default_public_input.toml"
+ )))
+ .unwrap();
+ let ((pub_pat, pub_typ), (prv_pat, prv_typ)) =
+ args_divide_public(&args, &public_inputs.keys().collect());
let mut out = TokenStream::new();
@@ -63,10 +74,16 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream {
let (ts_patterns, _) = args_divide_grouped(&args);
- out.extend(format!("
+ out.extend(
+ format!(
+ "
let result = zkp::{}{};
write(result as u64);
- ", name, ts_patterns).parse::<TokenStream>());
+ ",
+ name, ts_patterns
+ )
+ .parse::<TokenStream>(),
+ );
let mut block = TokenStream::new();
block.extend(format!("{{ {} }}", out).parse::<TokenStream>());
@@ -74,25 +91,38 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream {
}
fn return_primitive(readfn: &TokenTree, typ: &Ident) -> TokenStream {
- format!("
+ format!(
+ "
({readfn}() as {typ})
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
fn return_bool(readfn: &TokenTree) -> TokenStream {
- format!("
+ format!(
+ "
({readfn}() != 0)
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
fn return_char(readfn: &TokenTree) -> TokenStream {
- format!("
+ format!(
+ "
(({readfn}() as u8) as char)
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
fn return_string(readfn: &TokenTree) -> TokenStream {
- format!("
+ format!(
+ "
{{
let mut ret = Vec::new();
let size = read!({readfn} usize);
@@ -101,11 +131,15 @@ fn return_string(readfn: &TokenTree) -> TokenStream {
}}
ret.into_iter().collect()
}}
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
- format!("
+ format!(
+ "
{{
let mut ret = Vec::new();
let size = read!({readfn} usize);
@@ -114,11 +148,20 @@ fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
}}
ret.try_into().unwrap()
}}
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
-fn return_cont(readfn: &TokenTree, container: &Ident, pushfn: &str, inner: &TokenStream) -> TokenStream {
- format!("
+fn return_cont(
+ readfn: &TokenTree,
+ container: &Ident,
+ pushfn: &str,
+ inner: &TokenStream,
+) -> TokenStream {
+ format!(
+ "
{{
let mut ret = {container}::new();
let size = read!({readfn} usize);
@@ -127,7 +170,10 @@ fn return_cont(readfn: &TokenTree, container: &Ident, pushfn: &str, inner: &Toke
}}
ret
}}
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) -> TokenStream {
@@ -135,7 +181,8 @@ fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) ->
let key_type = inner.next().unwrap();
inner.next().unwrap();
let value_type = inner.next().unwrap();
- format!(r#"
+ format!(
+ r#"
{{
let mut ret = {container}::new();
let size = read!({readfn} usize);
@@ -144,7 +191,10 @@ fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) ->
}}
ret
}}
- "#).parse().unwrap()
+ "#
+ )
+ .parse()
+ .unwrap()
}
fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
@@ -152,12 +202,16 @@ fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream {
for subtype in inner.clone().into_iter() {
value += &format!("read!({readfn} {subtype}), ");
}
- format!("
+ format!(
+ "
{{
let _ = read!({readfn} usize);
( {value} )
}}
- ").parse().unwrap()
+ "
+ )
+ .parse()
+ .unwrap()
}
/// Creates a body which returns a value of the type, defined as a "type note"
@@ -181,41 +235,32 @@ pub fn read(item: TokenStream) -> TokenStream {
// Primitive or STD Container
TokenTree::Ident(ident) => {
match ident.to_string().as_str() {
- "u8" | "u16" | "u32" | "u64" | "u128" | "usize" |
- "i8" | "i16" | "i32" | "i64" | "i128" | "isize" |
- "f32" | "f64" =>
- return return_primitive(&readfn, &ident),
- "char" =>
- return return_char(&readfn),
- "bool" =>
- return return_bool(&readfn),
- "String" =>
- return return_string(&readfn),
- _ => {},
+ "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
+ | "i128" | "isize" | "f32" | "f64" => return return_primitive(&readfn, &ident),
+ "char" => return return_char(&readfn),
+ "bool" => return return_bool(&readfn),
+ "String" => return return_string(&readfn),
+ _ => {}
}
- let mut group = parts.next()
+ let mut group = parts
+ .next()
.expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str());
if let TokenTree::Group(inner_group) = group {
let rest = inner_group.stream();
match ident.to_string().as_str() {
// https://doc.rust-lang.org/std/collections/
- "Vec" | "BinaryHeap" =>
- return_cont(&readfn, &ident, "push", &rest),
- "VecDeque" | "LinkedList" =>
- return_cont(&readfn, &ident, "push_back", &rest),
- "HashSet" | "BTreeSet" =>
- return_cont(&readfn, &ident, "insert", &rest),
- "HashMap" | "BTreeMap" =>
- return_hashmap(&readfn, &ident, &rest),
+ "Vec" | "BinaryHeap" => return_cont(&readfn, &ident, "push", &rest),
+ "VecDeque" | "LinkedList" => return_cont(&readfn, &ident, "push_back", &rest),
+ "HashSet" | "BTreeSet" => return_cont(&readfn, &ident, "insert", &rest),
+ "HashMap" | "BTreeMap" => return_hashmap(&readfn, &ident, &rest),
_ => todo!("Unsupported container {ident}"),
}
- }
- else {
+ } else {
unreachable!("{group} is not a TokenTree::Group!");
}
- },
+ }
// Array or tuple
TokenTree::Group(group) => {
let mut group = group.stream().into_iter();
@@ -224,19 +269,19 @@ pub fn read(item: TokenStream) -> TokenStream {
match current {
TokenTree::Punct(punct) => match punct.as_char() {
// Array
- ';' =>
- return return_array(&readfn, &inner),
+ ';' => return return_array(&readfn, &inner),
// Tuple
',' => continue,
_ => unreachable!("Group contains unexpected \"{punct}\""),
},
- TokenTree::Ident(_) | TokenTree::Group(_) =>
- inner.extend([current].into_iter()),
+ TokenTree::Ident(_) | TokenTree::Group(_) => {
+ inner.extend([current].into_iter())
+ }
_ => unreachable!(),
}
}
return_tuple(&readfn, &inner)
- },
+ }
_ => unreachable!(),
}
}