aboutsummaryrefslogtreecommitdiff
path: root/zkvms/zkm/sdk/src/local/util.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zkvms/zkm/sdk/src/local/util.rs')
-rw-r--r--zkvms/zkm/sdk/src/local/util.rs180
1 files changed, 180 insertions, 0 deletions
diff --git a/zkvms/zkm/sdk/src/local/util.rs b/zkvms/zkm/sdk/src/local/util.rs
new file mode 100644
index 0000000..04ec68f
--- /dev/null
+++ b/zkvms/zkm/sdk/src/local/util.rs
@@ -0,0 +1,180 @@
+use std::fs::File;
+use std::io::BufReader;
+use std::ops::Range;
+use std::time::Duration;
+
+use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
+use plonky2::util::timing::TimingTree;
+use plonky2x::backend::circuit::Groth16WrapperParameters;
+use plonky2x::backend::wrapper::wrap::WrappedCircuit;
+use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder;
+use plonky2x::prelude::DefaultParameters;
+
+use zkm_prover::all_stark::AllStark;
+use zkm_prover::config::StarkConfig;
+use zkm_prover::cpu::kernel::assembler::segment_kernel;
+use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits;
+use zkm_prover::generation::state::{AssumptionReceipts, Receipt};
+
+// Value from before
+// https://github.com/zkMIPS/zkm-project-template/commit/313bea17309bc0bc62bad10e19bf1943b568e00f
+const DEGREE_BITS_RANGE: [Range<usize>; 6] = [10..21, 12..22, 10..21, 10..21, 6..21, 13..23];
+
+const D: usize = 2;
+type C = PoseidonGoldilocksConfig;
+type F = <C as GenericConfig<D>>::F;
+
+#[allow(clippy::too_many_arguments)]
+pub fn prove_segments(
+ seg_dir: &str,
+ basedir: &str,
+ outdir: &str,
+ block: &str,
+ file: &str,
+ seg_file_number: usize,
+ seg_start_id: usize,
+ assumptions: AssumptionReceipts<F, C, D>,
+) -> anyhow::Result<Receipt<F, C, D>> {
+ type InnerParameters = DefaultParameters;
+ type OuterParameters = Groth16WrapperParameters;
+
+ let total_timing = TimingTree::new("prove total time", log::Level::Info);
+ let all_stark = AllStark::<F, D>::default();
+ let config = StarkConfig::standard_fast_config();
+ // Preprocess all circuits.
+ let all_circuits =
+ AllRecursiveCircuits::<F, C, D>::new(&all_stark, &DEGREE_BITS_RANGE, &config);
+
+ let seg_file = format!("{}/{}", seg_dir, seg_start_id);
+ log::info!("Process segment {}", seg_file);
+ let seg_reader = BufReader::new(File::open(seg_file)?);
+ let input_first = segment_kernel(basedir, block, file, seg_reader);
+ let mut timing = TimingTree::new("prove root first", log::Level::Info);
+ let mut agg_receipt = all_circuits.prove_root_with_assumption(
+ &all_stark,
+ &input_first,
+ &config,
+ &mut timing,
+ assumptions.clone(),
+ )?;
+
+ timing.filter(Duration::from_millis(100)).print();
+ all_circuits.verify_root(agg_receipt.clone())?;
+
+ let mut base_seg = seg_start_id + 1;
+ let mut seg_num = seg_file_number - 1;
+ let mut is_agg = false;
+
+ if seg_file_number % 2 == 0 {
+ let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1);
+ log::info!("Process segment {}", seg_file);
+ let seg_reader = BufReader::new(File::open(seg_file)?);
+ let input = segment_kernel(basedir, block, file, seg_reader);
+ timing = TimingTree::new("prove root second", log::Level::Info);
+ let receipt = all_circuits.prove_root_with_assumption(
+ &all_stark,
+ &input,
+ &config,
+ &mut timing,
+ assumptions,
+ )?;
+ timing.filter(Duration::from_millis(100)).print();
+
+ all_circuits.verify_root(receipt.clone())?;
+
+ timing = TimingTree::new("prove aggression", log::Level::Info);
+ // We can duplicate the proofs here because the state hasn't mutated.
+ agg_receipt = all_circuits.prove_aggregation(false, &agg_receipt, false, &receipt)?;
+ timing.filter(Duration::from_millis(100)).print();
+ all_circuits.verify_aggregation(&agg_receipt)?;
+
+ is_agg = true;
+ base_seg = seg_start_id + 2;
+ seg_num -= 1;
+ }
+
+ for i in 0..seg_num / 2 {
+ let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1));
+ log::info!("Process segment {}", seg_file);
+ let seg_reader = BufReader::new(File::open(&seg_file)?);
+ let input_first = segment_kernel(basedir, block, file, seg_reader);
+ let mut timing = TimingTree::new("prove root first", log::Level::Info);
+ let root_receipt_first =
+ all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
+
+ timing.filter(Duration::from_millis(100)).print();
+ all_circuits.verify_root(root_receipt_first.clone())?;
+
+ let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1);
+ log::info!("Process segment {}", seg_file);
+ let seg_reader = BufReader::new(File::open(&seg_file)?);
+ let input = segment_kernel(basedir, block, file, seg_reader);
+ let mut timing = TimingTree::new("prove root second", log::Level::Info);
+ let root_receipt = all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
+ timing.filter(Duration::from_millis(100)).print();
+
+ all_circuits.verify_root(root_receipt.clone())?;
+
+ timing = TimingTree::new("prove aggression", log::Level::Info);
+ // We can duplicate the proofs here because the state hasn't mutated.
+ let new_agg_receipt =
+ all_circuits.prove_aggregation(false, &root_receipt_first, false, &root_receipt)?;
+ timing.filter(Duration::from_millis(100)).print();
+ all_circuits.verify_aggregation(&new_agg_receipt)?;
+
+ timing = TimingTree::new("prove nested aggression", log::Level::Info);
+
+ // We can duplicate the proofs here because the state hasn't mutated.
+ agg_receipt =
+ all_circuits.prove_aggregation(is_agg, &agg_receipt, true, &new_agg_receipt)?;
+ is_agg = true;
+ timing.filter(Duration::from_millis(100)).print();
+
+ all_circuits.verify_aggregation(&agg_receipt)?;
+ }
+
+ log::info!(
+ "proof size: {:?}",
+ serde_json::to_string(&agg_receipt.proof().proof)
+ .unwrap()
+ .len()
+ );
+ let final_receipt = if seg_file_number > 1 {
+ let block_receipt = all_circuits.prove_block(None, &agg_receipt)?;
+ all_circuits.verify_block(&block_receipt)?;
+ let builder = WrapperBuilder::<DefaultParameters, 2>::new();
+ let mut circuit = builder.build();
+ circuit.set_data(all_circuits.block.circuit);
+ let mut bit_size = vec![32usize; 16];
+ bit_size.extend(vec![8; 32]);
+ bit_size.extend(vec![64; 68]);
+ let wrapped_circuit = WrappedCircuit::<InnerParameters, OuterParameters, D>::build(
+ circuit,
+ Some((vec![], bit_size)),
+ );
+ let wrapped_proof = wrapped_circuit.prove(&block_receipt.proof()).unwrap();
+ wrapped_proof.save(outdir).unwrap();
+
+ let src_public_inputs = match &block_receipt {
+ Receipt::Segments(receipt) => &receipt.proof.public_inputs,
+ Receipt::Composite(recepit) => &recepit.program_receipt.proof.public_inputs,
+ };
+ let block_public_inputs = serde_json::json!({
+ "public_inputs": src_public_inputs,
+ });
+ let outdir_path = std::path::Path::new(outdir);
+ let public_values_file = File::create(outdir_path.join("public_values.json"))?;
+ serde_json::to_writer(&public_values_file, &block_receipt.values())?;
+ let block_public_inputs_file = File::create(outdir_path.join("block_public_inputs.json"))?;
+ serde_json::to_writer(&block_public_inputs_file, &block_public_inputs)?;
+
+ block_receipt
+ } else {
+ agg_receipt
+ };
+
+ log::info!("build finish");
+
+ total_timing.filter(Duration::from_millis(100)).print();
+ Ok(final_receipt)
+}