aboutsummaryrefslogtreecommitdiff
path: root/guests/eddsa_poseidon/src/lib.rs
blob: 53df361f77faa8eacb906ff67ed9527d3a69b445 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#![cfg_attr(feature = "no_std", no_std)]

use std::default::Default;
use std::hash::Hasher;
use std::hash::poseidon::PoseidonHasher;

mod ec;
use ec::{ baby_jubjub, TEPoint, Field };

#[guests_macro::proving_entrypoint]
pub fn main(
    msg: Field,
    pub_key_x: Field,
    pub_key_y: Field,
    r8_x: Field,
    r8_y: Field,
    s: Field,
) -> bool {
    eddsa_verify::<PoseidonHasher>(pub_key_x, pub_key_y, s, r8_x, r8_y, msg)
}

pub fn eddsa_verify<H>(
    pub_key_x: Field,
    pub_key_y: Field,
    signature_s: Field,
    signature_r8_x: Field,
    signature_r8_y: Field,
    message: Field,
) -> bool
where
    H: Hasher + Default,
{
    // Verifies by testing:
    // S * B8 = R8 + H(R8, A, m) * A8
    let bjj = baby_jubjub();

    let pub_key = TEPoint::new(pub_key_x, pub_key_y);
    assert!(bjj.curve.contains(pub_key));

    let signature_r8 = TEPoint::new(signature_r8_x, signature_r8_y);
    assert!(bjj.curve.contains(signature_r8));
    // Ensure S < Subgroup Order
    assert!(signature_s.lt(&bjj.suborder));
    // Calculate the h = H(R, A, msg)
    let mut hasher = H::default();
    hasher.write(signature_r8_x);
    hasher.write(signature_r8_y);
    hasher.write(pub_key_x);
    hasher.write(pub_key_y);
    hasher.write(message);
    let hash: Field = hasher.finish().into();
    // Calculate second part of the right side:  right2 = h*8*A
    // Multiply by 8 by doubling 3 times. This also ensures that the result is in the subgroup.
    let pub_key_mul_2 = bjj.curve.add(pub_key, pub_key);
    let pub_key_mul_4 = bjj.curve.add(pub_key_mul_2, pub_key_mul_2);
    let pub_key_mul_8 = bjj.curve.add(pub_key_mul_4, pub_key_mul_4);
    // We check that A8 is not zero.
    assert!(!pub_key_mul_8.is_zero());
    // Compute the right side: R8 + h * A8
    let right = bjj.curve.add(signature_r8, bjj.curve.mul(hash, pub_key_mul_8));
    // Calculate left side of equation left = S * B8
    let left = bjj.curve.mul(signature_s, bjj.base8);

    left.eq(&right)
}