diff options
| -rw-r--r-- | zkvms_host_io/input_macros/src/lib.rs | 53 | ||||
| -rw-r--r-- | zkvms_host_io/src/lib.rs | 21 |
2 files changed, 73 insertions, 1 deletions
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs index 868c755..80a6a54 100644 --- a/zkvms_host_io/input_macros/src/lib.rs +++ b/zkvms_host_io/input_macros/src/lib.rs @@ -108,3 +108,56 @@ pub fn foreach_private_input_field(item: TokenStream) -> TokenStream { foreach_field(item, private_patterns) } + +#[proc_macro] +pub fn benchmarkable(item: TokenStream) -> TokenStream { + format!(r#" + {{ + use std::time::Instant; + use std::fs::File; + use std::io::Write; + + let mut starts = Vec::new(); + let mut ends = Vec::new(); + + for i in 1..=run_info.repeats {{ + if run_info.benchmarking {{ + starts.push(Instant::now()); + }} + + {item} + + if run_info.benchmarking {{ + ends.push(Instant::now()); + }} + }} + + let mut output = String::new(); + + if run_info.benchmarking {{ + output += &format!("Total Duration: {{}}", + (*ends.last().unwrap() - *starts.first().unwrap()) + .as_secs()); + }} + if run_info.repeats > 1 {{ + let durations = starts + .into_iter() + .zip(ends.into_iter()) + .map(|(s,e)| (e - s).as_secs()) + .collect::<Vec<u64>>(); + output += &format!(";Average: {{}}", + durations.iter().sum::<u64>() / durations.len() as u64); + }} + + if run_info.benchmarking {{ + if let Some(file) = run_info.output_file {{ + let mut outfile = File::create(file).unwrap(); + writeln!(outfile, "{{}}", output); + }} + else {{ + println!("{{}}", output); + }} + }} + }} + "#).parse().unwrap() +} diff --git a/zkvms_host_io/src/lib.rs b/zkvms_host_io/src/lib.rs index 7090023..46c4740 100644 --- a/zkvms_host_io/src/lib.rs +++ b/zkvms_host_io/src/lib.rs @@ -3,7 +3,7 @@ use num_traits::NumCast; use serde::{ Serialize, Deserialize }; use env_file_reader::read_str; use std::{env, option::Option, fs::read_to_string, collections::*}; -pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field }; +pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field, benchmarkable }; static DEFAULT_PUBLIC_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml")); static DEFAULT_PRIVATE_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_private_input.toml")); @@ -18,6 +18,15 @@ struct Cli { private_input: Option<String>, public_input: Option<String>, + + #[arg(short, long)] + benchmark: bool, + + #[arg(short, long, requires = "benchmark")] + repeat: Option<usize>, + + #[arg(short, long, requires = "benchmark")] + metrics_output: Option<String>, } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -30,9 +39,14 @@ pub enum RunType { #[derive(Debug, Clone, PartialEq, Eq)] pub struct RunWith { pub run_type: RunType, + pub benchmarking: bool, + pub repeats: usize, + pub output_file: Option<String>, + pub input: Input, pub public_input: PublicInput, pub private_input: PrivateInput, + pub default_env: HashMap<String, String>, } @@ -76,9 +90,14 @@ pub fn read_args() -> RunWith { RunWith { run_type: cli.run_type, + benchmarking: cli.benchmark, + repeats: cli.repeat.unwrap_or(1), + output_file: cli.metrics_output, + input, public_input, private_input, + default_env, } } |
