aboutsummaryrefslogtreecommitdiff
path: root/zkvms_host_io/input_macros/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zkvms_host_io/input_macros/src/lib.rs')
-rw-r--r--zkvms_host_io/input_macros/src/lib.rs53
1 files changed, 53 insertions, 0 deletions
diff --git a/zkvms_host_io/input_macros/src/lib.rs b/zkvms_host_io/input_macros/src/lib.rs
index 868c755..80a6a54 100644
--- a/zkvms_host_io/input_macros/src/lib.rs
+++ b/zkvms_host_io/input_macros/src/lib.rs
@@ -108,3 +108,56 @@ pub fn foreach_private_input_field(item: TokenStream) -> TokenStream {
foreach_field(item, private_patterns)
}
+
+#[proc_macro]
+pub fn benchmarkable(item: TokenStream) -> TokenStream {
+ format!(r#"
+ {{
+ use std::time::Instant;
+ use std::fs::File;
+ use std::io::Write;
+
+ let mut starts = Vec::new();
+ let mut ends = Vec::new();
+
+ for i in 1..=run_info.repeats {{
+ if run_info.benchmarking {{
+ starts.push(Instant::now());
+ }}
+
+ {item}
+
+ if run_info.benchmarking {{
+ ends.push(Instant::now());
+ }}
+ }}
+
+ let mut output = String::new();
+
+ if run_info.benchmarking {{
+ output += &format!("Total Duration: {{}}",
+ (*ends.last().unwrap() - *starts.first().unwrap())
+ .as_secs());
+ }}
+ if run_info.repeats > 1 {{
+ let durations = starts
+ .into_iter()
+ .zip(ends.into_iter())
+ .map(|(s,e)| (e - s).as_secs())
+ .collect::<Vec<u64>>();
+ output += &format!(";Average: {{}}",
+ durations.iter().sum::<u64>() / durations.len() as u64);
+ }}
+
+ if run_info.benchmarking {{
+ if let Some(file) = run_info.output_file {{
+ let mut outfile = File::create(file).unwrap();
+ writeln!(outfile, "{{}}", output);
+ }}
+ else {{
+ println!("{{}}", output);
+ }}
+ }}
+ }}
+ "#).parse().unwrap()
+}