Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper ZK treatment in plonky2 #1625

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions plonky2/src/batch_fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::extension::Extendable;
Expand All @@ -19,6 +19,7 @@ use crate::hash::batch_merkle_tree::BatchMerkleTree;
use crate::hash::hash_types::RichField;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -151,9 +152,15 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let is_zk = fri_params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -165,6 +172,28 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

if is_zk && idx == 0 {
let degree = 1 << degree_bits[i];
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

assert_eq!(final_poly.len(), 1 << degree_bits[i]);
Expand Down
10 changes: 8 additions & 2 deletions plonky2/src/batch_fri/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
fri_params: &FriParams,
timing: &mut TimingTree,
) -> FriProof<F, C::Hasher, D> {
let n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), n);
let mut n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), lde_polynomial_coeffs.len());
// The polynomial vectors should be sorted by degree, from largest to smallest, with no duplicate degrees.
assert!(lde_polynomial_values
.windows(2)
Expand All @@ -49,6 +49,12 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
}
assert_eq!(cur_poly_index, lde_polynomial_values.len());

// In the zk case, the final polynomial polynomial to be reduced has degree double that
// of the original batch FRI polynomial.
if fri_params.hiding {
n /= 2;
}

// Commit phase
let (trees, final_coeffs) = timed!(
timing,
Expand Down
54 changes: 49 additions & 5 deletions plonky2/src/batch_fri/recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use alloc::{format, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;

use crate::field::extension::Extendable;
use crate::fri::proof::{
Expand All @@ -15,6 +16,7 @@ use crate::iop::ext_target::{flatten_target, ExtensionTarget};
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactorTarget;
use crate::with_context;

Expand Down Expand Up @@ -62,7 +64,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
PrecomputedReducedOpeningsTarget::from_os_and_alpha(
opn,
challenges.fri_alpha,
self
self,
params.hiding,
)
);
precomputed_reduced_evals.push(pre);
Expand Down Expand Up @@ -165,13 +168,24 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut alpha = ReducingFactorTarget::new(alpha);
let mut sum = self.zero_extension();

for (batch, reduced_openings) in instance[index]
for (idx, (batch, reduced_openings)) in instance[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to multiply R(X) by alpha^n? You would always get a uniform poly, and is not what is written in the paper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had seen that the indices in the paper started at 1 and assumed that the alpha for the quotient poly was shifted because of R (as in R was the first element in the sum: R + alpha q_1 + ...). But upon looking at the code, you are right: we already start with alpha and not 1, so I'll remove the multiplication by alpha^n.

// of the batch polynomial.
let FriBatchInfoTarget { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
Expand All @@ -184,6 +198,31 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let denominator = self.sub_extension(subgroup_x, *point);
sum = alpha.shift(sum, self);
sum = self.div_add_extension(numerator, denominator, sum);

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum, self);
let val = self
.constant_extension(F::Extension::from_canonical_u32((i == 0) as u32));
let power =
self.exp_power_of_2_extension(subgroup_x, i * params.degree_bits);
let pi =
self.constant_extension(F::Extension::from_canonical_u32(i as u32));
let power = self.mul_extension(power, pi);
let shift_val = self.add_extension(val, power);

let eval_extension = eval.to_ext_target(self.zero());
let tmp = self.mul_extension(eval_extension, shift_val);
sum = self.add_extension(sum, tmp);
});
}
}

sum
Expand All @@ -210,7 +249,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
Self::assert_noncanonical_indices_ok(&params.config);
let mut x_index_bits = self.low_bits(x_index, n, F::BITS);

let cap_index =
let initial_cap_index =
self.le_sum(x_index_bits[x_index_bits.len() - params.config.cap_height..].iter());
with_context!(
self,
Expand All @@ -221,7 +260,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&x_index_bits,
&round_proof.initial_trees_proof,
initial_merkle_caps,
cap_index
initial_cap_index
)
);

Expand Down Expand Up @@ -252,6 +291,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
batch_index += 1;

// In case of zk, the finaly polynomial's degree bits is increased by 1.
let cap_index = self.le_sum(
x_index_bits[x_index_bits.len() + params.hiding as usize - params.config.cap_height..]
.iter(),
);
for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() {
let evals = &round_proof.steps[i].evals;

Expand Down
36 changes: 33 additions & 3 deletions plonky2/src/batch_fri/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_proofs::{verify_batch_merkle_proof_to_cap, verify_merkle_proof_to_cap};
use crate::hash::merkle_tree::MerkleCap;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactor;
use crate::util::reverse_bits;

Expand Down Expand Up @@ -46,7 +47,8 @@ pub fn verify_batch_fri_proof<

let mut precomputed_reduced_evals = Vec::with_capacity(openings.len());
for opn in openings {
let pre = PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha);
let pre =
PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha, params.hiding);
precomputed_reduced_evals.push(pre);
}
let degree_bits = degree_bits
Expand Down Expand Up @@ -123,13 +125,24 @@ fn batch_fri_combine_initial<
let mut alpha = ReducingFactor::new(alpha);
let mut sum = F::Extension::ZERO;

for (batch, reduced_openings) in instances[index]
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
// of the batch polynomial.
for (idx, (batch, reduced_openings)) in instances[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
let FriBatchInfo { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
.sum();
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
Expand All @@ -142,6 +155,23 @@ fn batch_fri_combine_initial<
let denominator = subgroup_x - *point;
sum = alpha.shift(sum);
sum += numerator / denominator;

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum);
let shift_val = F::Extension::from_canonical_usize((i == 0) as usize)
+ subgroup_x.exp_power_of_2(i * params.degree_bits)
* F::Extension::from_canonical_usize(i);
sum += F::Extension::from_basefield(eval) * shift_val;
});
}
}

sum
Expand Down
5 changes: 3 additions & 2 deletions plonky2/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl FriConfig {
self.rate_bits,
self.cap_height,
self.num_query_rounds,
hiding,
);
FriParams {
config: self.clone(),
Expand Down Expand Up @@ -87,7 +88,7 @@ pub struct FriParams {

impl FriParams {
pub fn total_arities(&self) -> usize {
self.reduction_arity_bits.iter().sum()
self.reduction_arity_bits.iter().sum::<usize>()
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
}

pub(crate) fn max_arity_bits(&self) -> Option<usize> {
Expand All @@ -103,7 +104,7 @@ impl FriParams {
}

pub fn final_poly_bits(&self) -> usize {
self.degree_bits - self.total_arities()
self.degree_bits + self.hiding as usize - self.total_arities()
}

pub fn final_poly_len(&self) -> usize {
Expand Down
46 changes: 42 additions & 4 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;
Expand All @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_tree::MerkleTree;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -194,9 +195,23 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
//
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
LindaGuiga marked this conversation as resolved.
Show resolved Hide resolved
// of the batch polynomial.
// Then, since the degree of `R` is double that of the batch polynomial in our cimplementation, we need to
// compute one extra step in FRI to reach the correct degree.
let is_zk = fri_params.hiding;

for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials` until `last_poly`.
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -208,6 +223,29 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
let degree = 1 << oracles[0].degree_log;
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

let lde_final_poly = final_poly.lde(fri_params.config.rate_bits);
Expand Down
Loading