use proc_macro::{Ident, TokenStream, TokenTree}; #[path = "../../../../guests_macro/src/parse_fn.rs"] mod parse_fn; use crate::parse_fn::FunctionDefinition; use toml::Table; /// Extends an out TokenStream with `let` directives for all patterns (and /// types). readfn specifies wether the input is public or private. /// /// Each `let` binding calls the read! macro (defined in zkWasm wrapper_macro /// crate), feeding it an autogenerated "type note", which is then used to /// "deserialize" the input. fn insert_reads( out: &mut TokenStream, patterns: &Vec, types: &Vec, readfn: &str, ) { for i in 0..patterns.len() { let type_note: String = format!("{}", types[i]).replace('<', "[").replace('>', "]"); out.extend( format!( "let {} : {} = read!({} {});", patterns[i], types[i], readfn, type_note ) .parse::(), ); } } /// Creates a body, which reads all inputs, stores them in variables, then /// executes the entrypoint function with those arguments and writes the /// result. /// /// Inputs are read via the read! macro (defined in the zkWasm wrapper_macro /// crate). Their public status is dependent on `default_public_input.toml`. /// /// # Usage /// /// Inside zkWasm's guest (excluding the `entrypoint_expr` call): /// /// ```rust /// make_wrapper!{fn main(...) -> ...} /// ``` /// /// # Example output /// /// ```rust /// { /// let ... : ... = read!(... ...); /// let ... : ... = read!(... ...); /// ... /// let result = zkp::main(..., ..., ...); /// write(result as u64); /// } /// ``` #[proc_macro] pub fn make_wrapper(item: TokenStream) -> TokenStream { let fd = FunctionDefinition::new(&item); let mut out = TokenStream::new(); insert_reads(&mut out, fd.public_patterns(), fd.public_types(), "read_public"); insert_reads(&mut out, fd.private_patterns(), fd.private_types(), "read_private"); out.extend( format!( " let result = zkp::{}({}); let bytes = tobytes::to_bytes!(result); for val in bytes.into_iter() {{ write(val); }} ", fd.name, fd.grouped_patterns() ) .parse::(), ); let mut block = TokenStream::new(); block.extend(format!("{{ {} }}", out).parse::()); block } fn return_primitive(readfn: &TokenTree, typ: &Ident) -> TokenStream { format!( " ({readfn}() as {typ}) " ) .parse() .unwrap() } fn return_bool(readfn: &TokenTree) -> TokenStream { format!( " ({readfn}() != 0) " ) .parse() .unwrap() } fn return_char(readfn: &TokenTree) -> TokenStream { format!( " (({readfn}() as u8) as char) " ) .parse() .unwrap() } fn return_string(readfn: &TokenTree) -> TokenStream { format!( " {{ let mut ret = Vec::new(); let size = read!({readfn} usize); for _ in 0..size {{ ret.push(read!({readfn} char)); }} ret.into_iter().collect() }} " ) .parse() .unwrap() } fn return_array(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { format!( " {{ let mut ret = Vec::new(); let size = read!({readfn} usize); for _ in 0..size {{ ret.push(read!({readfn} {inner})); }} ret.try_into().unwrap() }} " ) .parse() .unwrap() } fn return_cont( readfn: &TokenTree, container: &Ident, pushfn: &str, inner: &TokenStream, ) -> TokenStream { format!( " {{ let mut ret = {container}::new(); let size = read!({readfn} usize); for _ in 0..size {{ ret.{pushfn}(read!({readfn} {inner})); }} ret }} " ) .parse() .unwrap() } fn return_hashmap(readfn: &TokenTree, container: &Ident, inner: &TokenStream) -> TokenStream { let mut inner = inner.clone().into_iter(); let key_type = inner.next().unwrap(); inner.next().unwrap(); let value_type = inner.next().unwrap(); format!( r#" {{ let mut ret = {container}::new(); let size = read!({readfn} usize); for _ in 0..size {{ ret.insert(read!({readfn} {key_type}), read!({readfn} {value_type})); }} ret }} "# ) .parse() .unwrap() } fn return_tuple(readfn: &TokenTree, inner: &TokenStream) -> TokenStream { let mut value = String::new(); for subtype in inner.clone().into_iter() { value += &format!("read!({readfn} {subtype}), "); } format!( " {{ let _ = read!({readfn} usize); ( {value} ) }} " ) .parse() .unwrap() } /// Creates a body which returns a value of the type, defined as a "type note" /// argument. /// /// The host "serializes" all input data by flattening it as a series of /// integers. This function, in turn, unflattens the input, by reading integers /// multiple times and combining them in the appropriate structures. /// /// It takes two arguments, separated by a space. The first is a string "type /// note" and the second is the name of the read function (either read_private /// or read_public). /// /// The type note is similar to a print!("{:?}") output, however angled braces /// are square. #[proc_macro] pub fn read(item: TokenStream) -> TokenStream { let mut parts = item.clone().into_iter(); let readfn = parts.next().unwrap(); match parts.next().unwrap() { // Primitive or STD Container TokenTree::Ident(ident) => { match ident.to_string().as_str() { "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "f32" | "f64" => return return_primitive(&readfn, &ident), "char" => return return_char(&readfn), "bool" => return return_bool(&readfn), "String" => return return_string(&readfn), _ => {} } let mut group = parts .next() .expect(format!("No group after \"{ident}\" while parsing \"{item}\"!").as_str()); if let TokenTree::Group(inner_group) = group { let rest = inner_group.stream(); match ident.to_string().as_str() { // https://doc.rust-lang.org/std/collections/ "Vec" | "BinaryHeap" => return_cont(&readfn, &ident, "push", &rest), "VecDeque" | "LinkedList" => return_cont(&readfn, &ident, "push_back", &rest), "HashSet" | "BTreeSet" => return_cont(&readfn, &ident, "insert", &rest), "HashMap" | "BTreeMap" => return_hashmap(&readfn, &ident, &rest), _ => todo!("Unsupported container {ident}"), } } else { unreachable!("{group} is not a TokenTree::Group!"); } } // Array or tuple TokenTree::Group(group) => { let mut group = group.stream().into_iter(); let mut inner = TokenStream::new(); while let Some(current) = group.next() { match current { TokenTree::Punct(punct) => match punct.as_char() { // Array ';' => return return_array(&readfn, &inner), // Tuple ',' => continue, _ => unreachable!("Group contains unexpected \"{punct}\""), }, TokenTree::Ident(_) | TokenTree::Group(_) => { inner.extend([current].into_iter()) } _ => unreachable!(), } } return_tuple(&readfn, &inner) } _ => unreachable!(), } }