use proc_macro::TokenStream;
#[path = "../../../guests_macro/src/parse_fn.rs"]
mod parse_fn;
use crate::parse_fn::FunctionDefinition;
/// Parses the `guests/type.txt` type note, created from the guest
/// Returns a tuple of the arguments group and the return type
fn new_fd() -> FunctionDefinition {
FunctionDefinition::new(&include_str!("../../../guests/type.txt").parse().unwrap())
}
static DERIVES: &str = "#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]";
/// Creates an Output type def and three Input structures from the guest
/// type.txt file.
///
/// # Usage
///
/// Inside zkvms_host_io:
///
/// ```rust
/// input_macros::generate_output_type_input_struct!();
/// ```
///
/// # Example output
///
/// ```rust
/// pub type Output = (... ...);
///
/// pub type Return = ...;
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub struct PublicInput {
/// pub ...: ...,
/// pub ...: ...,
/// ...
/// }
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub struct PrivateInput {
/// pub ...: ...,
/// pub ...: ...,
/// ...
/// }
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub struct Input {
/// pub ...: ...,
/// pub ...: ...,
/// ...
/// }
///
/// // Converts Input to a tuple
/// impl From for (...) {
/// fn from(input: Input) -> (...) {
/// (
/// input....,
/// input....,
/// ...
/// )
/// }
/// }
/// ```
#[proc_macro]
pub fn generate_output_type_input_struct(_: TokenStream) -> TokenStream {
let fd = new_fd();
let sep = if fd.types().is_empty() { "" } else { ", " };
let output_type = format!(
"pub type Output = ({} {} {});",
fd.grouped_public_types(),
sep,
fd.return_type
)
.to_string();
let return_type = format!("pub type Return = {};", fd.return_type).to_string();
let public_attrs = fd
.public_arguments()
.iter()
.map(|x| format!("pub {x},"))
.collect::();
let public_input_type =
format!("{} pub struct PublicInput {{ {} }}", DERIVES, public_attrs).to_string();
let private_attrs = fd
.private_arguments()
.iter()
.map(|x| format!("pub {x},"))
.collect::();
let private_input_type = format!(
"{} pub struct PrivateInput {{ {} }}",
DERIVES, private_attrs
)
.to_string();
let attrs = fd
.arguments()
.iter()
.map(|x| format!("pub {x},"))
.collect::();
let convertion = fd
.patterns()
.clone()
.iter()
.map(|x| format!("input.{x},"))
.collect::();
let types = fd.grouped_types();
let struct_def = &format!(
"
{DERIVES} pub struct Input {{
{attrs}
}}
impl From for ({types}) {{
fn from(input: Input) -> ({types}) {{
({convertion})
}}
}}
"
)
.to_string();
(output_type + &return_type + &public_input_type + &private_input_type + &struct_def)
.parse::()
.unwrap()
}
/// Repeats the given item as many times as fields there are, while replacing
/// all `.yield` occurences with the fields value (field name).
fn foreach_field(item: TokenStream, fields: Vec) -> TokenStream {
let expr = format!("{}", item);
let mut out = String::new();
for field in fields {
// Unquoted yield is a keyword, so it is not allowed as field name
out += &expr.replace(".yield", &format!(".{field}"));
}
out.parse::().unwrap()
}
/// Repeats the given code as many times as fields there are in the Input
/// struct, while replacing all `.yield` occurences with the concrete
/// field name.
#[proc_macro]
pub fn foreach_input_field(item: TokenStream) -> TokenStream {
foreach_field(item, new_fd().patterns().clone())
}
/// Repeats the given code as many times as fields there are in the
/// PublicInput struct, while replacing all `.yield` occurences with the
/// concrete field name.
#[proc_macro]
pub fn foreach_public_input_field(item: TokenStream) -> TokenStream {
foreach_field(item, new_fd().public_patterns().clone())
}
/// Repeats the given code as many times as fields there are in the
/// PrivateInput struct, while replacing all `.yield` occurences with the
/// concrete field name.
#[proc_macro]
pub fn foreach_private_input_field(item: TokenStream) -> TokenStream {
foreach_field(item, new_fd().private_patterns().clone())
}
/// Assuming the `run_info` variable is present, it creates a block with all
/// needed code to properly benchmark the input code, according to all command
/// parameters.
#[proc_macro]
pub fn benchmarkable(item: TokenStream) -> TokenStream {
format!(
r#"
{{
use std::time::Instant;
let mut starts = Vec::new();
let mut ends = Vec::new();
for i in 1..=run_info.runs {{
if run_info.benchmarking {{
starts.push(Instant::now());
}}
{item}
if run_info.benchmarking {{
ends.push(Instant::now());
}}
}}
if run_info.benchmarking {{
zkvms_host_io::emit_benchmark_results(run_info, starts, ends);
}}
}}
"#
)
.parse()
.unwrap()
}