aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--zkvms_host_io/input_macros/src/lib.rs53
-rw-r--r--zkvms_host_io/src/lib.rs21
2 files changed, 73 insertions, 1 deletions
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs
index 868c755..80a6a54 100644
--- a/zkvms_host_io/input_macros/src/lib.rs
+++ b/zkvms_host_io/input_macros/src/lib.rs
@@ -108,3 +108,56 @@ pub fn foreach_private_input_field(item: TokenStream) -> TokenStream {
foreach_field(item, private_patterns)
}
+
+#[proc_macro]
+pub fn benchmarkable(item: TokenStream) -> TokenStream {
+ format!(r#"
+ {{
+ use std::time::Instant;
+ use std::fs::File;
+ use std::io::Write;
+
+ let mut starts = Vec::new();
+ let mut ends = Vec::new();
+
+ for i in 1..=run_info.repeats {{
+ if run_info.benchmarking {{
+ starts.push(Instant::now());
+ }}
+
+ {item}
+
+ if run_info.benchmarking {{
+ ends.push(Instant::now());
+ }}
+ }}
+
+ let mut output = String::new();
+
+ if run_info.benchmarking {{
+ output += &format!("Total Duration: {{}}",
+ (*ends.last().unwrap() - *starts.first().unwrap())
+ .as_secs());
+ }}
+ if run_info.repeats > 1 {{
+ let durations = starts
+ .into_iter()
+ .zip(ends.into_iter())
+ .map(|(s,e)| (e - s).as_secs())
+ .collect::<Vec<u64>>();
+ output += &format!(";Average: {{}}",
+ durations.iter().sum::<u64>() / durations.len() as u64);
+ }}
+
+ if run_info.benchmarking {{
+ if let Some(file) = run_info.output_file {{
+ let mut outfile = File::create(file).unwrap();
+ writeln!(outfile, "{{}}", output);
+ }}
+ else {{
+ println!("{{}}", output);
+ }}
+ }}
+ }}
+ "#).parse().unwrap()
+}
diff --git a/zkvms_host_io/src/lib.rs b/zkvms_host_io/src/lib.rs
index 7090023..46c4740 100644
--- a/zkvms_host_io/src/lib.rs
+++ b/zkvms_host_io/src/lib.rs
@@ -3,7 +3,7 @@ use num_traits::NumCast;
use serde::{ Serialize, Deserialize };
use env_file_reader::read_str;
use std::{env, option::Option, fs::read_to_string, collections::*};
-pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field };
+pub use input_macros::{ foreach_input_field, foreach_public_input_field, foreach_private_input_field, benchmarkable };
static DEFAULT_PUBLIC_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_public_input.toml"));
static DEFAULT_PRIVATE_INPUT: &str = include_str!(concat!(env!("INPUTS_DIR"), "/default_private_input.toml"));
@@ -18,6 +18,15 @@ struct Cli {
private_input: Option<String>,
public_input: Option<String>,
+
+ #[arg(short, long)]
+ benchmark: bool,
+
+ #[arg(short, long, requires = "benchmark")]
+ repeat: Option<usize>,
+
+ #[arg(short, long, requires = "benchmark")]
+ metrics_output: Option<String>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
@@ -30,9 +39,14 @@ pub enum RunType {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RunWith {
pub run_type: RunType,
+ pub benchmarking: bool,
+ pub repeats: usize,
+ pub output_file: Option<String>,
+
pub input: Input,
pub public_input: PublicInput,
pub private_input: PrivateInput,
+
pub default_env: HashMap<String, String>,
}
@@ -76,9 +90,14 @@ pub fn read_args() -> RunWith {
RunWith {
run_type: cli.run_type,
+ benchmarking: cli.benchmark,
+ repeats: cli.repeat.unwrap_or(1),
+ output_file: cli.metrics_output,
+
input,
public_input,
private_input,
+
default_env,
}
}