From 8e6ec28ab29ebb97e1c59e79d4f143dc4563a07c Mon Sep 17 00:00:00 2001 From: Kamen Mladenov Date: Wed, 5 Feb 2025 11:14:19 +0200 Subject: feat(zkvms/zkwasm): Rework read macro into a procedural macro and improve type composition --- zkvms/zkwasm/guest/src/lib.rs | 25 +----- zkvms/zkwasm/wrapper_macro/src/lib.rs | 138 +++++++++++++++++++++++++++++----- 2 files changed, 120 insertions(+), 43 deletions(-) (limited to 'zkvms') diff --git a/zkvms/zkwasm/guest/src/lib.rs b/zkvms/zkwasm/guest/src/lib.rs index c92b97c..c7c0572 100644 --- a/zkvms/zkwasm/guest/src/lib.rs +++ b/zkvms/zkwasm/guest/src/lib.rs @@ -1,5 +1,5 @@ use wasm_bindgen::prelude::wasm_bindgen; -use wrapper_macro::make_wrapper; +use wrapper_macro::{ make_wrapper, read }; // https://github.com/DelphinusLab/zkWasm-rust/blob/main/src/lib.rs use zkwasm_rust_sdk::{require, wasm_input, wasm_output}; @@ -19,29 +19,6 @@ fn write(value: u64) { unsafe { wasm_output(value); } } -static VERTICES: u64 = 10; - -macro_rules! read { - // Vec>>> is converted by entrypoint_expr! to - // Vec,Vec,...,Vec,primitive - (Vec $size:literal , $($type:tt)*) => { - { - let mut ret = Vec::new(); - for _ in 0..$size { - ret.push(read!($($type)*)); - } - ret - } - }; - (bool $readfn:tt) => { - ($readfn() != 0) - }; - // Has to be primitive! - ($type:tt $readfn:tt) => { - ($readfn() as $type) - }; -} - #[wasm_bindgen] pub fn zkmain() { zkp::entrypoint_expr!() diff --git a/zkvms/zkwasm/wrapper_macro/src/lib.rs b/zkvms/zkwasm/wrapper_macro/src/lib.rs index caebc4d..c5ce4e2 100644 --- a/zkvms/zkwasm/wrapper_macro/src/lib.rs +++ b/zkvms/zkwasm/wrapper_macro/src/lib.rs @@ -1,4 +1,4 @@ -use proc_macro::TokenStream; +use proc_macro::{ TokenStream, TokenTree, Ident }; #[path = "../../../../guests_macro/src/parse_fn.rs"] mod parse_fn; @@ -9,26 +9,30 @@ fn insert_reads(out: &mut TokenStream, patterns: &Vec, types: &Vec< for i in 0..patterns.len() { let mut value = &inputs[&patterns[i].to_string()]; let type_note: String = format!("{}", types[i]) - .chars() - .map(|c| match c { - '<' => ',', - '>' => ' ', - _ => c, - }) - .collect::() - .split(" , ") - .map(|typ| match typ { - "Vec" => if let Some(array) = value.as_array() { + .replace('<', "[") + .replace('>', "]") + .split("[") + .map(|x| x.trim()) + .map(|typ| + // Array + if typ.is_empty() { + let array = value.as_array() + .expect("value is of type Array but isn't an array"); value = &array[0]; - format!("{} {}", typ, array.len()) - } else { + "[".to_string() + } + // STD Vec + else if typ.ends_with("Vec") { + let array = value.as_array() + .expect("value is of type Vec but isn't an array"); + value = &array[0]; + format!("{} [ {} ", typ, array.len()) + } + else { typ.to_string() - }, - _ => typ.to_string() - }) - .collect::>() - .join(","); - out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], type_note, readfn).parse::()); + }) + .collect(); + out.extend(format!("let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note).parse::()); } } @@ -59,3 +63,99 @@ pub fn make_wrapper(item: TokenStream) -> TokenStream { block.extend(format!("{{ {} }}", out).parse::()); block } + +fn return_primitive(readfn: &TokenTree, typ: &Ident) -> TokenStream { + format!(" + ({readfn}() as {typ}) + ").parse().unwrap() +} + +fn return_bool(readfn: &TokenTree) -> TokenStream { + format!(" + ({readfn}() != 0) + ").parse().unwrap() +} + +fn return_array(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { + format!(" + {{ + let mut ret = Vec::new(); + for _ in 0..{size} {{ + ret.push(read!({readfn} {inner})); + }} + ret.try_into().unwrap() + }} + ").parse().unwrap() +} + +fn return_vec(readfn: &TokenTree, size: &TokenTree, inner: &TokenStream) -> TokenStream { + format!(" + {{ + let mut ret = Vec::new(); + for _ in 0..{size} {{ + ret.push(read!({readfn} {inner})); + }} + ret + }} + ").parse().unwrap() +} + +#[proc_macro] +pub fn read(item: TokenStream) -> TokenStream { + let mut parts = item.clone().into_iter(); + let readfn = parts.next().unwrap(); + match parts.next().unwrap() { + // Primitive or STD Container + TokenTree::Ident(ident) => { + match ident.to_string().as_str() { + "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | + "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | + "f32" | "f64" | + "char" => + return return_primitive(&readfn, &ident), + "bool" => + return return_bool(&readfn), + _ => {}, + } + + let mut group = parts.next() + .expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str()); + if let TokenTree::Group(inner_group) = group { + let mut group = inner_group.stream().into_iter(); + let size = group.next().unwrap(); + + let mut rest = TokenStream::new(); + rest.extend(group); + + match ident.to_string().as_str() { + "Vec" => return_vec(&readfn, &size, &rest), + _ => todo!("Unsupported container {ident}"), + } + } + else { + unreachable!("{group} is not a TokenTree::Group!"); + } + }, + // Array + TokenTree::Group(group) => { + let mut group = group.stream().into_iter(); + let mut inner = TokenStream::new(); + loop { + let current = group.next().unwrap(); + match current { + TokenTree::Punct(punct) => if punct.as_char() == ';' { + break; + } else { + unreachable!(); + }, + TokenTree::Ident(_) | TokenTree::Group(_) => + inner.extend([current].into_iter()), + _ => unreachable!(), + } + } + let size = group.next().unwrap(); + return_array(&readfn, &size, &inner) + }, + _ => unreachable!(), + } +} -- cgit v1.2.3