diff --git a/Cargo.toml b/Cargo.toml index 34a0f41a..b68e1f6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,12 @@ opt-level = 3 codegen-units = 1 lto = "fat" +# Fast release builds for development iteration: cargo build --profile release-fast +[profile.release-fast] +inherits = "release" +codegen-units = 16 +lto = "thin" + # Doing light optimizations helps test performance more than it hurts build time. [profile.test] opt-level = 2 @@ -120,6 +126,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10.9" test-case = "3.3.1" +tikv-jemallocator = "0.6" toml = "0.8.8" tokio = { version = "1.47.1", features = ["full"] } tokio-util = "0.7.13" @@ -155,7 +162,7 @@ ark-poly = "0.5" ark-serialize = "0.5" ark-std = { version = "0.5", features = ["std"] } spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ - "arkworks-algebra", -], rev = "ecb4f08373ed930175585c856517efdb1851fb47" } -spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish", rev = "ecb4f08373ed930175585c856517efdb1851fb47" } -whir = { git = "https://github.com/WizardOfMenlo/whir/", features = ["tracing"], rev = "cf1599b56ff50e09142ebe6d2e2fbd86875c9986" } + "ark-ff", "sha2", +], rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } +spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish", rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } +whir = { git = "https://github.com/WizardOfMenlo/whir/", rev = "3056565b90931c28f725f6655d954bd8f17eaaf6", features = ["tracing"] } diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index c6b49944..710aa49f 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -18,7 +18,6 @@ noirc_abi.workspace = true # Cryptography and proof systems ark-bn254.workspace = true -ark-crypto-primitives.workspace = true ark-ff.workspace = true ark-serialize.workspace = true ark-std.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index dc10ff4a..1090abc8 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -21,9 +21,32 @@ pub use { prover::Prover, r1cs::R1CS, verifier::Verifier, - whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, + whir_r1cs::{ + WhirConfig, WhirDomainSeparator, WhirProof, WhirProverState, WhirR1CSProof, WhirR1CSScheme, + }, witness::PublicInputs, }; +/// SHA-256 based transcript sponge for Fiat-Shamir. +pub type TranscriptSponge = spongefish::instantiations::SHA256; + +/// Register provekit's custom implementations in whir's global registries. +/// +/// Must be called once before any prove/verify operations. +/// Idempotent — safe to call multiple times. +pub fn register_ntt() { + use std::sync::{Arc, Once}; + static INIT: Once = Once::new(); + INIT.call_once(|| { + let ntt: Arc> = + Arc::new(whir::algebra::ntt::ArkNtt::::default()); + whir::algebra::ntt::NTT.insert(ntt); + + let skyscraper: Arc = + Arc::new(skyscraper::SkyscraperHashEngine); + whir::hash::ENGINES.register(skyscraper); + }); +} + #[cfg(test)] mod tests {} diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index 2caecdc2..eca2c2ae 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -5,5 +5,5 @@ mod whir; pub use self::{ pow::SkyscraperPoW, sponge::SkyscraperSponge, - whir::{SkyscraperCRH, SkyscraperMerkleConfig}, + whir::{SkyscraperHashEngine, SKYSCRAPER}, }; diff --git a/provekit/common/src/skyscraper/pow.rs b/provekit/common/src/skyscraper/pow.rs index d5476bae..7c0678c5 100644 --- a/provekit/common/src/skyscraper/pow.rs +++ b/provekit/common/src/skyscraper/pow.rs @@ -1,56 +1,44 @@ use { skyscraper::pow::{solve, verify}, - spongefish_pow::PowStrategy, + spongefish_pow::{PoWSolution, PowStrategy}, zerocopy::transmute, }; -/// Skyscraper proof of work #[derive(Clone, Copy)] pub struct SkyscraperPoW { - challenge: [u64; 4], + challenge: [u8; 32], bits: f64, } impl PowStrategy for SkyscraperPoW { fn new(challenge: [u8; 32], bits: f64) -> Self { assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60"); - Self { - challenge: transmute!(challenge), - bits, - } + Self { challenge, bits } } fn check(&mut self, nonce: u64) -> bool { - verify(self.challenge, self.bits, nonce) + verify(transmute!(self.challenge), self.bits, nonce) + } + + fn solution(&self, nonce: u64) -> PoWSolution { + PoWSolution { + challenge: self.challenge, + nonce, + } } - fn solve(&mut self) -> Option { - Some(solve(self.challenge, self.bits)) + fn solve(&mut self) -> Option { + let nonce = solve(transmute!(self.challenge), self.bits); + Some(self.solution(nonce)) } } #[test] fn test_pow_skyscraper() { - use { - spongefish::{ - ByteDomainSeparator, BytesToUnitDeserialize, BytesToUnitSerialize, DefaultHash, - DomainSeparator, - }, - spongefish_pow::{PoWChallenge, PoWDomainSeparator}, - }; - - const BITS: f64 = 10.0; - - let iopattern = DomainSeparator::::new("the proof of work lottery 🎰") - .add_bytes(1, "something") - .challenge_pow("rolling dices"); - - let mut prover = iopattern.to_prover_state(); - prover.add_bytes(b"\0").expect("Invalid IOPattern"); - prover.challenge_pow::(BITS).unwrap(); - - let mut verifier = iopattern.to_verifier_state(prover.narg_string()); - let byte = verifier.next_bytes::<1>().unwrap(); - assert_eq!(&byte, b"\0"); - verifier.challenge_pow::(BITS).unwrap(); + let challenge = [42u8; 32]; + let bits = 10.0; + let mut pow = SkyscraperPoW::new(challenge, bits); + let solution = pow.solve().expect("should find nonce"); + assert_eq!(solution.challenge, challenge); + assert!(pow.check(solution.nonce)); } diff --git a/provekit/common/src/skyscraper/sponge.rs b/provekit/common/src/skyscraper/sponge.rs index 774d160e..75ad1da3 100644 --- a/provekit/common/src/skyscraper/sponge.rs +++ b/provekit/common/src/skyscraper/sponge.rs @@ -1,60 +1,41 @@ use { - crate::FieldElement, ark_bn254::Fr, ark_ff::{BigInt, PrimeField}, - spongefish::duplex_sponge::{DuplexSponge, Permutation}, - zeroize::Zeroize, + spongefish::{DuplexSponge, Permutation}, }; -fn to_fr(x: FieldElement) -> Fr { - Fr::new(BigInt(x.into_bigint().0)) -} -fn from_fr(x: Fr) -> FieldElement { - FieldElement::new(x.into_bigint()) -} - -fn bigint_from_bytes_le(bytes: &[u8]) -> BigInt { - let limbs = bytes - .chunks_exact(8) - .map(|s| u64::from_le_bytes(s.try_into().unwrap())) - .collect::>(); - BigInt::new(limbs.try_into().unwrap()) -} - -type State = [FieldElement; 2]; - -#[derive(Clone, Default, Zeroize)] -pub struct Skyscraper { - state: State, -} - -impl AsRef<[FieldElement]> for Skyscraper { - fn as_ref(&self) -> &[FieldElement] { - &self.state +fn bytes_to_fr(bytes: &[u8]) -> Fr { + let mut limbs = [0u64; 4]; + for (i, chunk) in bytes.chunks_exact(8).enumerate() { + limbs[i] = u64::from_le_bytes(chunk.try_into().unwrap()); } + Fr::new(BigInt(limbs)) } -impl AsMut<[FieldElement]> for Skyscraper { - fn as_mut(&mut self) -> &mut [FieldElement] { - &mut self.state + +fn fr_to_bytes(f: Fr) -> [u8; 32] { + let limbs = f.into_bigint().0; + let mut out = [0u8; 32]; + for (i, &limb) in limbs.iter().enumerate() { + out[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes()); } + out } -impl Permutation for Skyscraper { - type U = FieldElement; - const N: usize = 2; - const R: usize = 1; +#[derive(Clone, Default)] +pub struct Skyscraper; - fn new(iv: [u8; 32]) -> Self { - let felt = FieldElement::new(bigint_from_bytes_le(&iv)); - Self { - state: [0.into(), felt], - } - } +impl Permutation<64> for Skyscraper { + type U = u8; - fn permute(&mut self) { - let (l2, r2) = skyscraper::reference::permute(to_fr(self.state[0]), to_fr(self.state[1])); - self.state = [from_fr(l2), from_fr(r2)]; + fn permute(&self, state: &[u8; 64]) -> [u8; 64] { + let left = bytes_to_fr(&state[..32]); + let right = bytes_to_fr(&state[32..]); + let (l2, r2) = skyscraper::reference::permute(left, right); + let mut out = [0u8; 64]; + out[..32].copy_from_slice(&fr_to_bytes(l2)); + out[32..].copy_from_slice(&fr_to_bytes(r2)); + out } } -pub type SkyscraperSponge = DuplexSponge; +pub type SkyscraperSponge = DuplexSponge; diff --git a/provekit/common/src/skyscraper/whir.rs b/provekit/common/src/skyscraper/whir.rs index d855f414..6ee863cf 100644 --- a/provekit/common/src/skyscraper/whir.rs +++ b/provekit/common/src/skyscraper/whir.rs @@ -1,111 +1,198 @@ +// Use the fastest available compress_many for this platform. +#[cfg(target_arch = "aarch64")] +use skyscraper::block4::compress_many; +#[cfg(not(target_arch = "aarch64"))] +use skyscraper::simple::compress_many; use { - crate::{skyscraper::SkyscraperSponge, FieldElement}, - ark_crypto_primitives::{ - crh::{CRHScheme, TwoToOneCRHScheme}, - merkle_tree::{Config, IdentityDigestConverter}, - Error, + std::borrow::Cow, + whir::{ + engines::EngineId, + hash::{Hash, HashEngine}, }, - ark_ff::{BigInt, PrimeField}, - rand08::Rng, - serde::{Deserialize, Serialize}, - spongefish::{ - codecs::arkworks_algebra::{ - FieldDomainSeparator, FieldToUnitDeserialize, FieldToUnitSerialize, - }, - DomainSeparator, ProofResult, ProverState, VerifierState, - }, - std::borrow::Borrow, }; -fn compress(l: FieldElement, r: FieldElement) -> FieldElement { - let l64 = l.into_bigint().0; - let r64 = r.into_bigint().0; - let out = skyscraper::simple::compress(l64, r64); - FieldElement::new(BigInt(out)) -} +/// Pre-computed `EngineId` for the Skyscraper hash engine. +/// +/// Derived as `SHA3-256("whir::hash" || "skyscraper")`. +pub const SKYSCRAPER: EngineId = EngineId::new([ + 0xa5, 0x0d, 0x5e, 0xe2, 0xa3, 0xfc, 0x52, 0xe9, 0x6f, 0x11, 0x10, 0x3c, 0xbb, 0x8a, 0x65, 0xa3, + 0x77, 0xb5, 0x82, 0xb0, 0xb2, 0xdd, 0x42, 0x1c, 0x66, 0x19, 0x13, 0xe6, 0xa5, 0x63, 0xf8, 0xa1, +]); + +#[derive(Clone, Copy, Debug)] +pub struct SkyscraperHashEngine; + +impl HashEngine for SkyscraperHashEngine { + fn name(&self) -> Cow<'_, str> { + "skyscraper".into() + } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct SkyscraperCRH; + fn supports_size(&self, size: usize) -> bool { + size > 0 && size % 32 == 0 + } -impl CRHScheme for SkyscraperCRH { - type Input = [FieldElement]; - type Output = FieldElement; - type Parameters = (); - fn setup(_r: &mut R) -> Result { - Ok(()) + fn preferred_batch_size(&self) -> usize { + skyscraper::WIDTH_LCM } - fn evaluate>( - _: &Self::Parameters, - input: T, - ) -> Result { - input - .borrow() - .iter() - .copied() - .reduce(compress) - .ok_or(Error::IncorrectInputLength(0)) + + fn hash_many(&self, size: usize, input: &[u8], output: &mut [Hash]) { + assert!( + self.supports_size(size), + "skyscraper: unsupported message size {size} (must be a positive multiple of 32)" + ); + + let count = output.len(); + assert_eq!( + input.len(), + size * count, + "skyscraper: input length {} != size {size} * count {count}", + input.len() + ); + + // SAFETY: `output` is `&mut [[u8; 32]]` with `count` elements, so it occupies + // exactly `count * 32` contiguous bytes. We reinterpret as a flat `&mut [u8]` + // to interface with `compress_many` which operates on byte slices. + let out_bytes = + unsafe { std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::(), count * 32) }; + + if size == 32 { + out_bytes.copy_from_slice(input); + return; + } + + if size == 64 { + compress_many(input, out_bytes); + return; + } + + // Leaf hashing: left-fold 32-byte chunks, batched across messages + // for SIMD throughput. Equivalent to main's SkyscraperCRH::evaluate: + // elements.reduce(compress) + // Processes in fixed-size groups to avoid heap allocation. + const GROUP: usize = 4 * skyscraper::WIDTH_LCM; // fits in 3 KiB on stack + let chunks_per_msg = size / 32; + let mut pair_buf = [0u8; GROUP * 64]; + + for start in (0..count).step_by(GROUP) { + let n = (count - start).min(GROUP); + let pairs = &mut pair_buf[..n * 64]; + let accs = &mut out_bytes[start * 32..(start + n) * 32]; + + for i in 0..n { + let msg = &input[(start + i) * size..]; + pairs[i * 64..i * 64 + 32].copy_from_slice(&msg[..32]); + pairs[i * 64 + 32..i * 64 + 64].copy_from_slice(&msg[32..64]); + } + compress_many(pairs, accs); + + for k in 2..chunks_per_msg { + for i in 0..n { + let msg = &input[(start + i) * size..]; + pairs[i * 64..i * 64 + 32].copy_from_slice(&accs[i * 32..i * 32 + 32]); + pairs[i * 64 + 32..i * 64 + 64].copy_from_slice(&msg[k * 32..k * 32 + 32]); + } + compress_many(pairs, accs); + } + } } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct SkyscraperTwoToOne; +#[cfg(test)] +mod tests { + use {super::*, zerocopy::IntoBytes}; -impl TwoToOneCRHScheme for SkyscraperTwoToOne { - type Input = FieldElement; - type Output = FieldElement; - type Parameters = (); - fn setup(_r: &mut R) -> Result { - Ok(()) + fn limbs_to_bytes(limbs: [u64; 4]) -> [u8; 32] { + let mut out = [0u8; 32]; + out[0..8].copy_from_slice(&limbs[0].to_le_bytes()); + out[8..16].copy_from_slice(&limbs[1].to_le_bytes()); + out[16..24].copy_from_slice(&limbs[2].to_le_bytes()); + out[24..32].copy_from_slice(&limbs[3].to_le_bytes()); + out } - fn evaluate>( - _: &Self::Parameters, - l: T, - r: T, - ) -> Result { - Ok(compress(*l.borrow(), *r.borrow())) + + #[test] + fn engine_id_matches() { + use whir::engines::Engine; + assert_eq!(SkyscraperHashEngine.engine_id(), SKYSCRAPER); } - fn compress>( - p: &Self::Parameters, - l: T, - r: T, - ) -> Result { - ::evaluate(p, l, r) + + #[test] + fn supports_expected_sizes() { + let e = SkyscraperHashEngine; + assert!(!e.supports_size(0)); + assert!(!e.supports_size(1)); + assert!(!e.supports_size(31)); + assert!(e.supports_size(32)); + assert!(e.supports_size(64)); + assert!(e.supports_size(512)); + assert!(e.supports_size(1024)); } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct SkyscraperMerkleConfig; + #[test] + fn two_to_one_matches_simple_compress() { + let l: [u64; 4] = [1, 2, 3, 4]; + let r: [u64; 4] = [5, 6, 7, 8]; + let expected = skyscraper::simple::compress(l, r); -impl Config for SkyscraperMerkleConfig { - type Leaf = [FieldElement]; - type LeafDigest = FieldElement; - type LeafInnerDigestConverter = IdentityDigestConverter; - type InnerDigest = FieldElement; - type LeafHash = SkyscraperCRH; - type TwoToOneHash = SkyscraperTwoToOne; -} + let mut input = [0u8; 64]; + input[0..32].copy_from_slice(&limbs_to_bytes(l)); + input[32..64].copy_from_slice(&limbs_to_bytes(r)); + + let mut output = [Hash::default()]; + SkyscraperHashEngine.hash_many(64, &input, &mut output); -impl whir::whir::domainsep::DigestDomainSeparator - for DomainSeparator -{ - fn add_digest(self, label: &str) -> Self { - >::add_scalars(self, 1, label) + assert_eq!(output[0].0, limbs_to_bytes(expected)); } -} -impl whir::whir::utils::DigestToUnitSerialize - for ProverState -{ - fn add_digest(&mut self, digest: FieldElement) -> ProofResult<()> { - self.add_scalars(&[digest]) + #[test] + fn leaf_hash_matches_fold() { + let elems: [[u64; 4]; 4] = [[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]; + + let expected = elems + .into_iter() + .reduce(skyscraper::simple::compress) + .unwrap(); + + let mut output = [Hash::default()]; + SkyscraperHashEngine.hash_many(128, elems.as_bytes(), &mut output); + + assert_eq!(output[0].0, limbs_to_bytes(expected)); } -} -impl whir::whir::utils::DigestToUnitDeserialize - for VerifierState<'_, SkyscraperSponge, FieldElement> -{ - fn read_digest(&mut self) -> ProofResult { - let [r] = self.next_scalars()?; - Ok(r) + #[test] + fn batch_two_to_one_consistency() { + let pairs: [[[u64; 4]; 2]; 3] = [ + [[1, 2, 3, 4], [5, 6, 7, 8]], + [[9, 10, 11, 12], [13, 14, 15, 16]], + [[17, 18, 19, 20], [21, 22, 23, 24]], + ]; + + let mut batch_output = [Hash::default(); 3]; + SkyscraperHashEngine.hash_many(64, pairs.as_bytes(), &mut batch_output); + + for (i, pair) in pairs.iter().enumerate() { + let expected = skyscraper::simple::compress(pair[0], pair[1]); + assert_eq!(batch_output[i].0, limbs_to_bytes(expected)); + } + } + + #[test] + fn batch_leaf_hash_consistency() { + // 3 messages of 16 field elements each (512 bytes per message). + // Verify batched result matches per-message scalar reduce(compress). + let msgs: [[[u64; 4]; 16]; 3] = + std::array::from_fn(|i| std::array::from_fn(|j| [(i * 16 + j + 1) as u64, 0, 0, 0])); + + let mut batch_output = [Hash::default(); 3]; + SkyscraperHashEngine.hash_many(512, msgs.as_bytes(), &mut batch_output); + + for (i, msg) in msgs.iter().enumerate() { + let expected = msg + .iter() + .copied() + .reduce(skyscraper::simple::compress) + .unwrap(); + assert_eq!(batch_output[i].0, limbs_to_bytes(expected)); + } } } diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 6baef51d..1236789b 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -5,7 +5,6 @@ use { }, ark_std::{One, Zero}, rayon::iter::{IndexedParallelIterator as _, IntoParallelRefIterator, ParallelIterator as _}, - spongefish::codecs::arkworks_algebra::FieldDomainSeparator, std::array, tracing::instrument, }; @@ -103,54 +102,6 @@ fn sumcheck_fold_map_reduce_inner( } } -/// Trait which is used to add sumcheck functionality fo `IOPattern` -pub trait SumcheckIOPattern { - /// Prover sends coefficients of the cubic sumcheck polynomial and the - /// verifier sends randomness for the next sumcheck round - fn add_sumcheck_polynomials(self, num_vars: usize) -> Self; - - /// Verifier sends the randomness on which the supposed 0-polynomial is - /// evaluated - fn add_rand(self, num_rand: usize) -> Self; - - fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self; - - /// Prover sends the hash of the public inputs - /// Verifier sends randomness to construct weights - fn add_public_inputs(self) -> Self; -} - -impl SumcheckIOPattern for IOPattern -where - IOPattern: FieldDomainSeparator, -{ - fn add_zk_sumcheck_polynomials(mut self, num_vars: usize) -> Self { - self = self.add_scalars(1, "Sum of G over boolean hypercube"); - self = self.challenge_scalars(1, "Rho"); - self = self.add_sumcheck_polynomials(num_vars); - self = self.add_scalars(2, "Polynomial sums"); - self - } - - fn add_sumcheck_polynomials(mut self, num_vars: usize) -> Self { - for _ in 0..num_vars { - self = self.add_scalars(4, "Sumcheck Polynomials"); - self = self.challenge_scalars(1, "Sumcheck Random"); - } - self - } - - fn add_public_inputs(mut self) -> Self { - self = self.add_scalars(1, "Public Inputs Hash"); - self = self.challenge_scalars(1, "Public Weights Vector Random"); - self - } - - fn add_rand(self, num_rand: usize) -> Self { - self.challenge_scalars(num_rand, "rand") - } -} - /// List of evaluations for eq(r, x) over the boolean hypercube #[instrument(skip_all)] pub fn calculate_evaluations_over_boolean_hypercube_for_eq( diff --git a/provekit/common/src/utils/zk_utils.rs b/provekit/common/src/utils/zk_utils.rs index 87fc806a..f6b7b6f0 100644 --- a/provekit/common/src/utils/zk_utils.rs +++ b/provekit/common/src/utils/zk_utils.rs @@ -1,10 +1,29 @@ use { crate::FieldElement, - ark_ff::{Field, UniformRand}, + ark_ff::UniformRand, rayon::prelude::*, - whir::poly_utils::evals::EvaluationsList, + whir::algebra::{ + dot, + ntt::wavelet_transform, + polynomials::{CoefficientList, EvaluationsList}, + weights::Covector, + }, }; +/// Transform coefficients to evaluation form. Avoids the per-call +/// clone+transform inside `Covector::evaluate`. +pub fn coeffs_to_evals(poly: &CoefficientList) -> Vec { + let mut evals = poly.coeffs().to_vec(); + wavelet_transform(&mut evals); + evals +} + +/// Dot product of a covector's weight vector against pre-transformed +/// evaluations. +pub fn covector_dot(w: &Covector, evals: &[FieldElement]) -> FieldElement { + dot(&w.vector, evals) +} + pub fn create_masked_polynomial( original: EvaluationsList, mask: &[FieldElement], @@ -39,28 +58,3 @@ pub fn generate_random_multilinear_polynomial(num_vars: usize) -> Vec(mut a: F, n: usize, x: &[F]) -> F { - let k = x.len(); - assert!(n > 0 && n < (1 << k)); - let mut borrow_0 = F::one(); - let mut borrow_1 = F::zero(); - for (i, &xi) in x.iter().rev().enumerate() { - let bn = ((n - 1) >> i) & 1; - let b0 = F::one() - xi; - let b1 = a * xi; - (borrow_0, borrow_1) = if bn == 0 { - (b0 * borrow_0, (b0 + b1) * borrow_1 + b1 * borrow_0) - } else { - ((b0 + b1) * borrow_0 + b0 * borrow_1, b1 * borrow_1) - }; - a = a.square(); - } - borrow_0 -} diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index dc6ed361..2f813c0b 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -1,21 +1,23 @@ +#[cfg(debug_assertions)] +use whir::transcript::Interaction; use { - crate::{ - skyscraper::{SkyscraperMerkleConfig, SkyscraperPoW, SkyscraperSponge}, - utils::{serde_hex, sumcheck::SumcheckIOPattern}, - witness::WitnessIOPattern, - FieldElement, - }, + crate::{utils::serde_hex, FieldElement}, serde::{Deserialize, Serialize}, - spongefish::DomainSeparator, - std::fmt::{Debug, Formatter}, - tracing::instrument, - whir::whir::{domainsep::WhirDomainSeparator, parameters::WhirConfig as GenericWhirConfig}, + whir::{protocols::whir::Config as GenericWhirConfig, transcript}, }; -pub type WhirConfig = GenericWhirConfig; -pub type IOPattern = DomainSeparator; +pub type WhirConfig = GenericWhirConfig; -#[derive(Clone, PartialEq, Serialize, Deserialize)] +/// Type alias for the whir domain separator used in provekit's outer protocol. +pub type WhirDomainSeparator = transcript::DomainSeparator<'static, ()>; + +/// Type alias for the whir prover transcript state. +pub type WhirProverState = transcript::ProverState; + +/// Type alias for the whir proof. +pub type WhirProof = transcript::Proof; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct WhirR1CSScheme { pub m: usize, pub w1_size: usize, @@ -28,63 +30,22 @@ pub struct WhirR1CSScheme { } impl WhirR1CSScheme { - #[instrument(skip_all)] - pub fn create_io_pattern(&self) -> IOPattern { - let mut io = IOPattern::new("🌪️"); - - if self.num_challenges > 0 { - // Compute total constraints: OOD + statement - // OOD: 2 witnesses × committment_ood_samples each - // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, - // statement_2 has 3 constraints = 3, total = 7 - let num_witnesses = 2; - let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = if self.has_public_inputs { 7 } else { 6 }; - let num_constraints_total = num_ood_constraints + num_statement_constraints; - - io = io - .commit_statement(&self.whir_witness) // C1 - .add_logup_challenges(self.num_challenges) - .commit_statement(&self.whir_witness) // C2 - .add_rand(self.m_0) - .commit_statement(&self.whir_for_hiding_spartan) - .add_zk_sumcheck_polynomials(self.m_0) - .add_whir_proof(&self.whir_for_hiding_spartan) - .add_public_inputs() - .hint("claimed_evaluations_1") - .hint("claimed_evaluations_2") - .hint("public_weights_evaluations") - .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); - } else { - io = io - .commit_statement(&self.whir_witness) - .add_rand(self.m_0) - .commit_statement(&self.whir_for_hiding_spartan) - .add_zk_sumcheck_polynomials(self.m_0) - .add_whir_proof(&self.whir_for_hiding_spartan) - .add_public_inputs() - .hint("claimed_evaluations") - .hint("public_weights_evaluations") - .add_whir_proof(&self.whir_witness); - } - - io + /// Create a domain separator for the provekit outer protocol. + pub fn create_domain_separator(&self) -> WhirDomainSeparator { + transcript::DomainSeparator::protocol(self) } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct WhirR1CSProof { #[serde(with = "serde_hex")] - pub transcript: Vec, -} + pub narg_string: Vec, + #[serde(with = "serde_hex")] + pub hints: Vec, -// TODO: Implement Debug for WhirConfig and derive. -impl Debug for WhirR1CSScheme { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WhirR1CSScheme") - .field("m", &self.m) - .field("w1_size", &self.w1_size) - .field("m_0", &self.m_0) - .finish() - } + /// Transcript interaction pattern for debug-mode validation. + /// Populated by the prover; absent from serialized proofs on disk. + #[cfg(debug_assertions)] + #[serde(default, skip_serializing)] + pub pattern: Vec, } diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 5ebfaa24..0f7cb866 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -4,16 +4,13 @@ mod ram; mod scheduling; mod witness_builder; mod witness_generator; -mod witness_io_pattern; use { crate::{ - skyscraper::SkyscraperCRH, utils::{serde_ark, serde_ark_vec}, FieldElement, }, - ark_crypto_primitives::crh::CRHScheme, - ark_ff::One, + ark_ff::{BigInt, One, PrimeField}, serde::{Deserialize, Serialize}, }; pub use { @@ -26,7 +23,6 @@ pub use { WitnessCoefficient, }, witness_generator::NoirWitnessGenerator, - witness_io_pattern::WitnessIOPattern, }; /// The index of the constant 1 witness in the R1CS instance @@ -68,15 +64,15 @@ impl PublicInputs { } pub fn hash(&self) -> FieldElement { + fn compress(l: FieldElement, r: FieldElement) -> FieldElement { + let out = skyscraper::simple::compress(l.into_bigint().0, r.into_bigint().0); + FieldElement::new(BigInt(out)) + } + match self.0.len() { 0 => FieldElement::from(0u64), - 1 => { - // For single element, hash it with zero to ensure it gets properly hashed - let padded = vec![self.0[0], FieldElement::from(0u64)]; - SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed") - } - _ => SkyscraperCRH::evaluate(&(), &self.0[..]) - .expect("hash should succeed for multiple inputs"), + 1 => compress(self.0[0], FieldElement::from(0u64)), + _ => self.0.iter().copied().reduce(compress).unwrap(), } } } diff --git a/provekit/common/src/witness/witness_io_pattern.rs b/provekit/common/src/witness/witness_io_pattern.rs deleted file mode 100644 index b0554455..00000000 --- a/provekit/common/src/witness/witness_io_pattern.rs +++ /dev/null @@ -1,21 +0,0 @@ -use {crate::FieldElement, spongefish::codecs::arkworks_algebra::FieldDomainSeparator}; - -/// Trait which is used to add witness RNG for IOPattern -pub trait WitnessIOPattern { - /// Schedule absorption of `num_challenges` Fiat–Shamir challenges for - /// LogUp/Spice. - fn add_logup_challenges(self, num_challenges: usize) -> Self; -} - -impl WitnessIOPattern for IOPattern -where - IOPattern: FieldDomainSeparator, -{ - fn add_logup_challenges(self, num_challenges: usize) -> Self { - if num_challenges > 0 { - self.challenge_scalars(num_challenges, "wb:challenges") - } else { - self - } - } -} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index bb89b790..be86c8f3 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -6,9 +6,12 @@ use { nargo::foreign_calls::DefaultForeignCallBuilder, noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, - provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, + provekit_common::{ + FieldElement, NoirElement, NoirProof, Prover, PublicInputs, TranscriptSponge, + }, std::path::Path, tracing::instrument, + whir::transcript::{codecs::Empty, ProverState}, }; mod r1cs; @@ -52,6 +55,8 @@ impl Prove for Prover { #[instrument(skip_all)] fn prove(mut self, prover_toml: impl AsRef) -> Result { + provekit_common::register_ntt(); + let (input_map, _expected_return) = read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; @@ -59,9 +64,11 @@ impl Prove for Prover { let acir_public_inputs = self.program.functions[0].public_inputs().indices(); // Set up transcript - let io: IOPattern = self.whir_for_witness.create_io_pattern(); - let mut merlin = io.to_prover_state(); - drop(io); + let ds = self + .whir_for_witness + .create_domain_separator() + .instance(&Empty); + let mut merlin = ProverState::new(&ds, TranscriptSponge::default()); let mut witness: Vec> = vec![None; self.r1cs.num_witnesses()]; diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index c48f1f9c..30cf4726 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -4,13 +4,12 @@ use { crate::witness::witness_builder::WitnessBuilderSolver, acir::native_types::WitnessMap, provekit_common::{ - skyscraper::SkyscraperSponge, utils::batch_inverse_montgomery, witness::{LayerType, LayeredWitnessBuilders, WitnessBuilder}, - FieldElement, NoirElement, R1CS, + FieldElement, NoirElement, TranscriptSponge, R1CS, }, - spongefish::ProverState, tracing::instrument, + whir::transcript::ProverState, }; pub trait R1CSSolver { @@ -19,7 +18,7 @@ pub trait R1CSSolver { witness: &mut Vec>, plan: LayeredWitnessBuilders, acir_map: &WitnessMap, - transcript: &mut ProverState, + transcript: &mut ProverState, ); #[cfg(test)] @@ -53,14 +52,14 @@ impl R1CSSolver for R1CS { witness: &mut Vec>, plan: LayeredWitnessBuilders, acir_map: &WitnessMap, - transcript: &mut ProverState, + transcript: &mut ProverState, ) { for layer in &plan.layers { match layer.typ { LayerType::Other => { // Execute regular operations for builder in &layer.witness_builders { - builder.solve(&acir_map, witness, transcript); + builder.solve(acir_map, witness, transcript); } } LayerType::Inverse => { diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 9ded36f9..82c7c2b2 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,7 +3,6 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ - skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, utils::{ pad_to_power_of_two, sumcheck::{ @@ -11,39 +10,39 @@ use { calculate_external_row_of_r1cs_matrices, calculate_witness_bounds, eval_cubic_poly, sumcheck_fold_map_reduce, }, - zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, + zk_utils::{ + coeffs_to_evals, covector_dot, create_masked_polynomial, + generate_random_multilinear_polynomial, + }, HALF, }, - FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, - }, - spongefish::{ - codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, - ProverState, + FieldElement, PublicInputs, TranscriptSponge, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + R1CS, }, std::mem, - tracing::{info, instrument, warn}, + tracing::{debug, instrument}, whir::{ - poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::{ - committer::{CommitmentWriter, Witness}, - prover::Prover, - statement::{Statement, Weights}, - utils::HintSerialize, + algebra::{ + embedding::Basefield, + polynomials::{CoefficientList, EvaluationsList, MultilinearPoint}, + weights::{Covector, Evaluate}, }, + protocols::whir::Witness, + transcript::{ProverState, VerifierMessage}, }, }; pub struct WhirR1CSCommitment { - pub commitment_to_witness: Witness, - pub masked_polynomial: EvaluationsList, - pub random_polynomial: EvaluationsList, - pub padded_witness: Vec, + pub commitment_to_witness: Witness, + pub masked_polynomial_coeff: CoefficientList, + pub random_polynomial_coeff: CoefficientList, + pub padded_witness: Vec, } pub trait WhirR1CSProver { fn commit( &self, - merlin: &mut ProverState, + merlin: &mut ProverState, r1cs: &R1CS, witness: Vec, is_w1: bool, @@ -51,7 +50,7 @@ pub trait WhirR1CSProver { fn prove( &self, - merlin: ProverState, + merlin: ProverState, r1cs: R1CS, commitments: Vec, public_inputs: &PublicInputs, @@ -62,7 +61,7 @@ impl WhirR1CSProver for WhirR1CSScheme { #[instrument(skip_all)] fn commit( &self, - merlin: &mut ProverState, + merlin: &mut ProverState, r1cs: &R1CS, witness: Vec, is_w1: bool, @@ -87,7 +86,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ); // log2(domain) for WHIR witness evaluations. - let whir_num_vars = self.whir_witness.mv_parameters.num_variables; + let whir_num_vars = self.whir_witness.initial_num_variables(); // Expected evaluation length = 2^(log2(domain) - 1). let target_len = 1usize << (whir_num_vars - 1); @@ -100,7 +99,7 @@ impl WhirR1CSProver for WhirR1CSScheme { let witness_polynomial_evals = EvaluationsList::new(padded_witness.clone()); - let (commitment_to_witness, masked_polynomial, random_polynomial) = + let (commitment_to_witness, masked_polynomial_coeff, random_polynomial_coeff) = batch_commit_to_polynomial( self.m, &self.whir_witness, @@ -110,8 +109,8 @@ impl WhirR1CSProver for WhirR1CSScheme { Ok(WhirR1CSCommitment { commitment_to_witness, - masked_polynomial, - random_polynomial, + masked_polynomial_coeff, + random_polynomial_coeff, padded_witness, }) } @@ -119,7 +118,7 @@ impl WhirR1CSProver for WhirR1CSScheme { #[instrument(skip_all)] fn prove( &self, - mut merlin: ProverState, + mut merlin: ProverState, r1cs: R1CS, mut commitments: Vec, public_inputs: &PublicInputs, @@ -160,39 +159,51 @@ impl WhirR1CSProver for WhirR1CSScheme { if is_single { // Single commitment path let commitment = commitments.into_iter().next().unwrap(); - let alphas: [Vec; 3] = alphas.try_into().unwrap(); - - let (mut statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( - self.m, - &commitment.commitment_to_witness, - &commitment.masked_polynomial, - &commitment.random_polynomial, - &alphas, - ); + let (mut weights, f_sums, g_sums) = + create_weights_and_evaluations_for_two_polynomials::<3>( + self.m, + &commitment.masked_polynomial_coeff, + &commitment.random_polynomial_coeff, + &alphas, + ); - merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; + merlin.prover_hint_ark(&(f_sums, g_sums)); let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { - // If there are no public inputs, the hint is unused by the verifier and can be - // assigned an arbitrary value. - let public_f_sum = FieldElement::zero(); - let public_g_sum = FieldElement::zero(); - (public_f_sum, public_g_sum) + // If there are no public inputs, the hint is unused by the verifier + // and can be assigned an arbitrary value. + (FieldElement::zero(), FieldElement::zero()) } else { - update_statement_with_public_weights( - &mut statement, - &commitment.commitment_to_witness, - &commitment.masked_polynomial, - &commitment.random_polynomial, + compute_public_weight_evaluations( + &mut weights, + &commitment.masked_polynomial_coeff, + &commitment.random_polynomial_coeff, public_weight, ) }; - merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + merlin.prover_hint_ark(&(public_f_sum, public_g_sum)); + + // Build evaluations: for each weight, eval on masked + eval on random + let evaluations = compute_evaluations_single( + &weights, + &commitment.masked_polynomial_coeff, + &commitment.random_polynomial_coeff, + ); + + let weight_refs: Vec<&dyn Evaluate>> = weights + .iter() + .map(|w| w as &dyn Evaluate>) + .collect(); run_zk_whir_pcs_prover( - commitment.commitment_to_witness, - statement, + &[&commitment.commitment_to_witness], + &[ + &commitment.masked_polynomial_coeff, + &commitment.random_polynomial_coeff, + ], + &weight_refs, + &evaluations, &self.whir_witness, &mut merlin, ); @@ -214,55 +225,112 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - let (mut statement_1, f_sums_1, g_sums_1) = - create_combined_statement_over_two_polynomials::<3>( + let (mut weights_1, f_sums_1, g_sums_1) = + create_weights_and_evaluations_for_two_polynomials::<3>( self.m, - &c1.commitment_to_witness, - &c1.masked_polynomial, - &c1.random_polynomial, + &c1.masked_polynomial_coeff, + &c1.random_polynomial_coeff, &alphas_1, ); drop(alphas_1); - let (statement_2, f_sums_2, g_sums_2) = - create_combined_statement_over_two_polynomials::<3>( + let (weights_2, f_sums_2, g_sums_2) = + create_weights_and_evaluations_for_two_polynomials::<3>( self.m, - &c2.commitment_to_witness, - &c2.masked_polynomial, - &c2.random_polynomial, + &c2.masked_polynomial_coeff, + &c2.random_polynomial_coeff, &alphas_2, ); drop(alphas_2); - merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; - merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; - - let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { - let public_f_sum = FieldElement::zero(); - let public_g_sum = FieldElement::zero(); - (public_f_sum, public_g_sum) + // Compute cross-evaluations: weights_1 on c2's polynomials and + // weights_2 on c1's polynomials. Whir's prove() expects evaluations + // for ALL (weight, polynomial) pairs in row-major order. + let c1m_evals = coeffs_to_evals(&c1.masked_polynomial_coeff); + let c1r_evals = coeffs_to_evals(&c1.random_polynomial_coeff); + let c2m_evals = coeffs_to_evals(&c2.masked_polynomial_coeff); + let c2r_evals = coeffs_to_evals(&c2.random_polynomial_coeff); + let cross_f_12: Vec = weights_1 + .iter() + .map(|w| covector_dot(w, &c2m_evals)) + .collect(); + let cross_g_12: Vec = weights_1 + .iter() + .map(|w| covector_dot(w, &c2r_evals)) + .collect(); + let cross_f_21: Vec = weights_2 + .iter() + .map(|w| covector_dot(w, &c1m_evals)) + .collect(); + let cross_g_21: Vec = weights_2 + .iter() + .map(|w| covector_dot(w, &c1r_evals)) + .collect(); + + merlin.prover_hint_ark(&(f_sums_1, g_sums_1)); + merlin.prover_hint_ark(&(f_sums_2, g_sums_2)); + merlin.prover_hint_ark(&(cross_f_12, cross_g_12)); + merlin.prover_hint_ark(&(cross_f_21, cross_g_21)); + + let (public_f1, public_g1, public_f2, public_g2) = if public_inputs.is_empty() { + ( + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + ) } else { - update_statement_with_public_weights( - &mut statement_1, - &c1.commitment_to_witness, - &c1.masked_polynomial, - &c1.random_polynomial, + compute_public_weight_evaluations_dual( + &mut weights_1, + &c1.masked_polynomial_coeff, + &c1.random_polynomial_coeff, + &c2.masked_polynomial_coeff, + &c2.random_polynomial_coeff, public_weight, ) }; - merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + merlin.prover_hint_ark(&(public_f1, public_g1, public_f2, public_g2)); + + // Combine weights from both commitments + let mut all_weights = weights_1; + all_weights.extend(weights_2); - run_zk_whir_pcs_batch_prover( - &[c1.commitment_to_witness, c2.commitment_to_witness], - &[statement_1, statement_2], + // Build evaluations: for each weight, evaluate on all 4 polynomials + // (c1_masked, c1_random, c2_masked, c2_random) + // Row-major: evaluations[w_idx * 4 + p_idx] + let poly_evals = vec![c1m_evals, c1r_evals, c2m_evals, c2r_evals]; + let evaluations: Vec = all_weights + .iter() + .flat_map(|w| poly_evals.iter().map(|pe| covector_dot(w, pe))) + .collect(); + + let weight_refs: Vec<&dyn Evaluate>> = all_weights + .iter() + .map(|w| w as &dyn Evaluate>) + .collect(); + + run_zk_whir_pcs_prover( + &[&c1.commitment_to_witness, &c2.commitment_to_witness], + &[ + &c1.masked_polynomial_coeff, + &c1.random_polynomial_coeff, + &c2.masked_polynomial_coeff, + &c2.random_polynomial_coeff, + ], + &weight_refs, + &evaluations, &self.whir_witness, &mut merlin, ); } + let proof = merlin.proof(); Ok(WhirR1CSProof { - transcript: merlin.narg_string().to_vec(), + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, }) } } @@ -347,11 +415,11 @@ pub fn batch_commit_to_polynomial( m: usize, whir_config: &WhirConfig, witness: EvaluationsList, - merlin: &mut ProverState, + merlin: &mut ProverState, ) -> ( - Witness, - EvaluationsList, - EvaluationsList, + Witness, + CoefficientList, + CoefficientList, ) { let mask = generate_random_multilinear_polynomial(witness.num_variables()); let masked_polynomial_coeff = create_masked_polynomial(witness, &mask).to_coeffs(); @@ -360,18 +428,15 @@ pub fn batch_commit_to_polynomial( let random_polynomial_coeff = EvaluationsList::new(generate_random_multilinear_polynomial(m)).to_coeffs(); - let committer = CommitmentWriter::new(whir_config.clone()); - let witness_new = committer - .commit_batch(merlin, &[ - &masked_polynomial_coeff, - &random_polynomial_coeff, - ]) - .expect("WHIR prover failed to commit"); + let witness_new = whir_config.commit(merlin, &[ + &masked_polynomial_coeff, + &random_polynomial_coeff, + ]); ( witness_new, - masked_polynomial_coeff.into(), - random_polynomial_coeff.into(), + masked_polynomial_coeff, + random_polynomial_coeff, ) } @@ -410,15 +475,12 @@ pub fn pad_to_pow2_len_min2(v: &mut Vec) { pub fn run_zk_sumcheck_prover( r1cs: &R1CS, z: &[FieldElement], - merlin: &mut ProverState, + merlin: &mut ProverState, m_0: usize, whir_for_blinding_of_spartan_config: &WhirConfig, ) -> Vec { // r is the combination randomness from the 2nd item of the interaction phase - let mut r = vec![FieldElement::zero(); m_0]; - merlin - .fill_challenge_scalars(&mut r) - .expect("Failed to extract challenge scalars from Merlin"); + let r: Vec = merlin.verifier_message_vec(m_0); // let a = sum_fhat_1, b = sum_fhat_2, c = sum_fhat_3 for brevity let ((mut a, mut b, mut c), mut eq) = rayon::join( || calculate_witness_bounds(r1cs, z), @@ -436,9 +498,7 @@ pub fn run_zk_sumcheck_prover( let blinding_polynomial = generate_blinding_spartan_univariate_polys(m_0); // Spartan blinding: m = log2(domain), target_len = 2^(m-1). - let blinding_num_vars = whir_for_blinding_of_spartan_config - .mv_parameters - .num_variables; + let blinding_num_vars = whir_for_blinding_of_spartan_config.initial_num_variables(); let target_b = 1usize << (blinding_num_vars - 1); // Flatten and pad to exactly 1 << blinding_num_vars - 1 @@ -464,11 +524,9 @@ pub fn run_zk_sumcheck_prover( let sum_g_reduce = sum_over_hypercube(&blinding_polynomial); - let _ = merlin.add_scalars(&[sum_g_reduce]); + merlin.prover_message(&sum_g_reduce); - let mut rho_buf = [FieldElement::zero()]; - let _ = merlin.fill_challenge_scalars(&mut rho_buf); - let rho = rho_buf[0]; + let rho: FieldElement = merlin.verifier_message(); // Instead of proving that sum of F over the boolean hypercube is 0, we prove // that sum of F + rho * G over the boolean hypercube is rho * Sum G. @@ -531,10 +589,10 @@ pub fn run_zk_sumcheck_prover( + combined_hhat_i_coeffs[3] ); - let _ = merlin.add_scalars(&combined_hhat_i_coeffs[..]); - let mut alpha_i_wrapped_in_vector = [FieldElement::zero()]; - let _ = merlin.fill_challenge_scalars(&mut alpha_i_wrapped_in_vector); - let alpha_i = alpha_i_wrapped_in_vector[0]; + for coeff in &combined_hhat_i_coeffs { + merlin.prover_message(coeff); + } + let alpha_i: FieldElement = merlin.verifier_message(); alpha.push(alpha_i); fold = Some(alpha_i); @@ -544,24 +602,34 @@ pub fn run_zk_sumcheck_prover( } drop((a, b, c, eq)); - let (statement, blinding_mask_polynomial_sum, blinding_blind_polynomial_sum) = - create_combined_statement_over_two_polynomials::<1>( + let (blinding_weights, blinding_mask_polynomial_sum, blinding_blind_polynomial_sum) = + create_weights_and_evaluations_for_two_polynomials::<1>( blinding_polynomial_variables + 1, - &commitment_to_blinding_polynomial, &blindings_mask_polynomial, &blindings_blind_polynomial, &[expand_powers(alpha.as_slice())], ); - let _ = merlin.add_scalars(&[ - blinding_mask_polynomial_sum[0], - blinding_blind_polynomial_sum[0], - ]); + merlin.prover_message(&blinding_mask_polynomial_sum[0]); + merlin.prover_message(&blinding_blind_polynomial_sum[0]); + + let blinding_evaluations = compute_evaluations_single( + &blinding_weights, + &blindings_mask_polynomial, + &blindings_blind_polynomial, + ); + + let blinding_weight_refs: Vec<&dyn Evaluate>> = blinding_weights + .iter() + .map(|w| w as &dyn Evaluate>) + .collect(); let (_sums, _deferred) = run_zk_whir_pcs_prover( - commitment_to_blinding_polynomial, - statement, - &whir_for_blinding_of_spartan_config, + &[&commitment_to_blinding_polynomial], + &[&blindings_mask_polynomial, &blindings_blind_polynomial], + &blinding_weight_refs, + &blinding_evaluations, + whir_for_blinding_of_spartan_config, merlin, ); @@ -579,32 +647,31 @@ fn expand_powers(values: &[FieldElement]) -> Vec { result } -fn create_combined_statement_over_two_polynomials( +fn create_weights_and_evaluations_for_two_polynomials( cfg_nv: usize, - witness: &Witness, - f_polynomial: &EvaluationsList, - g_polynomial: &EvaluationsList, + f_polynomial: &CoefficientList, + g_polynomial: &CoefficientList, alphas: &[Vec; N], ) -> ( - Statement, + Vec>, Vec, Vec, ) { - // base_nv = cfg_nv - 1; lengths: 2^(cfg_nv-1) and 2^cfg_nv. let base_nv = cfg_nv.checked_sub(1).expect("cfg_nv >= 1"); let base_len = 1usize << base_nv; let final_len = 1usize << cfg_nv; - let mut statement = Statement::::new(cfg_nv); + let f_evals = coeffs_to_evals(f_polynomial); + let g_evals = coeffs_to_evals(g_polynomial); + + let mut weights = Vec::with_capacity(N); let mut f_sums = Vec::with_capacity(N); let mut g_sums = Vec::with_capacity(N); - for w in alphas.into_iter() { - // lift to 2^{cfg_nv} by zeroing the mask half: [w || 0] + for w in alphas.iter() { let mut w_full = Vec::with_capacity(final_len); w_full.extend_from_slice(w); - // Ensure w has length base_len (pad if shorter, assert if longer) if w_full.len() < base_len { w_full.resize(base_len, FieldElement::zero()); } else { @@ -612,98 +679,104 @@ fn create_combined_statement_over_two_polynomials( } w_full.resize(final_len, FieldElement::zero()); - let weight = Weights::linear(EvaluationsList::new(w_full)); - let f = weight.weighted_sum(f_polynomial); - let g = weight.weighted_sum(g_polynomial); + let weight = Covector::new(w_full); + f_sums.push(covector_dot(&weight, &f_evals)); + g_sums.push(covector_dot(&weight, &g_evals)); - statement.add_constraint(weight, f + witness.batching_randomness * g); - f_sums.push(f); - g_sums.push(g); + weights.push(weight); } - (statement, f_sums, g_sums) + (weights, f_sums, g_sums) } -#[instrument(skip_all)] -pub fn run_zk_whir_pcs_prover( - witnesses: Witness, - statements: Statement, - params: &WhirConfig, - merlin: &mut ProverState, -) -> (MultilinearPoint, Vec) { - info!("WHIR Parameters: {params}"); - - if !params.check_pow_bits() { - warn!("More PoW bits required than specified."); - } - - let prover = Prover::new(params.clone()); - let (randomness, deferred) = prover - .prove(merlin, statements, witnesses) - .expect("WHIR prover failed to generate a proof"); - - (randomness, deferred) +fn compute_evaluations_single( + weights: &[Covector], + masked_poly: &CoefficientList, + random_poly: &CoefficientList, +) -> Vec { + let masked_evals = coeffs_to_evals(masked_poly); + let random_evals = coeffs_to_evals(random_poly); + weights + .iter() + .flat_map(|w| { + [ + covector_dot(w, &masked_evals), + covector_dot(w, &random_evals), + ] + }) + .collect() } #[instrument(skip_all)] -pub fn run_zk_whir_pcs_batch_prover( - witnesses: &[Witness], - statements: &[Statement], +pub fn run_zk_whir_pcs_prover( + witnesses: &[&Witness], + polynomials: &[&CoefficientList], + weights: &[&dyn Evaluate>], + evaluations: &[FieldElement], params: &WhirConfig, - merlin: &mut ProverState, + merlin: &mut ProverState, ) -> (MultilinearPoint, Vec) { - info!("WHIR Parameters: {params}"); - - if !params.check_pow_bits() { - warn!("More PoW bits required than specified."); - } + debug!("WHIR Parameters: {params}"); - let prover = Prover::new(params.clone()); - let (randomness, deferred) = prover - .prove_batch(merlin, statements, witnesses) - .expect("WHIR prover failed to generate a proof"); + let (randomness, deferred) = params.prove(merlin, polynomials, witnesses, weights, evaluations); (randomness, deferred) } -fn update_statement_with_public_weights( - statement: &mut Statement, - witness: &Witness, - f_polynomial: &EvaluationsList, - g_polynomial: &EvaluationsList, - public_weights: Weights, +fn compute_public_weight_evaluations( + weights: &mut Vec>, + f_polynomial: &CoefficientList, + g_polynomial: &CoefficientList, + public_weights: Covector, ) -> (FieldElement, FieldElement) { - let f = public_weights.weighted_sum(f_polynomial); - let g = public_weights.weighted_sum(g_polynomial); - statement.add_constraint_in_front(public_weights, f + witness.batching_randomness * g); + let f_evals = coeffs_to_evals(f_polynomial); + let g_evals = coeffs_to_evals(g_polynomial); + let f = covector_dot(&public_weights, &f_evals); + let g = covector_dot(&public_weights, &g_evals); + weights.insert(0, public_weights); (f, g) } +fn compute_public_weight_evaluations_dual( + weights_1: &mut Vec>, + c1_masked: &CoefficientList, + c1_random: &CoefficientList, + c2_masked: &CoefficientList, + c2_random: &CoefficientList, + public_weights: Covector, +) -> (FieldElement, FieldElement, FieldElement, FieldElement) { + let c1m = coeffs_to_evals(c1_masked); + let c1r = coeffs_to_evals(c1_random); + let c2m = coeffs_to_evals(c2_masked); + let c2r = coeffs_to_evals(c2_random); + let f1 = covector_dot(&public_weights, &c1m); + let g1 = covector_dot(&public_weights, &c1r); + let f2 = covector_dot(&public_weights, &c2m); + let g2 = covector_dot(&public_weights, &c2r); + weights_1.insert(0, public_weights); + (f1, g1, f2, g2) +} + fn get_public_weights( public_inputs: &PublicInputs, - merlin: &mut ProverState, + merlin: &mut ProverState, m: usize, -) -> Weights { - // Add hash to transcript +) -> Covector { let public_inputs_hash = public_inputs.hash(); - let _ = merlin.add_scalars(&[public_inputs_hash]); + merlin.prover_message(&public_inputs_hash); - // Get random point x - let mut x_buf = [FieldElement::zero()]; - merlin - .fill_challenge_scalars(&mut x_buf) - .expect("Failed to get challenge from Merlin"); - let x = x_buf[0]; + let x: FieldElement = merlin.verifier_message(); let domain_size = 1 << m; let mut public_weights = vec![FieldElement::zero(); domain_size]; - // Set public weights for public inputs [1,x,x^2,x^3...x^n-1,0,0,0...0] let mut current_pow = FieldElement::one(); - for (idx, _) in public_inputs.0.iter().enumerate() { - public_weights[idx] = current_pow; - current_pow = current_pow * x; + for slot in public_weights.iter_mut().take(public_inputs.len()) { + *slot = current_pow; + current_pow *= x; } - Weights::geometric(x, public_inputs.len(), EvaluationsList::new(public_weights)) + let mut covector = Covector::new(public_weights); + covector.deferred = false; + covector } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 358e02e1..8ee617b3 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -4,15 +4,14 @@ use { ark_ff::{BigInteger, PrimeField}, ark_std::Zero, provekit_common::{ - skyscraper::SkyscraperSponge, utils::noir_to_native, witness::{ ConstantOrR1CSWitness, ConstantTerm, ProductLinearTerm, SumTerm, WitnessBuilder, WitnessCoefficient, BINOP_ATOMIC_BITS, }, - FieldElement, NoirElement, + FieldElement, NoirElement, TranscriptSponge, }, - spongefish::{codecs::arkworks_algebra::UnitToField, ProverState}, + whir::transcript::{ProverState, VerifierMessage}, }; pub trait WitnessBuilderSolver { @@ -20,7 +19,7 @@ pub trait WitnessBuilderSolver { &self, acir_witness_idx_to_value_map: &WitnessMap, witness: &mut [Option], - transcript: &mut ProverState, + transcript: &mut ProverState, ); } @@ -29,7 +28,7 @@ impl WitnessBuilderSolver for WitnessBuilder { &self, acir_witness_idx_to_value_map: &WitnessMap, witness: &mut [Option], - transcript: &mut ProverState, + transcript: &mut ProverState, ) { match self { WitnessBuilder::Constant(ConstantTerm(witness_idx, c)) => { @@ -93,9 +92,8 @@ impl WitnessBuilderSolver for WitnessBuilder { } } WitnessBuilder::Challenge(witness_idx) => { - let mut one = [FieldElement::zero(); 1]; - let _ = transcript.fill_challenge_scalars(&mut one); - witness[*witness_idx] = Some(one[0]); + let challenge: FieldElement = transcript.verifier_message(); + witness[*witness_idx] = Some(challenge); } WitnessBuilder::LogUpDenominator( witness_idx, diff --git a/provekit/r1cs-compiler/src/ntt.rs b/provekit/r1cs-compiler/src/ntt.rs index f977cd46..28eddaff 100644 --- a/provekit/r1cs-compiler/src/ntt.rs +++ b/provekit/r1cs-compiler/src/ntt.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] // Remove once RSFr is used for WHIR -use {ark_bn254::Fr, ark_ff::AdditiveGroup, ntt::ntt_nr, whir::ntt::ReedSolomon}; +use {ark_bn254::Fr, ark_ff::AdditiveGroup, ntt::ntt_nr, whir::algebra::ntt::ReedSolomon}; pub struct RSFr; impl ReedSolomon for RSFr { @@ -7,29 +7,28 @@ impl ReedSolomon for RSFr { &self, interleaved_coeffs: &[Fr], expansion: usize, - fold_factor: usize, + interleaving_depth: usize, ) -> Vec { debug_assert!(expansion > 0); - interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) + interleaved_rs_encode(interleaved_coeffs, expansion, interleaving_depth) } } fn interleaved_rs_encode( interleaved_coeffs: &[Fr], expansion: usize, - fold_factor: usize, + interleaving_depth: usize, ) -> Vec { - let fold_factor_exp = 2usize.pow(fold_factor as u32); let expanded_size = interleaved_coeffs.len() * expansion; - debug_assert_eq!(expanded_size % fold_factor_exp, 0); + debug_assert_eq!(expanded_size % interleaving_depth, 0); - // 1. Create zero-padded message of appropriate size let mut result = vec![Fr::ZERO; expanded_size]; result[..interleaved_coeffs.len()].copy_from_slice(interleaved_coeffs); - let mut ntt = ntt::NTT::new(result, fold_factor_exp) - .expect("interleaved_coeffs.len() * expansion / 2^fold_factor needs to be a power of two."); + let mut ntt = ntt::NTT::new(result, interleaving_depth).expect( + "interleaved_coeffs.len() * expansion / interleaving_depth needs to be a power of two.", + ); ntt_nr(&mut ntt); diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 385b11c9..6032f1d0 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -1,12 +1,7 @@ use { provekit_common::{utils::next_power_of_two, WhirConfig, WhirR1CSScheme, R1CS}, - std::sync::Arc, - whir::{ - ntt::RSDefault, - parameters::{ - default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, - MultivariateParameters, ProtocolParameters, SoundnessType, - }, + whir::parameters::{ + default_max_pow, FoldingFactor, MultivariateParameters, ProtocolParameters, SoundnessType, }, }; @@ -72,17 +67,11 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { security_level: 128, pow_bits: default_max_pow(nv, 1), folding_factor: FoldingFactor::Constant(4), - leaf_hash_params: (), - two_to_one_params: (), soundness_type: SoundnessType::ConjectureList, - _pow_parameters: Default::default(), starting_log_inv_rate: 1, batch_size, - deduplication_strategy: DeduplicationStrategy::Disabled, - merkle_proof_strategy: MerkleProofStrategy::Uncompressed, + hash_id: whir::hash::SHA2, }; - let reed_solomon = Arc::new(RSDefault); - let basefield_reed_solomon = reed_solomon.clone(); - WhirConfig::new(reed_solomon, basefield_reed_solomon, mv_params, whir_params) + WhirConfig::new(mv_params, &whir_params) } } diff --git a/provekit/verifier/Cargo.toml b/provekit/verifier/Cargo.toml index 7ee1196b..29b71ae7 100644 --- a/provekit/verifier/Cargo.toml +++ b/provekit/verifier/Cargo.toml @@ -14,7 +14,6 @@ provekit-common.workspace = true # Cryptography and proof systems ark-std.workspace = true -spongefish.workspace = true whir.workspace = true # 3rd party diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index 3cd662db..fd113b47 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -14,6 +14,8 @@ pub trait Verify { impl Verify for Verifier { #[instrument(skip_all)] fn verify(&mut self, proof: &NoirProof) -> Result<()> { + provekit_common::register_ntt(); + self.whir_for_witness.take().unwrap().verify( &proof.whir_r1cs_proof, &proof.public_inputs, diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 481f6813..83163c94 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -2,28 +2,20 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ - skyscraper::SkyscraperSponge, - utils::{ - sumcheck::{ - calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, - }, - zk_utils::geometric_till, + utils::sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, }, - FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, - }, - spongefish::{ - codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, - VerifierState, + FieldElement, PublicInputs, TranscriptSponge, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + R1CS, }, tracing::instrument, whir::{ - poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::{ - committer::{reader::ParsedCommitment, CommitmentReader}, - statement::{Statement, Weights}, - utils::HintDeserialize, - verifier::Verifier, + algebra::{ + polynomials::MultilinearPoint, + weights::{Covector, Weights}, }, + protocols::whir::Commitment, + transcript::{codecs::Empty, Proof, VerifierMessage, VerifierState}, }, }; @@ -44,24 +36,35 @@ pub trait WhirR1CSVerifier { impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] - #[allow(unused)] fn verify( &self, proof: &WhirR1CSProof, public_inputs: &PublicInputs, r1cs: &R1CS, ) -> Result<()> { - let io = self.create_io_pattern(); - let mut arthur = io.to_verifier_state(&proof.transcript); + let ds = self.create_domain_separator().instance(&Empty); + let whir_proof = Proof { + narg_string: proof.narg_string.clone(), + hints: proof.hints.clone(), + #[cfg(debug_assertions)] + pattern: proof.pattern.clone(), + }; + let mut arthur = VerifierState::new(&ds, &whir_proof, TranscriptSponge::default()); - let commitment_reader = CommitmentReader::new(&self.whir_witness); - let parsed_commitment_1 = commitment_reader.parse_commitment(&mut arthur)?; + let commitment_1 = self + .whir_witness + .receive_commitment(&mut arthur) + .map_err(|_| anyhow::anyhow!("Failed to parse commitment 1"))?; // Parse second commitment only if we have challenges - let parsed_commitment_2 = if self.num_challenges > 0 { - let mut _logup_challenges = vec![FieldElement::zero(); self.num_challenges]; - arthur.fill_challenge_scalars(&mut _logup_challenges)?; - Some(commitment_reader.parse_commitment(&mut arthur)?) + let commitment_2 = if self.num_challenges > 0 { + let _logup_challenges: Vec = + arthur.verifier_message_vec(self.num_challenges); + Some( + self.whir_witness + .receive_commitment(&mut arthur) + .map_err(|_| anyhow::anyhow!("Failed to parse commitment 2"))?, + ) } else { None }; @@ -72,117 +75,155 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .context("while verifying sumcheck")?; // Verify public inputs hash - let mut public_inputs_hash_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let public_inputs_hash_buf: FieldElement = arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read public inputs hash"))?; let expected_public_inputs_hash = public_inputs.hash(); ensure!( - public_inputs_hash_buf[0] == expected_public_inputs_hash, + public_inputs_hash_buf == expected_public_inputs_hash, "Public inputs hash mismatch: expected {:?}, got {:?}", expected_public_inputs_hash, - public_inputs_hash_buf[0] + public_inputs_hash_buf ); - let mut public_weights_vector_random_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + let public_weights_vector_random: FieldElement = arthur.verifier_message(); // Read hints and verify WHIR proof - let ( - az_at_alpha, - bz_at_alpha, - cz_at_alpha, - whir_folding_randomness, - deferred_evals, - public_weights_challenge, - ) = if let Some(parsed_commitment_2) = parsed_commitment_2 { - // Dual commitment mode - let sums_1: (Vec, Vec) = arthur.hint()?; - let sums_2: (Vec, Vec) = arthur.hint()?; - - let whir_sums_1: ([FieldElement; 3], [FieldElement; 3]) = - (sums_1.0.try_into().unwrap(), sums_1.1.try_into().unwrap()); - let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = - (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - - let mut statement_1 = prepare_statement_for_witness_verifier::<3>( - self.m, - &parsed_commitment_1, - &whir_sums_1, - ); - let statement_2 = prepare_statement_for_witness_verifier::<3>( - self.m, - &parsed_commitment_2, - &whir_sums_2, - ); - - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur - .hint() - .context("failed to read WHIR public weights query answer")?; - - if !public_inputs.is_empty() { - update_statement_for_witness_verifier( + let (az_at_alpha, bz_at_alpha, cz_at_alpha, whir_folding_randomness, deferred_evals) = + if let Some(commitment_2) = commitment_2 { + // Dual commitment mode: read same-commitment and cross-evaluation hints + let sums_1: (Vec, Vec) = arthur + .prover_hint_ark() + .map_err(|_| anyhow::anyhow!("Failed to read sums_1 hint"))?; + let sums_2: (Vec, Vec) = arthur + .prover_hint_ark() + .map_err(|_| anyhow::anyhow!("Failed to read sums_2 hint"))?; + let cross_12: (Vec, Vec) = arthur + .prover_hint_ark() + .map_err(|_| anyhow::anyhow!("Failed to read cross_12 hint"))?; + let cross_21: (Vec, Vec) = arthur + .prover_hint_ark() + .map_err(|_| anyhow::anyhow!("Failed to read cross_21 hint"))?; + + let f_sums_1: [FieldElement; 3] = sums_1.0.try_into().unwrap(); + let g_sums_1: [FieldElement; 3] = sums_1.1.try_into().unwrap(); + let f_sums_2: [FieldElement; 3] = sums_2.0.try_into().unwrap(); + let g_sums_2: [FieldElement; 3] = sums_2.1.try_into().unwrap(); + let cross_f_12: [FieldElement; 3] = cross_12.0.try_into().unwrap(); + let cross_g_12: [FieldElement; 3] = cross_12.1.try_into().unwrap(); + let cross_f_21: [FieldElement; 3] = cross_21.0.try_into().unwrap(); + let cross_g_21: [FieldElement; 3] = cross_21.1.try_into().unwrap(); + + // Build weights and evaluations with full 4-polynomial layout per weight + // weights_1 evaluations: [f1, g1, cross_f12, cross_g12] per weight + let (mut weights_1, mut evaluations_1) = prepare_weights_and_evaluations_dual::<3>( self.m, - &mut statement_1, - &parsed_commitment_1, - whir_public_weights_query_answer, + &f_sums_1, + &g_sums_1, + &cross_f_12, + &cross_g_12, ); - } - - let (whir_folding_randomness, deferred_evals) = run_whir_pcs_batch_verifier( - &mut arthur, - &self.whir_witness, - &[parsed_commitment_1, parsed_commitment_2], - &[statement_1, statement_2], - ) - .context("while verifying WHIR batch proof")?; - - ( - whir_sums_1.0[0] + whir_sums_2.0[0], - whir_sums_1.0[1] + whir_sums_2.0[1], - whir_sums_1.0[2] + whir_sums_2.0[2], - whir_folding_randomness.0.to_vec(), - deferred_evals, - public_weights_vector_random_buf[0], - ) - } else { - // Single commitment mode - let sums: (Vec, Vec) = arthur.hint()?; - let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = - (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - - let mut statement = prepare_statement_for_witness_verifier::<3>( - self.m, - &parsed_commitment_1, - &whir_sums, - ); - - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur - .hint() - .context("failed to read WHIR public weights query answer")?; - if !public_inputs.is_empty() { - update_statement_for_witness_verifier( + // weights_2 evaluations: [cross_f21, cross_g21, f2, g2] per weight + let (weights_2, evaluations_2) = prepare_weights_and_evaluations_dual::<3>( self.m, - &mut statement, - &parsed_commitment_1, - whir_public_weights_query_answer, + &cross_f_21, + &cross_g_21, + &f_sums_2, + &g_sums_2, ); - } - let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( - &mut arthur, - &parsed_commitment_1, - &self.whir_witness, - &statement, - ) - .context("while verifying WHIR proof")?; - - ( - whir_sums.0[0], - whir_sums.0[1], - whir_sums.0[2], - whir_folding_randomness.0.to_vec(), - deferred_evals, - public_weights_vector_random_buf[0], - ) - }; + let public_hint: (FieldElement, FieldElement, FieldElement, FieldElement) = + arthur.prover_hint_ark().map_err(|_| { + anyhow::anyhow!("failed to read WHIR public weights query answer") + })?; + + if !public_inputs.is_empty() { + update_weights_and_evaluations_dual( + self.m, + &mut weights_1, + &mut evaluations_1, + public_hint, + public_inputs.len(), + public_weights_vector_random, + ); + } + + let mut all_weights = weights_1; + all_weights.extend(weights_2); + + let mut all_evaluations = evaluations_1; + all_evaluations.extend(evaluations_2); + + let weight_refs: Vec<&dyn Weights> = all_weights + .iter() + .map(|w| w as &dyn Weights) + .collect(); + let commitment_refs: Vec<&Commitment> = + vec![&commitment_1, &commitment_2]; + + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( + &mut arthur, + &self.whir_witness, + &commitment_refs, + &weight_refs, + &all_evaluations, + ) + .context("while verifying WHIR batch proof")?; + + ( + f_sums_1[0] + f_sums_2[0], + f_sums_1[1] + f_sums_2[1], + f_sums_1[2] + f_sums_2[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + ) + } else { + // Single commitment mode + let sums: (Vec, Vec) = arthur + .prover_hint_ark() + .map_err(|_| anyhow::anyhow!("Failed to read sums hint"))?; + let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = + (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); + + let (mut weights, mut evaluations) = + prepare_weights_and_evaluations::<3>(self.m, &whir_sums); + + let whir_public_weights_query_answer: (FieldElement, FieldElement) = + arthur.prover_hint_ark().map_err(|_| { + anyhow::anyhow!("failed to read WHIR public weights query answer") + })?; + if !public_inputs.is_empty() { + update_weights_and_evaluations( + self.m, + &mut weights, + &mut evaluations, + whir_public_weights_query_answer, + public_inputs.len(), + public_weights_vector_random, + ); + } + + let weight_refs: Vec<&dyn Weights> = weights + .iter() + .map(|w| w as &dyn Weights) + .collect(); + + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( + &mut arthur, + &self.whir_witness, + &[&commitment_1], + &weight_refs, + &evaluations, + ) + .context("while verifying WHIR proof")?; + + ( + whir_sums.0[0], + whir_sums.0[1], + whir_sums.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + ) + }; // Check the Spartan sumcheck relation ensure!( @@ -195,13 +236,10 @@ impl WhirR1CSVerifier for WhirR1CSScheme { "last sumcheck value does not match" ); - // Check deferred linear and geometric constraints - let offset = if public_inputs.is_empty() { 0 } else { 1 }; - - // Linear deferred + // Check deferred linear constraints. if self.num_challenges > 0 { assert!( - deferred_evals.len() == offset + 6, + deferred_evals.len() == 6, "Deferred evals length does not match" ); @@ -213,14 +251,14 @@ impl WhirR1CSVerifier for WhirR1CSScheme { ); for i in 0..6 { ensure!( - matrix_extension_evals[i] == deferred_evals[offset + i], + matrix_extension_evals[i] == deferred_evals[i], "Matrix extension evaluation {} does not match deferred value", i ); } } else { assert!( - deferred_evals.len() == offset + 3, + deferred_evals.len() == 3, "Deferred evals length does not match" ); @@ -232,115 +270,211 @@ impl WhirR1CSVerifier for WhirR1CSScheme { for i in 0..3 { ensure!( - matrix_extension_evals[i] == deferred_evals[offset + i], + matrix_extension_evals[i] == deferred_evals[i], "Matrix extension evaluation {} does not match deferred value", i ); } } - // Geometric deferred - if !public_inputs.is_empty() && deferred_evals.len() > 0 { - let public_weight_eval = compute_public_weight_evaluation( - public_inputs, - &whir_folding_randomness, - public_weights_challenge, - ); - ensure!( - public_weight_eval == deferred_evals[0], - "Public weight evaluation does not match deferred value" - ); - } - Ok(()) } } -fn prepare_statement_for_witness_verifier( - m: usize, - parsed_commitment: &ParsedCommitment, +/// Build weights and evaluations for the verifier, mirroring the prover's +/// `create_weights_and_evaluations_for_two_polynomials`. +/// +/// Each weight is a linear constraint with a zero-filled evaluation list (the +/// verifier doesn't know the polynomial, so the weight itself is deferred). +/// The claimed evaluations come from the prover's hints: f_sums and g_sums +/// interleaved as [f_sum_i, g_sum_i] for each constraint. +fn prepare_weights_and_evaluations( + cfg_nv: usize, whir_query_answer_sums: &([FieldElement; N], [FieldElement; N]), -) -> Statement { - let mut statement_verifier = Statement::::new(m); - for i in 0..whir_query_answer_sums.0.len() { - let claimed_sum = whir_query_answer_sums.0[i] - + whir_query_answer_sums.1[i] * parsed_commitment.batching_randomness; - statement_verifier.add_constraint( - Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])), - claimed_sum, - ); +) -> (Vec>, Vec) { + let final_len = 1usize << cfg_nv; + + let mut weights = Vec::with_capacity(N); + let mut evaluations = Vec::with_capacity(N * 2); + + for i in 0..N { + let weight = Covector::new(vec![FieldElement::zero(); final_len]); + weights.push(weight); + + // Each weight evaluates against 2 polynomials (masked + random) → 2 evaluations + // per weight + evaluations.push(whir_query_answer_sums.0[i]); // f_sum (masked polynomial) + evaluations.push(whir_query_answer_sums.1[i]); // g_sum (random + // polynomial) } - statement_verifier + + (weights, evaluations) } -fn update_statement_for_witness_verifier( +/// Add a public weight constraint at the front, mirroring the prover's +/// `compute_public_weight_evaluations` which inserts at position 0. +/// +/// The weight must be `Weights::geometric` to match the prover (not +/// `Weights::linear`), because `Geometric` is non-deferred and the verifier +/// computes its value itself. +fn update_weights_and_evaluations( m: usize, - statement_verifier: &mut Statement, - parsed_commitment: &ParsedCommitment, + weights: &mut Vec>, + evaluations: &mut Vec, whir_public_weights_query_answer: (FieldElement, FieldElement), + public_inputs_len: usize, + x: FieldElement, ) { + let domain_size = 1usize << m; + let mut public_weight_evals = vec![FieldElement::zero(); domain_size]; + let mut current_pow = FieldElement::one(); + for slot in public_weight_evals.iter_mut().take(public_inputs_len) { + *slot = current_pow; + current_pow *= x; + } + let mut public_weight = Covector::new(public_weight_evals); + public_weight.deferred = false; let (public_f_sum, public_g_sum) = whir_public_weights_query_answer; - let public_weight = Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])); - statement_verifier.add_constraint_in_front( - public_weight, - public_f_sum + public_g_sum * parsed_commitment.batching_randomness, - ); + weights.insert(0, public_weight); + evaluations.insert(0, public_g_sum); + evaluations.insert(0, public_f_sum); +} + +/// Build weights and evaluations for the dual-commitment verifier path. +/// +/// Each weight produces 4 evaluations (one per polynomial across both +/// commitments): [eval_c1_masked, eval_c1_random, eval_c2_masked, +/// eval_c2_random]. This matches whir's row-major evaluation matrix layout. +fn prepare_weights_and_evaluations_dual( + cfg_nv: usize, + evals_c1_masked: &[FieldElement; N], + evals_c1_random: &[FieldElement; N], + evals_c2_masked: &[FieldElement; N], + evals_c2_random: &[FieldElement; N], +) -> (Vec>, Vec) { + let final_len = 1usize << cfg_nv; + + let mut weights = Vec::with_capacity(N); + let mut evaluations = Vec::with_capacity(N * 4); + + for i in 0..N { + let weight = Covector::new(vec![FieldElement::zero(); final_len]); + weights.push(weight); + + evaluations.push(evals_c1_masked[i]); + evaluations.push(evals_c1_random[i]); + evaluations.push(evals_c2_masked[i]); + evaluations.push(evals_c2_random[i]); + } + + (weights, evaluations) +} + +/// Add a public weight for dual-commitment at the front, with 4 evaluations. +/// Must use `Weights::geometric` to match the prover's non-deferred weight +/// type. +fn update_weights_and_evaluations_dual( + m: usize, + weights: &mut Vec>, + evaluations: &mut Vec, + public_hint: (FieldElement, FieldElement, FieldElement, FieldElement), + public_inputs_len: usize, + x: FieldElement, +) { + let domain_size = 1usize << m; + let mut public_weight_evals = vec![FieldElement::zero(); domain_size]; + let mut current_pow = FieldElement::one(); + for slot in public_weight_evals.iter_mut().take(public_inputs_len) { + *slot = current_pow; + current_pow *= x; + } + let mut public_weight = Covector::new(public_weight_evals); + public_weight.deferred = false; + let (f1, g1, f2, g2) = public_hint; + weights.insert(0, public_weight); + evaluations.insert(0, g2); + evaluations.insert(0, f2); + evaluations.insert(0, g1); + evaluations.insert(0, f1); } #[instrument(skip_all)] pub fn run_sumcheck_verifier( - arthur: &mut VerifierState, + arthur: &mut VerifierState<'_, TranscriptSponge>, m_0: usize, whir_for_spartan_blinding_config: &WhirConfig, ) -> Result { - let mut r = vec![FieldElement::zero(); m_0]; - let _ = arthur.fill_challenge_scalars(&mut r); + let r: Vec = arthur.verifier_message_vec(m_0); - let commitment_reader = CommitmentReader::new(whir_for_spartan_blinding_config); - let parsed_commitment = commitment_reader.parse_commitment(arthur)?; + let commitment = whir_for_spartan_blinding_config + .receive_commitment(arthur) + .map_err(|_| anyhow::anyhow!("Failed to parse spartan blinding commitment"))?; - let mut sum_g_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut sum_g_buf)?; + let sum_g: FieldElement = arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read sum_g"))?; - let mut rho_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut rho_buf)?; - let rho = rho_buf[0]; + let rho: FieldElement = arthur.verifier_message(); - let mut saved_val_for_sumcheck_equality_assertion = rho * sum_g_buf[0]; + let mut saved_val_for_sumcheck_equality_assertion = rho * sum_g; let mut alpha = vec![FieldElement::zero(); m_0]; for item in alpha.iter_mut().take(m_0) { - let mut hhat_i = [FieldElement::zero(); 4]; - let mut alpha_i = [FieldElement::zero(); 1]; - let _ = arthur.fill_next_scalars(&mut hhat_i); - let _ = arthur.fill_challenge_scalars(&mut alpha_i); - *item = alpha_i[0]; + let hhat_i: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read hhat coeff"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read hhat coeff"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read hhat coeff"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read hhat coeff"))?, + ]; + let alpha_i: FieldElement = arthur.verifier_message(); + *item = alpha_i; let hhat_i_at_zero = eval_cubic_poly(hhat_i, FieldElement::zero()); let hhat_i_at_one = eval_cubic_poly(hhat_i, FieldElement::one()); ensure!( saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, "Sumcheck equality assertion failed" ); - saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i[0]); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i); } - let mut values_of_polynomial_sums = [FieldElement::zero(); 2]; - let _ = arthur.fill_next_scalars(&mut values_of_polynomial_sums); + let values_of_polynomial_sums: [FieldElement; 2] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read polynomial sum"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read polynomial sum"))?, + ]; - let statement_verifier = prepare_statement_for_witness_verifier::<1>( - whir_for_spartan_blinding_config.mv_parameters.num_variables, - &parsed_commitment, + let blinding_nv = whir_for_spartan_blinding_config.initial_num_variables(); + + let (blinding_weights, blinding_evaluations) = prepare_weights_and_evaluations::<1>( + blinding_nv, &([values_of_polynomial_sums[0]], [ values_of_polynomial_sums[1] ]), ); + let blinding_weight_refs: Vec<&dyn Weights> = blinding_weights + .iter() + .map(|w| w as &dyn Weights) + .collect(); + run_whir_pcs_verifier( arthur, - &parsed_commitment, whir_for_spartan_blinding_config, - &statement_verifier, + &[&commitment], + &blinding_weight_refs, + &blinding_evaluations, ) .context("while verifying WHIR")?; @@ -355,29 +489,16 @@ pub fn run_sumcheck_verifier( #[instrument(skip_all)] pub fn run_whir_pcs_verifier( - arthur: &mut VerifierState, - parsed_commitment: &ParsedCommitment, + arthur: &mut VerifierState<'_, TranscriptSponge>, params: &WhirConfig, - statement_verifier: &Statement, + commitments: &[&Commitment], + weights: &[&dyn Weights], + evaluations: &[FieldElement], ) -> Result<(MultilinearPoint, Vec)> { - let verifier = Verifier::new(params); - let (folding_randomness, deferred) = verifier - .verify(arthur, parsed_commitment, statement_verifier) - .context("while verifying WHIR")?; - Ok((folding_randomness, deferred)) -} - -#[instrument(skip_all)] -pub fn run_whir_pcs_batch_verifier( - arthur: &mut VerifierState, - params: &WhirConfig, - parsed_commitments: &[ParsedCommitment], - statements: &[Statement], -) -> Result<(MultilinearPoint, Vec)> { - let verifier = Verifier::new(params); - let (folding_randomness, deferred) = verifier - .verify_batch(arthur, parsed_commitments, statements) - .context("while verifying batch WHIR")?; + let (folding_randomness, deferred) = + params + .verify(arthur, commitments, weights, evaluations) + .map_err(|_| anyhow::anyhow!("WHIR verification failed"))?; Ok((folding_randomness, deferred)) } @@ -446,11 +567,3 @@ fn evaluate_r1cs_matrix_extension_batch( ans } - -fn compute_public_weight_evaluation( - public_inputs: &PublicInputs, - folding_randomness: &[FieldElement], - x: FieldElement, -) -> FieldElement { - geometric_till(x, public_inputs.len(), folding_randomness) -} diff --git a/tooling/cli/Cargo.toml b/tooling/cli/Cargo.toml index 76436aee..9759420d 100644 --- a/tooling/cli/Cargo.toml +++ b/tooling/cli/Cargo.toml @@ -26,6 +26,7 @@ base64.workspace = true postcard.workspace = true serde.workspace = true serde_json.workspace = true +tikv-jemallocator = { workspace = true, optional = true } tracing.workspace = true tracing-subscriber.workspace = true tracing-tracy = { workspace = true, optional = true, features = ["default", "sampling","manual-lifetime"] } @@ -36,4 +37,5 @@ workspace = true [features] default = ["profiling-allocator"] profiling-allocator = [] +jemalloc = ["profiling-allocator", "dep:tikv-jemallocator"] tracy = ["dep:tracing-tracy"] diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index 0e3cc6d0..5ceb2679 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -55,8 +55,7 @@ impl Command for Args { write_gnark_parameters_to_file( &prover.whir_for_witness.whir_witness, &prover.whir_for_witness.whir_for_hiding_spartan, - &proof.whir_r1cs_proof.transcript, - &prover.whir_for_witness.create_io_pattern(), + &proof.whir_r1cs_proof, prover.whir_for_witness.m_0, prover.whir_for_witness.m, prover.whir_for_witness.a_num_terms, diff --git a/tooling/cli/src/main.rs b/tooling/cli/src/main.rs index e896aada..d273bace 100644 --- a/tooling/cli/src/main.rs +++ b/tooling/cli/src/main.rs @@ -13,7 +13,7 @@ use { anyhow::Result, span_stats::SpanStats, tracing::subscriber, - tracing_subscriber::{self, layer::SubscriberExt as _, Registry}, + tracing_subscriber::{self, filter::LevelFilter, layer::SubscriberExt as _, Layer, Registry}, }; #[cfg(feature = "profiling-allocator")] @@ -22,7 +22,13 @@ static ALLOCATOR: ProfilingAllocator = ProfilingAllocator::new(); fn main() -> Result<()> { let args = argh::from_env::(); - let subscriber = Registry::default().with(SpanStats); + // Debug builds: track ALL spans for detailed profiling. + // Release builds: only INFO+ to reduce overhead. + #[cfg(debug_assertions)] + let level = LevelFilter::TRACE; + #[cfg(not(debug_assertions))] + let level = LevelFilter::INFO; + let subscriber = Registry::default().with(SpanStats.with_filter(level)); #[cfg(feature = "tracy")] let subscriber = { diff --git a/tooling/cli/src/profiling_alloc.rs b/tooling/cli/src/profiling_alloc.rs index c1fae147..8452423f 100644 --- a/tooling/cli/src/profiling_alloc.rs +++ b/tooling/cli/src/profiling_alloc.rs @@ -1,10 +1,15 @@ use std::{ - alloc::{GlobalAlloc, Layout, System as SystemAlloc}, + alloc::{GlobalAlloc, Layout}, sync::atomic::{AtomicUsize, Ordering}, }; #[cfg(feature = "tracy")] use {std::sync::atomic::AtomicBool, tracing_tracy::client::sys as tracy_sys}; +#[cfg(feature = "jemalloc")] +static BACKING: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +#[cfg(not(feature = "jemalloc"))] +static BACKING: std::alloc::System = std::alloc::System; + /// Custom allocator that keeps track of statistics to see program memory /// consumption. pub struct ProfilingAllocator { @@ -115,7 +120,7 @@ impl ProfilingAllocator { #[allow(unsafe_code)] unsafe impl GlobalAlloc for ProfilingAllocator { unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - let ptr = SystemAlloc.alloc(layout); + let ptr = BACKING.alloc(layout); let size = layout.size(); let current = self .current @@ -130,11 +135,11 @@ unsafe impl GlobalAlloc for ProfilingAllocator { unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { self.current.fetch_sub(layout.size(), Ordering::SeqCst); self.tracy_dealloc(ptr); - SystemAlloc.dealloc(ptr, layout); + BACKING.dealloc(ptr, layout); } unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { - let ptr = SystemAlloc.alloc_zeroed(layout); + let ptr = BACKING.alloc_zeroed(layout); let size = layout.size(); let current = self .current @@ -148,7 +153,7 @@ unsafe impl GlobalAlloc for ProfilingAllocator { unsafe fn realloc(&self, ptr: *mut u8, old_layout: Layout, new_size: usize) -> *mut u8 { self.tracy_dealloc(ptr); - let ptr = SystemAlloc.realloc(ptr, old_layout, new_size); + let ptr = BACKING.realloc(ptr, old_layout, new_size); let old_size = old_layout.size(); if new_size > old_size { let diff = new_size - old_size; diff --git a/tooling/provekit-ffi/src/ffi.rs b/tooling/provekit-ffi/src/ffi.rs index 1aeab548..9cd8c27e 100644 --- a/tooling/provekit-ffi/src/ffi.rs +++ b/tooling/provekit-ffi/src/ffi.rs @@ -163,8 +163,8 @@ pub unsafe extern "C" fn pk_free_buf(buf: PKBuf) { /// Returns `PKError::Success` on success. #[no_mangle] pub extern "C" fn pk_init() -> c_int { - // Initialize tracing/logging if needed - // For now, we'll keep it simple and just return success + // TODO: Initialize tracing/logging for FFI consumers. + provekit_common::register_ntt(); PKError::Success.into() } diff --git a/tooling/provekit-gnark/Cargo.toml b/tooling/provekit-gnark/Cargo.toml index 7558f054..da48874c 100644 --- a/tooling/provekit-gnark/Cargo.toml +++ b/tooling/provekit-gnark/Cargo.toml @@ -14,6 +14,7 @@ provekit-common.workspace = true # Cryptography and proof systems ark-poly.workspace = true +whir.workspace = true # 3rd party serde.workspace = true diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 41439a75..ad52dea9 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,116 +1,141 @@ use { ark_poly::EvaluationDomain, - provekit_common::{IOPattern, PublicInputs, WhirConfig}, + provekit_common::{FieldElement, PublicInputs, WhirConfig, WhirR1CSProof}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, + whir::algebra::domain::Domain, }; +/// Configuration for the Gnark recursive verifier. #[derive(Debug, Serialize, Deserialize)] -/// Configuration for Gnark pub struct GnarkConfig { - /// WHIR parameters for witness - pub whir_config_witness: WHIRConfigGnark, - /// WHIR parameters for hiding spartan + /// WHIR parameters for witness commitment. + pub whir_config_witness: WHIRConfigGnark, + /// WHIR parameters for hiding Spartan. pub whir_config_hiding_spartan: WHIRConfigGnark, - /// log of number of constraints in R1CS - pub log_num_constraints: usize, - /// log of number of variables in R1CS - pub log_num_variables: usize, - /// log of number of non-zero terms matrix A - pub log_a_num_terms: usize, - /// nimue input output pattern - pub io_pattern: String, - /// transcript in byte form - pub transcript: Vec, - /// length of the transcript - pub transcript_len: usize, - /// number of logup challenges (0 = single commitment mode) - pub num_challenges: usize, - /// size of w1 - pub w1_size: usize, - /// public inputs - pub public_inputs: PublicInputs, + /// log₂ of number of constraints in R1CS. + pub log_num_constraints: usize, + /// log₂ of number of variables in R1CS. + pub log_num_variables: usize, + /// log₂ of number of non-zero terms in matrix A. + pub log_a_num_terms: usize, + /// Spongefish NARG string (transcript interaction pattern). + pub narg_string: Vec, + /// Explicit length for Go deserialisation. + pub narg_string_len: usize, + /// Prover hints (serialised transcript data). + pub hints: Vec, + /// Explicit length for Go deserialisation. + pub hints_len: usize, + /// Number of LogUp challenges (0 = single commitment mode). + pub num_challenges: usize, + /// Size of w1 partition. + pub w1_size: usize, + /// Public inputs to the circuit. + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] - pub struct WHIRConfigGnark { - /// number of rounds + /// Number of WHIR rounds. pub n_rounds: usize, - /// rate + /// Reed-Solomon rate (log₂ of inverse rate). pub rate: usize, - /// number of variables + /// Number of variables in the multilinear polynomial. pub n_vars: usize, - /// folding factor + /// Folding factor per round. pub folding_factor: Vec, - /// out of domain samples + /// Out-of-domain samples per round. pub ood_samples: Vec, - /// number of queries + /// Number of queries per round. pub num_queries: Vec, - /// proof of work bits + /// Proof-of-work bits per round. pub pow_bits: Vec, - /// final queries + /// Final round query count. pub final_queries: usize, - /// final proof of work bits + /// Final round proof-of-work bits. pub final_pow_bits: i32, - /// final folding proof of work bits + /// Final folding proof-of-work bits. pub final_folding_pow_bits: i32, - /// domain generator string + /// Domain generator as a string. pub domain_generator: String, - /// batch size + /// Batch size (number of polynomials committed together). pub batch_size: usize, } impl WHIRConfigGnark { pub fn new(whir_params: &WhirConfig) -> Self { + let n_rounds = whir_params.n_rounds(); + let n_vars = whir_params.initial_num_variables(); + let rate = whir_params.initial_committer.expansion.ilog2() as usize; + + // Folding factor: initial round uses initial_sumcheck.num_rounds, + // subsequent rounds use round_configs[i].sumcheck.num_rounds + let mut folding_factor = Vec::with_capacity(n_rounds + 1); + folding_factor.push(whir_params.initial_sumcheck.num_rounds); + for rc in &whir_params.round_configs { + folding_factor.push(rc.sumcheck.num_rounds); + } + + let ood_samples: Vec = whir_params + .round_configs + .iter() + .map(|rc| rc.irs_committer.out_domain_samples) + .collect(); + + let num_queries: Vec = whir_params + .round_configs + .iter() + .map(|rc| rc.irs_committer.in_domain_samples) + .collect(); + + let pow_bits: Vec = whir_params + .round_configs + .iter() + .map(|rc| { + f64::from(whir::protocols::proof_of_work::difficulty(rc.pow.threshold)) as i32 + }) + .collect(); + + let final_queries = whir_params.final_in_domain_samples(); + let final_pow_bits = f64::from(whir::protocols::proof_of_work::difficulty( + whir_params.final_pow.threshold, + )) as i32; + let final_folding_pow_bits = f64::from(whir::protocols::proof_of_work::difficulty( + whir_params.final_sumcheck.round_pow.threshold, + )) as i32; + + // Reconstruct the starting domain to get its generator + let domain = Domain::::new(1 << n_vars, rate) + .expect("Should have found an appropriate domain"); + let domain_generator = format!("{}", domain.backing_domain.group_gen()); + + let batch_size = whir_params.initial_committer.num_polynomials; + WHIRConfigGnark { - n_rounds: whir_params - .folding_factor - .compute_number_of_rounds(whir_params.mv_parameters.num_variables) - .0, - rate: whir_params.starting_log_inv_rate, - n_vars: whir_params.mv_parameters.num_variables, - folding_factor: (0..(whir_params - .folding_factor - .compute_number_of_rounds(whir_params.mv_parameters.num_variables) - .0)) - .map(|round| whir_params.folding_factor.at_round(round)) - .collect(), - ood_samples: whir_params - .round_parameters - .iter() - .map(|x| x.ood_samples) - .collect(), - num_queries: whir_params - .round_parameters - .iter() - .map(|x| x.num_queries) - .collect(), - pow_bits: whir_params - .round_parameters - .iter() - .map(|x| x.pow_bits as i32) - .collect(), - final_queries: whir_params.final_queries, - final_pow_bits: whir_params.final_pow_bits as i32, - final_folding_pow_bits: whir_params.final_folding_pow_bits as i32, - domain_generator: format!( - "{}", - whir_params.starting_domain.backing_domain.group_gen() - ), - batch_size: whir_params.batch_size, + n_rounds, + rate, + n_vars, + folding_factor, + ood_samples, + num_queries, + pow_bits, + final_queries, + final_pow_bits, + final_folding_pow_bits, + domain_generator, + batch_size, } } } -/// Writes config used for Gnark circuit to a file +/// Build the Gnark recursive verifier configuration. #[instrument(skip_all)] pub fn gnark_parameters( whir_params_witness: &WhirConfig, whir_params_hiding_spartan: &WhirConfig, - transcript: &[u8], - io: &IOPattern, + proof: &WhirR1CSProof, m_0: usize, m: usize, a_num_terms: usize, @@ -124,22 +149,22 @@ pub fn gnark_parameters( log_num_constraints: m_0, log_num_variables: m, log_a_num_terms: a_num_terms, - io_pattern: String::from_utf8(io.as_bytes().to_vec()).unwrap(), - transcript: transcript.to_vec(), - transcript_len: transcript.to_vec().len(), + narg_string: proof.narg_string.clone(), + narg_string_len: proof.narg_string.len(), + hints: proof.hints.clone(), + hints_len: proof.hints.len(), num_challenges, w1_size, public_inputs: public_inputs.clone(), } } -/// Writes config used for Gnark circuit to a file +/// Serialize the Gnark configuration to a JSON file. #[instrument(skip_all)] pub fn write_gnark_parameters_to_file( whir_params_witness: &WhirConfig, whir_params_hiding_spartan: &WhirConfig, - transcript: &[u8], - io: &IOPattern, + proof: &WhirR1CSProof, m_0: usize, m: usize, a_num_terms: usize, @@ -151,8 +176,7 @@ pub fn write_gnark_parameters_to_file( let gnark_config = gnark_parameters( whir_params_witness, whir_params_hiding_spartan, - transcript, - io, + proof, m_0, m, a_num_terms, diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index 398b282c..4d0c5629 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -90,8 +90,7 @@ impl VerificationService { write_gnark_parameters_to_file( &whir_scheme.whir_witness, &whir_scheme.whir_for_hiding_spartan, - &proof.whir_r1cs_proof.transcript, - &whir_scheme.create_io_pattern(), + &proof.whir_r1cs_proof, whir_scheme.m_0, whir_scheme.m, whir_scheme.a_num_terms,