diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index f2a3f2cb0c..ce94fca704 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -5,12 +5,16 @@ #![allow(clippy::upper_case_acronyms)] +extern crate alloc; +use alloc::sync::Arc; use core::num::ParseIntError; use core::ops::RangeInclusive; use core::str::FromStr; use anyhow::{anyhow, Context as _, Result}; +use itertools::Itertools; use log::{info, Level, LevelFilter}; +use plonky2::gadgets::lookup::TIP5_TABLE; use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; @@ -59,6 +63,12 @@ struct Options { /// range. #[structopt(long, default_value="14", parse(try_from_str = parse_range_usize))] size: RangeInclusive, + + /// Lookup type. If `lookup_type == 0` or `lookup_type > 2`, then a benchmark with NoopGates only is run. + /// If `lookup_type == 1`, a benchmark with one lookup is run. + /// If `lookup_type == 2`, a benchmark with 515 lookups is run. + #[structopt(long, default_value="0", parse(try_from_str = parse_hex_u64))] + lookup_type: u64, } /// Creates a dummy proof which should have `2 ** log2_size` rows. @@ -91,6 +101,101 @@ fn dummy_proof, C: GenericConfig, const D Ok((proof, data.verifier_only, data.common)) } +fn dummy_lookup_proof, C: GenericConfig, const D: usize>( + config: &CircuitConfig, + log2_size: usize, +) -> Result> { + let mut builder = CircuitBuilder::::new(config.clone()); + let tip5_table = TIP5_TABLE.to_vec(); + let inps = 0..256; + let table = Arc::new(inps.zip_eq(tip5_table).collect()); + let tip5_idx = builder.add_lookup_table_from_pairs(table); + let initial_a = builder.add_virtual_target(); + builder.add_lookup_from_index(initial_a, tip5_idx); + builder.register_public_input(initial_a); + + // 'size' is in degree, but we want the number of gates in the circuit. + // A non-zero amount of padding will be added and size will be rounded to the next power of two. + // To hit our target size, we go just under the previous power of two and hope padding is less than half the proof. + let targeted_num_gates = match log2_size { + 0 => return Err(anyhow!("size must be at least 1")), + 1 => 0, + 2 => 1, + n => (1 << (n - 1)) + 1, + }; + assert!( + targeted_num_gates >= builder.num_gates(), + "size is too small to support lookups" + ); + + for _ in builder.num_gates()..targeted_num_gates { + builder.add_gate(NoopGate, vec![]); + } + builder.print_gate_counts(0); + + let data = builder.build::(); + let mut inputs = PartialWitness::::new(); + inputs.set_target(initial_a, F::ONE); + let mut timing = TimingTree::new("prove with one lookup", Level::Debug); + let proof = prove(&data.prover_only, &data.common, inputs, &mut timing)?; + timing.print(); + data.verify(proof.clone())?; + + Ok((proof, data.verifier_only, data.common)) +} + +/// Creates a dummy proof which has more than 256 lookups to one LUT +fn dummy_many_rows_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + config: &CircuitConfig, + log2_size: usize, +) -> Result> { + let mut builder = CircuitBuilder::::new(config.clone()); + let tip5_table = TIP5_TABLE.to_vec(); + let inps: Vec = (0..256).collect(); + let tip5_idx = builder.add_lookup_table_from_table(&inps, &tip5_table); + let initial_a = builder.add_virtual_target(); + + let output = builder.add_lookup_from_index(initial_a, tip5_idx); + for _ in 0..514 { + builder.add_lookup_from_index(output, 0); + } + + // 'size' is in degree, but we want the number of gates in the circuit. + // A non-zero amount of padding will be added and size will be rounded to the next power of two. + // To hit our target size, we go just under the previous power of two and hope padding is less than half the proof. + let targeted_num_gates = match log2_size { + 0 => return Err(anyhow!("size must be at least 1")), + 1 => 0, + 2 => 1, + n => (1 << (n - 1)) + 1, + }; + assert!( + targeted_num_gates >= builder.num_gates(), + "size is too small to support so many lookups" + ); + + for _ in 0..targeted_num_gates { + builder.add_gate(NoopGate, vec![]); + } + + builder.register_public_input(initial_a); + builder.register_public_input(output); + + let mut pw = PartialWitness::new(); + pw.set_target(initial_a, F::ONE); + let data = builder.build::(); + let mut timing = TimingTree::new("prove with many lookups", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + + data.verify(proof.clone())?; + Ok((proof, data.verifier_only, data.common)) +} + fn recursive_proof< F: RichField + Extendable, C: GenericConfig, @@ -183,16 +288,34 @@ fn test_serialization, C: GenericConfig, Ok(()) } -fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { +pub fn benchmark_function( + config: &CircuitConfig, + log2_inner_size: usize, + lookup_type: u64, +) -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; + let dummy_proof_function = match lookup_type { + 0 => dummy_proof::, + 1 => dummy_lookup_proof::, + 2 => dummy_many_rows_proof::, + _ => dummy_proof::, + }; + + let name = match lookup_type { + 0 => "proof", + 1 => "one lookup proof", + 2 => "multiple lookups proof", + _ => "proof", + }; // Start with a dummy proof of specified size - let inner = dummy_proof::(config, log2_inner_size)?; + let inner = dummy_proof_function(config, log2_inner_size)?; let (_, _, cd) = &inner; info!( - "Initial proof degree {} = 2^{}", + "Initial {} degree {} = 2^{}", + name, cd.degree(), cd.degree_bits() ); @@ -201,7 +324,8 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { let middle = recursive_proof::(&inner, config, None)?; let (_, _, cd) = &middle; info!( - "Single recursion proof degree {} = 2^{}", + "Single recursion {} degree {} = 2^{}", + name, cd.degree(), cd.degree_bits() ); @@ -210,7 +334,8 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { let outer = recursive_proof::(&middle, config, None)?; let (proof, vd, cd) = &outer; info!( - "Double recursion proof degree {} = 2^{}", + "Double recursion {} degree {} = 2^{}", + name, cd.degree(), cd.degree_bits() ); @@ -223,7 +348,6 @@ fn benchmark(config: &CircuitConfig, log2_inner_size: usize) -> Result<()> { fn main() -> Result<()> { // Parse command line arguments, see `--help` for details. let options = Options::from_args_safe()?; - // Initialize logging let mut builder = env_logger::Builder::from_default_env(); builder.parse_filters(&options.log_filter); @@ -246,8 +370,9 @@ fn main() -> Result<()> { let threads = options.threads.unwrap_or(num_cpus..=num_cpus); let config = CircuitConfig::standard_recursion_config(); + for log2_inner_size in options.size { - // Since the `size` is most likely to be and unbounded range we make that the outer iterator. + // Since the `size` is most likely to be an unbounded range we make that the outer iterator. for threads in threads.clone() { rayon::ThreadPoolBuilder::new() .num_threads(threads) @@ -259,8 +384,8 @@ fn main() -> Result<()> { rayon::current_num_threads(), num_cpus ); - // Run the benchmark - benchmark(&config, log2_inner_size) + // Run the benchmark. `options.lookup_type` determines which benchmark to run. + benchmark_function(&config, log2_inner_size, options.lookup_type) })?; } } diff --git a/plonky2/src/gadgets/lookup.rs b/plonky2/src/gadgets/lookup.rs new file mode 100644 index 0000000000..cc60cbe7ea --- /dev/null +++ b/plonky2/src/gadgets/lookup.rs @@ -0,0 +1,126 @@ +use crate::field::extension::Extendable; +use crate::gates::lookup::LookupGate; +use crate::gates::lookup_table::{LookupTable, LookupTableGate}; +use crate::gates::noop::NoopGate; +use crate::hash::hash_types::RichField; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +/// Lookup tables used in the tests and benchmarks. +/// +/// The following table was taken from the Tip5 paper. +pub const TIP5_TABLE: [u16; 256] = [ + 0, 7, 26, 63, 124, 215, 85, 254, 214, 228, 45, 185, 140, 173, 33, 240, 29, 177, 176, 32, 8, + 110, 87, 202, 204, 99, 150, 106, 230, 14, 235, 128, 213, 239, 212, 138, 23, 130, 208, 6, 44, + 71, 93, 116, 146, 189, 251, 81, 199, 97, 38, 28, 73, 179, 95, 84, 152, 48, 35, 119, 49, 88, + 242, 3, 148, 169, 72, 120, 62, 161, 166, 83, 175, 191, 137, 19, 100, 129, 112, 55, 221, 102, + 218, 61, 151, 237, 68, 164, 17, 147, 46, 234, 203, 216, 22, 141, 65, 57, 123, 12, 244, 54, 219, + 231, 96, 77, 180, 154, 5, 253, 133, 165, 98, 195, 205, 134, 245, 30, 9, 188, 59, 142, 186, 197, + 181, 144, 92, 31, 224, 163, 111, 74, 58, 69, 113, 196, 67, 246, 225, 10, 121, 50, 60, 157, 90, + 122, 2, 250, 101, 75, 178, 159, 24, 36, 201, 11, 243, 132, 198, 190, 114, 233, 39, 52, 21, 209, + 108, 238, 91, 187, 18, 104, 194, 37, 153, 34, 200, 143, 126, 155, 236, 118, 64, 80, 172, 89, + 94, 193, 135, 183, 86, 107, 252, 13, 167, 206, 136, 220, 207, 103, 171, 160, 76, 182, 227, 217, + 158, 56, 174, 4, 66, 109, 139, 162, 184, 211, 249, 47, 125, 232, 117, 43, 16, 42, 127, 20, 241, + 25, 149, 105, 156, 51, 53, 168, 145, 247, 223, 79, 78, 226, 15, 222, 82, 115, 70, 210, 27, 41, + 1, 170, 40, 131, 192, 229, 248, 255, +]; + +/// This is a table with 256 arbitrary values. +pub const OTHER_TABLE: [u16; 256] = [ + 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, + 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, + 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, + 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, + 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, + 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, + 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, + 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, + 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, + 3, 9, 7, 0, 3, 25, 35, 10, 19, 36, 45, 216, 247, 35, 39, 57, 126, 2, 6, 25, 3, 9, 7, 0, 3, 25, + 35, 10, 19, 36, 45, 216, 247, +]; + +/// This is a smaller lookup table with arbitrary values. +pub const SMALLER_TABLE: [u16; 8] = [2, 24, 56, 100, 128, 16, 20, 49]; + +impl, const D: usize> CircuitBuilder { + /// Adds a lookup table to the list of stored lookup tables `self.luts` based on a table of (input, output) pairs. It returns the index of the LUT within `self.luts`. + pub fn add_lookup_table_from_pairs(&mut self, table: LookupTable) -> usize { + self.update_luts_from_pairs(table) + } + + /// Adds a lookup table to the list of stored lookup tables `self.luts` based on a table, represented as a slice `&[u16]` of inputs and a slice `&[u16]` of outputs. It returns the index of the LUT within `self.luts`. + pub fn add_lookup_table_from_table(&mut self, inps: &[u16], outs: &[u16]) -> usize { + self.update_luts_from_table(inps, outs) + } + + /// Adds a lookup table to the list of stored lookup tables `self.luts` based on a function. It returns the index of the LUT within `self.luts`. + pub fn add_lookup_table_from_fn(&mut self, f: fn(u16) -> u16, inputs: &[u16]) -> usize { + self.update_luts_from_fn(f, inputs) + } + + /// Adds a lookup (input, output) pair to the stored lookups. Takes a `Target` input and returns a `Target` output. + pub fn add_lookup_from_index(&mut self, looking_in: Target, lut_index: usize) -> Target { + assert!( + lut_index < self.get_luts_length(), + "lut number {} not in luts (length = {})", + lut_index, + self.get_luts_length() + ); + let looking_out = self.add_virtual_target(); + self.update_lookups(looking_in, looking_out, lut_index); + looking_out + } + + /// We call this function at the end of circuit building right before the PI gate to add all `LookupTableGate` and `LookupGate`. + /// It also updates `self.lookup_rows` accordingly. + pub fn add_all_lookups(&mut self) { + for lut_index in 0..self.num_luts() { + assert!( + !self.get_lut_lookups(lut_index).is_empty() || lut_index >= self.get_luts_length(), + "LUT number {:?} is unused", + lut_index + ); + if !self.get_lut_lookups(lut_index).is_empty() { + // Create LU gates. Connect them to the stored lookups. + let last_lu_gate = self.num_gates(); + + let lut = self.get_lut(lut_index); + + let lookups = self.get_lut_lookups(lut_index).to_owned(); + + for (looking_in, looking_out) in lookups { + let gate = LookupGate::new_from_table(&self.config, lut.clone()); + let (gate, i) = + self.find_slot(gate, &[F::from_canonical_usize(lut_index)], &[]); + let gate_in = Target::wire(gate, LookupGate::wire_ith_looking_inp(i)); + let gate_out = Target::wire(gate, LookupGate::wire_ith_looking_out(i)); + self.connect(gate_in, looking_in); + self.connect(gate_out, looking_out); + } + + // Create LUT gates. Nothing is connected to them. + let last_lut_gate = self.num_gates(); + let num_lut_entries = LookupTableGate::num_slots(&self.config); + let num_lut_rows = (self.get_luts_idx_length(lut_index) - 1) / num_lut_entries + 1; + let num_lut_cells = num_lut_entries * num_lut_rows; + for _ in 0..num_lut_cells { + let gate = + LookupTableGate::new_from_table(&self.config, lut.clone(), last_lut_gate); + self.find_slot(gate, &[], &[]); + } + + let first_lut_gate = self.num_gates() - 1; + + // Will ensure the next row's wires will be all zeros. With this, there is no distinction between the transition constraints on the first row + // and on the other rows. Additionally, initial constraints become a simple zero check. + self.add_gate(NoopGate, vec![]); + + // These elements are increasing: the gate rows are deliberately upside down. + // This is necessary for constraint evaluation so that you do not need values of the next + // row's wires, which aren't provided in the evaluation variables. + self.add_lookup_rows(last_lu_gate, last_lut_gate, first_lut_gate); + } + } + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index bac944754b..9016211f97 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -2,6 +2,7 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod hash; pub mod interpolation; +pub mod lookup; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 1dbc4f5747..dd0779dee2 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -100,6 +100,7 @@ pub trait Gate, const D: usize>: 'static + Send + S selector_index: usize, group_range: Range, num_selectors: usize, + num_lookup_selectors: usize, ) -> Vec { let filter = compute_filter( row, @@ -108,6 +109,7 @@ pub trait Gate, const D: usize>: 'static + Send + S num_selectors > 1, ); vars.remove_prefix(num_selectors); + vars.remove_prefix(num_lookup_selectors); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) @@ -123,6 +125,7 @@ pub trait Gate, const D: usize>: 'static + Send + S selector_index: usize, group_range: Range, num_selectors: usize, + num_lookup_selectors: usize, ) -> Vec { let filters: Vec<_> = vars_batch .iter() @@ -135,7 +138,7 @@ pub trait Gate, const D: usize>: 'static + Send + S ) }) .collect(); - vars_batch.remove_prefix(num_selectors); + vars_batch.remove_prefix(num_selectors + num_lookup_selectors); let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); for res_chunk in res_batch.chunks_exact_mut(filters.len()) { batch_multiply_inplace(res_chunk, &filters); @@ -152,6 +155,7 @@ pub trait Gate, const D: usize>: 'static + Send + S selector_index: usize, group_range: Range, num_selectors: usize, + num_lookup_selectors: usize, combined_gate_constraints: &mut [ExtensionTarget], ) { let filter = compute_filter_circuit( @@ -162,6 +166,7 @@ pub trait Gate, const D: usize>: 'static + Send + S num_selectors > 1, ); vars.remove_prefix(num_selectors); + vars.remove_prefix(num_lookup_selectors); let my_constraints = self.eval_unfiltered_circuit(builder, vars); for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) { *acc = builder.mul_add_extension(filter, c, *acc); diff --git a/plonky2/src/gates/lookup.rs b/plonky2/src/gates/lookup.rs new file mode 100644 index 0000000000..0ecac5018e --- /dev/null +++ b/plonky2/src/gates/lookup.rs @@ -0,0 +1,204 @@ +use alloc::format; +use alloc::string::String; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::usize; + +use super::lookup_table::LookupTable; +use crate::field::extension::Extendable; +use crate::field::packed::PackedField; +use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; +use crate::util::serialization::{Buffer, IoResult, Read, Write}; + +pub type Lookup = Vec<(Target, Target)>; + +/// A gate which stores (input, output) lookup pairs made elsewhere in the trace. It doesn't check any constraints itself. +#[derive(Debug, Clone)] +pub struct LookupGate { + /// Number of lookups per gate. + pub num_slots: usize, + /// LUT associated to the gate. + lut: LookupTable, +} + +impl LookupGate { + pub fn new_from_table(config: &CircuitConfig, lut: LookupTable) -> Self { + Self { + num_slots: Self::num_slots(config), + lut, + } + } + pub(crate) fn num_slots(config: &CircuitConfig) -> usize { + let wires_per_lookup = 2; + config.num_routed_wires / wires_per_lookup + } + + pub fn wire_ith_looking_inp(i: usize) -> usize { + 2 * i + } + + pub fn wire_ith_looking_out(i: usize) -> usize { + 2 * i + 1 + } +} + +impl, const D: usize> Gate for LookupGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn serialize(&self, dst: &mut Vec) -> IoResult<()> { + dst.write_usize(self.num_slots)?; + dst.write_lut(&self.lut) + } + + fn deserialize(src: &mut Buffer) -> IoResult { + let num_slots = src.read_usize()?; + let lut = src.read_lut()?; + + Ok(Self { + num_slots, + lut: Arc::new(lut), + }) + } + + fn eval_unfiltered(&self, _vars: EvaluationVars) -> Vec { + // No main trace constraints for lookups. + vec![] + } + + fn eval_unfiltered_base_one( + &self, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, + ) { + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + _builder: &mut CircuitBuilder, + _vars: EvaluationTargets, + ) -> Vec> { + // No main trace constraints for lookups. + vec![] + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + (0..self.num_slots) + .map(|i| { + WitnessGeneratorRef::new( + LookupGenerator { + row, + lut: self.lut.clone(), + slot_nb: i, + } + .adapter(), + ) + }) + .collect() + } + + fn num_wires(&self) -> usize { + self.num_slots * 2 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 0 + } + + fn num_constraints(&self) -> usize { + 0 + } +} + +impl, const D: usize> PackedEvaluableBase for LookupGate { + fn eval_unfiltered_base_packed>( + &self, + _vars: EvaluationVarsBasePacked

, + mut _yield_constr: StridedConstraintConsumer

, + ) { + } +} + +#[derive(Clone, Debug, Default)] +pub struct LookupGenerator { + row: usize, + lut: LookupTable, + slot_nb: usize, +} + +impl SimpleGenerator for LookupGenerator { + fn id(&self) -> String { + "LookupGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![Target::wire( + self.row, + LookupGate::wire_ith_looking_inp(self.slot_nb), + )] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_wire = |wire: usize| -> F { witness.get_target(Target::wire(self.row, wire)) }; + + let input_val = get_wire(LookupGate::wire_ith_looking_inp(self.slot_nb)); + let output_val = if input_val + == F::from_canonical_u16(self.lut[input_val.to_canonical_u64() as usize].0) + { + F::from_canonical_u16(self.lut[input_val.to_canonical_u64() as usize].1) + } else { + let mut cur_idx = 0; + while input_val != F::from_canonical_u16(self.lut[cur_idx].0) + && cur_idx < self.lut.len() + { + cur_idx += 1; + } + assert!(cur_idx < self.lut.len(), "Incorrect input value provided"); + F::from_canonical_u16(self.lut[cur_idx].1) + }; + + let out_wire = Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); + out_buffer.set_target(out_wire, output_val); + } + + fn serialize(&self, dst: &mut Vec) -> IoResult<()> { + dst.write_usize(self.row)?; + dst.write_lut(&self.lut)?; + dst.write_usize(self.slot_nb) + } + + fn deserialize(src: &mut Buffer) -> IoResult { + let row = src.read_usize()?; + let lut = src.read_lut()?; + let slot_nb = src.read_usize()?; + + Ok(Self { + row, + lut: Arc::new(lut), + slot_nb, + }) + } +} diff --git a/plonky2/src/gates/lookup_table.rs b/plonky2/src/gates/lookup_table.rs new file mode 100644 index 0000000000..b4ad902ab8 --- /dev/null +++ b/plonky2/src/gates/lookup_table.rs @@ -0,0 +1,228 @@ +use alloc::format; +use alloc::string::String; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::usize; + +use plonky2_util::ceil_div_usize; + +use crate::field::extension::Extendable; +use crate::field::packed::PackedField; +use crate::gates::gate::Gate; +use crate::gates::packed_util::PackedEvaluableBase; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, WitnessWrite}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; +use crate::util::serialization::{Buffer, IoResult, Read, Write}; + +pub type LookupTable = Arc>; + +/// A gate which stores the set of (input, output) value pairs of a lookup table, and their multiplicities. +#[derive(Debug, Clone)] +pub struct LookupTableGate { + /// Number of lookup entries per gate. + pub num_slots: usize, + /// Lookup table associated to the gate. + pub lut: LookupTable, + /// First row of the lookup table. + last_lut_row: usize, +} + +impl LookupTableGate { + pub fn new_from_table(config: &CircuitConfig, lut: LookupTable, last_lut_row: usize) -> Self { + Self { + num_slots: Self::num_slots(config), + lut, + last_lut_row, + } + } + + pub(crate) fn num_slots(config: &CircuitConfig) -> usize { + let wires_per_entry = 3; + config.num_routed_wires / wires_per_entry + } + + /// Wire for the looked input. + pub fn wire_ith_looked_inp(i: usize) -> usize { + 3 * i + } + + // Wire for the looked output. + pub fn wire_ith_looked_out(i: usize) -> usize { + 3 * i + 1 + } + + /// Wire for the multiplicity. Set after the trace has been generated. + pub fn wire_ith_multiplicity(i: usize) -> usize { + 3 * i + 2 + } +} + +impl, const D: usize> Gate for LookupTableGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn serialize(&self, dst: &mut Vec) -> IoResult<()> { + dst.write_usize(self.num_slots)?; + dst.write_lut(&self.lut)?; + dst.write_usize(self.last_lut_row) + } + + fn deserialize(src: &mut Buffer) -> IoResult { + let num_slots = src.read_usize()?; + let lut = src.read_lut()?; + let last_lut_row = src.read_usize()?; + + Ok(Self { + num_slots, + lut: Arc::new(lut), + last_lut_row, + }) + } + + fn eval_unfiltered(&self, _vars: EvaluationVars) -> Vec { + // No main trace constraints for the lookup table. + vec![] + } + + fn eval_unfiltered_base_one( + &self, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, + ) { + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + _builder: &mut CircuitBuilder, + _vars: EvaluationTargets, + ) -> Vec> { + // No main trace constraints for the lookup table. + vec![] + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + (0..self.num_slots) + .map(|i| { + WitnessGeneratorRef::new( + LookupTableGenerator { + row, + lut: self.lut.clone(), + slot_nb: i, + num_slots: self.num_slots, + last_lut_row: self.last_lut_row, + } + .adapter(), + ) + }) + .collect() + } + + fn num_wires(&self) -> usize { + self.num_slots * 3 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 0 + } + + fn num_constraints(&self) -> usize { + 0 + } +} + +impl, const D: usize> PackedEvaluableBase for LookupTableGate { + fn eval_unfiltered_base_packed>( + &self, + _vars: EvaluationVarsBasePacked

, + mut _yield_constr: StridedConstraintConsumer

, + ) { + } +} + +#[derive(Clone, Debug, Default)] +pub struct LookupTableGenerator { + row: usize, + lut: LookupTable, + slot_nb: usize, + num_slots: usize, + last_lut_row: usize, +} + +impl SimpleGenerator for LookupTableGenerator { + fn id(&self) -> String { + "LookupTableGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![] + } + + fn run_once(&self, _witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let first_row = self.last_lut_row + ceil_div_usize(self.lut.len(), self.num_slots) - 1; + let slot = (first_row - self.row) * self.num_slots + self.slot_nb; + + let slot_input_target = + Target::wire(self.row, LookupTableGate::wire_ith_looked_inp(self.slot_nb)); + let slot_output_target = + Target::wire(self.row, LookupTableGate::wire_ith_looked_out(self.slot_nb)); + + if slot < self.lut.len() { + out_buffer.set_target( + slot_input_target, + F::from_canonical_usize(self.lut[slot].0 as usize), + ); + out_buffer.set_target( + slot_output_target, + F::from_canonical_usize(self.lut[slot].1.into()), + ); + } else { + // Pad with zeros. + out_buffer.set_target(slot_input_target, F::ZERO); + out_buffer.set_target(slot_output_target, F::ZERO); + } + } + + fn serialize(&self, dst: &mut Vec) -> IoResult<()> { + dst.write_usize(self.row)?; + dst.write_lut(&self.lut)?; + dst.write_usize(self.slot_nb)?; + dst.write_usize(self.num_slots)?; + dst.write_usize(self.last_lut_row) + } + + fn deserialize(src: &mut Buffer) -> IoResult { + let row = src.read_usize()?; + let lut = src.read_lut()?; + let slot_nb = src.read_usize()?; + let num_slots = src.read_usize()?; + let last_lut_row = src.read_usize()?; + + Ok(Self { + row, + lut: Arc::new(lut), + slot_nb, + num_slots, + last_lut_row, + }) + } +} diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 9df1a535bd..d9d90b37fb 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -8,6 +8,8 @@ pub mod constant; pub mod coset_interpolation; pub mod exponentiation; pub mod gate; +pub mod lookup; +pub mod lookup_table; pub mod multiplication_extension; pub mod noop; pub mod packed_util; diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index 0e690c6dbd..ef2dec8171 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -8,6 +8,7 @@ use crate::field::extension::Extendable; use crate::field::polynomial::PolynomialValues; use crate::gates::gate::{GateInstance, GateRef}; use crate::hash::hash_types::RichField; +use crate::plonk::circuit_builder::LookupWire; /// Placeholder value to indicate that a gate doesn't use a selector polynomial. pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize; @@ -24,6 +25,77 @@ impl SelectorsInfo { } } +/// Enum listing the different selectors for lookup constraints: +/// - `TransSre` is for Sum and RE transition constraints. +/// - `TransLdc` is for LDC transition constraints. +/// - `InitSre` is for the initial constraint of Sum and Re. +/// - `LastLdc` is for the final LDC (and Sum) constraint. +/// - `StartEnd` indicates where lookup end selectors begin. +pub enum LookupSelectors { + TransSre = 0, + TransLdc, + InitSre, + LastLdc, + StartEnd, +} + +/// Returns selector polynomials for each LUT. We have two constraint domains (remember that gates are stored upside down): +/// - [last_lut_row, first_lut_row] (Sum and RE transition contraints), +/// - [last_lu_row, last_lut_row - 1] (LDC column transition constraints). +/// We also add two more: +/// - {first_lut_row + 1} where we check the initial values of sum and RE (which are 0), +/// - {last_lu_row} where we check that the last value of LDC is 0. +/// Conceptually they're part of the selector ends lookups, but since we can have one polynomial for *all* LUTs it's here. +pub(crate) fn selectors_lookup, const D: usize>( + _gates: &[GateRef], + instances: &[GateInstance], + lookup_rows: &[LookupWire], +) -> Vec> { + let n = instances.len(); + let mut lookup_selectors = Vec::with_capacity(LookupSelectors::StartEnd as usize); + for _ in 0..LookupSelectors::StartEnd as usize { + lookup_selectors.push(PolynomialValues::::new(vec![F::ZERO; n])); + } + + for &LookupWire { + last_lu_gate: last_lu_row, + last_lut_gate: last_lut_row, + first_lut_gate: first_lut_row, + } in lookup_rows + { + for row in last_lut_row..first_lut_row + 1 { + lookup_selectors[LookupSelectors::TransSre as usize].values[row] = F::ONE; + } + for row in last_lu_row..last_lut_row { + lookup_selectors[LookupSelectors::TransLdc as usize].values[row] = F::ONE; + } + lookup_selectors[LookupSelectors::InitSre as usize].values[first_lut_row + 1] = F::ONE; + lookup_selectors[LookupSelectors::LastLdc as usize].values[last_lu_row] = F::ONE; + } + lookup_selectors +} + +/// Returns selectors for checking the validity of the LUTs. +/// Each selector equals one on its respective LUT's `last_lut_row`, and 0 elsewhere. +pub(crate) fn selector_ends_lookups, const D: usize>( + lookup_rows: &[LookupWire], + instances: &[GateInstance], +) -> Vec> { + let n = instances.len(); + let mut lookups_ends = Vec::with_capacity(lookup_rows.len()); + for &LookupWire { + last_lu_gate: _, + last_lut_gate: last_lut_row, + first_lut_gate: _, + } in lookup_rows + { + let mut lookup_ends = PolynomialValues::::new(vec![F::ZERO; n]); + lookup_ends.values[last_lut_row] = F::ONE; + lookups_ends.push(lookup_ends); + } + lookups_ends +} + /// Returns the selector polynomials and related information. /// /// Selector polynomials are computed as follows: diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index b955fea72f..f52c876840 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -13,6 +13,7 @@ pub mod gadgets; pub mod gates; pub mod hash; pub mod iop; +pub mod lookup_test; pub mod plonk; pub mod recursion; pub mod util; diff --git a/plonky2/src/lookup_test.rs b/plonky2/src/lookup_test.rs new file mode 100644 index 0000000000..04b5720383 --- /dev/null +++ b/plonky2/src/lookup_test.rs @@ -0,0 +1,437 @@ +#[cfg(test)] +mod tests { + static LOGGER_INITIALIZED: Once = Once::new(); + + use alloc::sync::Arc; + use std::sync::Once; + + use itertools::Itertools; + use log::{Level, LevelFilter}; + + use crate::gadgets::lookup::{OTHER_TABLE, SMALLER_TABLE, TIP5_TABLE}; + use crate::gates::lookup_table::LookupTable; + use crate::gates::noop::NoopGate; + use crate::plonk::prover::prove; + use crate::util::timing::TimingTree; + + #[test] + fn test_no_lookup() -> anyhow::Result<()> { + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + builder.add_gate(NoopGate, vec![]); + let pw = PartialWitness::new(); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove first", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + data.verify(proof)?; + + Ok(()) + } + + // Tests two lookups in one lookup table. + #[test] + fn test_one_lookup() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + let table_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove one lookup", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + data.verify(proof.clone())?; + + assert!( + proof.public_inputs[2] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + + Ok(()) + } + + // Tests one lookup in two different lookup tables. + #[test] + pub fn test_two_luts() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + + let first_out = tip5_table[look_val_a]; + let second_out = tip5_table[look_val_b]; + + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let other_table = OTHER_TABLE.to_vec(); + + let table_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + let sum = builder.add(output_a, output_b); + + let s = first_out + second_out; + let final_out = other_table[s as usize]; + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + let table2_index = builder.add_lookup_table_from_pairs(table2); + + let output_final = builder.add_lookup_from_index(sum, table2_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + let data = builder.build::(); + let mut timing = TimingTree::new("prove two_luts", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof.clone())?; + timing.print(); + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(first_out), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(second_out), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(final_out), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) + } + + #[test] + pub fn test_different_inputs() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let init_a = 1; + let init_b = 2; + + let tab: Vec = SMALLER_TABLE.to_vec(); + let table: LookupTable = Arc::new((2..10).zip_eq(tab).collect()); + + let other_table = OTHER_TABLE.to_vec(); + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + let small_index = builder.add_lookup_table_from_pairs(table.clone()); + let output_a = builder.add_lookup_from_index(initial_a, small_index); + + let output_b = builder.add_lookup_from_index(initial_b, small_index); + let sum = builder.add(output_a, output_b); + + let other_index = builder.add_lookup_table_from_pairs(table2.clone()); + let output_final = builder.add_lookup_from_index(sum, other_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + + let look_val_a = table[init_a].0; + let look_val_b = table[init_b].0; + pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); + pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove different lookups", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof.clone())?; + timing.print(); + + let out_a = table[init_a].1; + let out_b = table[init_b].1; + let s = out_a + out_b; + let out_final = table2[s as usize].1; + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the smaller LUT gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the smaller LUT gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(out_final), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) + } + + // This test looks up over 514 values for one LookupTableGate, which means that several LookupGates are created. + #[test] + pub fn test_many_lookups() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + + let tip5_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, tip5_index); + + let output_b = builder.add_lookup_from_index(initial_b, tip5_index); + let sum = builder.add(output_a, output_b); + + for _ in 0..514 { + builder.add_lookup_from_index(initial_a, tip5_index); + } + + let other_table = OTHER_TABLE.to_vec(); + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + + let s = out_a + out_b; + let out_final = table2[s as usize].1; + + let other_index = builder.add_lookup_table_from_pairs(table2); + let output_final = builder.add_lookup_from_index(sum, other_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove different lookups", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + + data.verify(proof.clone())?; + timing.print(); + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(out_final), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok(()) + } + + // Tests whether, when adding the same LUT to the circuit, the circuit only adds one copy, with the same index. + #[test] + pub fn test_same_luts() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, WitnessWrite}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let table_index = builder.add_lookup_table_from_pairs(table.clone()); + let output_a = builder.add_lookup_from_index(initial_a, table_index); + + let output_b = builder.add_lookup_from_index(initial_b, table_index); + let sum = builder.add(output_a, output_b); + + let table2_index = builder.add_lookup_table_from_pairs(table); + + let output_final = builder.add_lookup_from_index(sum, table2_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let luts_length = builder.get_luts_length(); + + assert!( + luts_length == 1, + "There are {} LUTs when there should be only one", + luts_length + ); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove two_luts", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + data.verify(proof)?; + timing.print(); + + Ok(()) + } + + fn init_logger() -> anyhow::Result<()> { + let mut builder = env_logger::Builder::from_default_env(); + builder.format_timestamp(None); + builder.filter_level(LevelFilter::Debug); + + builder.try_init()?; + Ok(()) + } +} diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 84b2a0edc5..8955a2bdb9 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -1,4 +1,5 @@ use alloc::collections::BTreeMap; +use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use core::cmp::max; @@ -8,6 +9,7 @@ use std::time::Instant; use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use log::{debug, info, Level}; +use plonky2_util::ceil_div_usize; use crate::field::cosets::get_unique_coset_shifts; use crate::field::extension::{Extendable, FieldExtension}; @@ -23,9 +25,11 @@ use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef}; +use crate::gates::lookup::{Lookup, LookupGate}; +use crate::gates::lookup_table::LookupTable; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; -use crate::gates::selectors::selector_polynomials; +use crate::gates::selectors::{selector_ends_lookups, selector_polynomials, selectors_lookup}; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::hash::merkle_tree::MerkleCap; @@ -49,6 +53,36 @@ use crate::util::partial_products::num_partial_products; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, log2_strict, transpose, transpose_poly_values}; +/// Number of random coins needed for lookups (for each challenge). +/// A coin is a randomly sampled extension field element from the verifier, +/// consisting internally of `CircuitConfig::num_challenges` field elements. +pub const NUM_COINS_LOOKUP: usize = 4; + +/// Enum listing the different types of lookup challenges. +/// `ChallengeA` is used for the linear combination of input and output pairs in Sum and LDC. +/// `ChallengeB` is used for the linear combination of input and output pairs in the polynomial RE. +/// `ChallengeAlpha` is used for the running sums: 1/(alpha - combo_i). +/// `ChallengeDelta` is a challenge on which to evaluate the interpolated LUT function. +pub enum LookupChallenges { + ChallengeA = 0, + ChallengeB = 1, + ChallengeAlpha = 2, + ChallengeDelta = 3, +} + +/// Structure containing, for each lookup table, the indices of the last lookup row, +/// the last lookup table row and the first lookup table row. Since the rows are in +/// reverse order in the trace, they actually correspond, respectively, to: the indices +/// of the first `LookupGate`, the first `LookupTableGate` and the last `LookupTableGate`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct LookupWire { + /// Index of the last lookup row (i.e. the first `LookupGate`). + pub last_lu_gate: usize, + /// Index of the last lookup table row (i.e. the first `LookupTableGate`). + pub last_lut_gate: usize, + /// Index of the first lookup table row (i.e. the last `LookupTableGate`). + pub first_lut_gate: usize, +} pub struct CircuitBuilder, const D: usize> { pub config: CircuitConfig, @@ -92,6 +126,15 @@ pub struct CircuitBuilder, const D: usize> { /// List of constant generators used to fill the constant wires. constant_generators: Vec>, + /// Rows for each LUT: LookupWire contains: first `LookupGate`, first `LookupTableGate`, last `LookupTableGate`. + lookup_rows: Vec, + + /// For each LUT index, vector of `(looking_in, looking_out)` pairs. + lut_to_lookups: Vec, + + // Lookup tables in the form of `Vec<(input_value, output_value)>`. + luts: Vec, + /// Optional common data. When it is `Some(goal_data)`, the `build` function panics if the resulting /// common data doesn't equal `goal_data`. /// This is used in cyclic recursion. @@ -120,6 +163,9 @@ impl, const D: usize> CircuitBuilder { arithmetic_results: HashMap::new(), current_slots: HashMap::new(), constant_generators: Vec::new(), + lookup_rows: Vec::new(), + lut_to_lookups: Vec::new(), + luts: Vec::new(), goal_common_data: None, verifier_data_public_input: None, }; @@ -173,6 +219,39 @@ impl, const D: usize> CircuitBuilder { self.public_inputs.len() } + /// Adds lookup rows for a lookup table. + pub fn add_lookup_rows( + &mut self, + last_lu_gate: usize, + last_lut_gate: usize, + first_lut_gate: usize, + ) { + self.lookup_rows.push(LookupWire { + last_lu_gate, + last_lut_gate, + first_lut_gate, + }); + } + + /// Adds a looking (input, output) pair to the corresponding LUT. + pub fn update_lookups(&mut self, looking_in: Target, looking_out: Target, lut_index: usize) { + assert!( + lut_index < self.lut_to_lookups.len(), + "The LUT with index {} has not been created. The last LUT is at index {}", + lut_index, + self.lut_to_lookups.len() - 1 + ); + self.lut_to_lookups[lut_index].push((looking_in, looking_out)); + } + + pub fn num_luts(&mut self) -> usize { + self.lut_to_lookups.len() + } + + pub fn get_lut_lookups(&self, lut_index: usize) -> &[(Target, Target)] { + &self.lut_to_lookups[lut_index] + } + /// Adds a new "virtual" target. This is not an actual wire in the witness, but just a target /// that help facilitate witness generation. In particular, a generator can assign a values to a /// virtual target, which can then be copied to other (virtual or concrete) targets. When we @@ -486,6 +565,100 @@ impl, const D: usize> CircuitBuilder { self.context_log.pop(self.num_gates()); } + /// Returns the total number of LUTs. + pub fn get_luts_length(&self) -> usize { + self.luts.len() + } + + /// Gets the length of the LUT at index `idx`. + pub fn get_luts_idx_length(&self, idx: usize) -> usize { + assert!( + idx < self.luts.len(), + "index idx: {} greater than the total number of created LUTS: {}", + idx, + self.luts.len() + ); + self.luts[idx].len() + } + + /// Checks whether a LUT is already stored in `self.luts` + pub fn is_stored(&self, lut: LookupTable) -> Option { + self.luts.iter().position(|elt| *elt == lut) + } + + /// Returns the LUT at index `idx`. + pub fn get_lut(&self, idx: usize) -> LookupTable { + assert!( + idx < self.luts.len(), + "index idx: {} greater than the total number of created LUTS: {}", + idx, + self.luts.len() + ); + self.luts[idx].clone() + } + + /// Generates a LUT from a function. + pub fn get_lut_from_fn(f: fn(T) -> T, inputs: &[T]) -> Vec<(T, T)> + where + T: Copy, + { + inputs.iter().map(|&input| (input, f(input))).collect() + } + + /// Given a function `f: fn(u16) -> u16`, adds a LUT to the circuit builder. + pub fn update_luts_from_fn(&mut self, f: fn(u16) -> u16, inputs: &[u16]) -> usize { + let lut = Arc::new(Self::get_lut_from_fn::(f, inputs)); + + // If the LUT `lut` is already stored in `self.luts`, return its index. Otherwise, append `table` to `self.luts` and return its index. + if let Some(idx) = self.is_stored(lut.clone()) { + idx + } else { + self.luts.push(lut); + self.lut_to_lookups.push(vec![]); + assert!(self.luts.len() == self.lut_to_lookups.len()); + self.luts.len() - 1 + } + } + + /// Adds a table to the vector of LUTs in the circuit builder, given a list of inputs and table values. + pub fn update_luts_from_table(&mut self, inputs: &[u16], table: &[u16]) -> usize { + assert!( + inputs.len() == table.len(), + "Inputs and table have incompatible lengths: {} and {}", + inputs.len(), + table.len() + ); + let pairs = inputs + .iter() + .copied() + .zip_eq(table.iter().copied()) + .collect(); + let lut: LookupTable = Arc::new(pairs); + + // If the LUT `lut` is already stored in `self.luts`, return its index. Otherwise, append `table` to `self.luts` and return its index. + if let Some(idx) = self.is_stored(lut.clone()) { + idx + } else { + self.luts.push(lut); + self.lut_to_lookups.push(vec![]); + assert!(self.luts.len() == self.lut_to_lookups.len()); + self.luts.len() - 1 + } + } + + /// Adds a table to the vector of LUTs in the circuit builder. + pub fn update_luts_from_pairs(&mut self, table: LookupTable) -> usize { + // If the LUT `table` is already stored in `self.luts`, return its index. Otherwise, append `table` to `self.luts` and return its index. + if let Some(idx) = self.is_stored(table.clone()) { + idx + } else { + self.luts.push(table); + self.lut_to_lookups.push(vec![]); + assert!(self.luts.len() == self.lut_to_lookups.len()); + self.luts.len() - 1 + } + } + /// Find an available slot, of the form `(row, op)` for gate `G` using parameters `params` /// and constants `constants`. Parameters are any data used to differentiate which gate should be /// used for the given operation. @@ -739,11 +912,14 @@ impl, const D: usize> CircuitBuilder { /// Builds a "full circuit", with both prover and verifier data. pub fn build>(mut self) -> CircuitData { let mut timing = TimingTree::new("preprocess", Level::Trace); + #[cfg(feature = "std")] let start = Instant::now(); + let rate_bits = self.config.fri_config.rate_bits; let cap_height = self.config.fri_config.cap_height; - + // Total number of LUTs. + let num_luts = self.get_luts_length(); // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // those hash wires match the claimed public inputs. let num_public_inputs = self.public_inputs.len(); @@ -759,6 +935,9 @@ impl, const D: usize> CircuitBuilder { } self.randomize_unused_pi_wires(pi_gate); + // Place LUT-related gates. + self.add_all_lookups(); + // Make sure we have enough constant generators. If not, add a `ConstantGate`. while self.constants_to_targets.len() > self.constant_generators.len() { self.add_gate( @@ -808,6 +987,20 @@ impl, const D: usize> CircuitBuilder { gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id())); let (mut constant_vecs, selectors_info) = selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1); + + // Get the lookup selectors. + let num_lookup_selectors = if num_luts != 0 { + let selector_lookups = + selectors_lookup(&gates, &self.gate_instances, &self.lookup_rows); + let selector_ends = selector_ends_lookups(&self.lookup_rows, &self.gate_instances); + let all_lookup_selectors = [selector_lookups, selector_ends].concat(); + let num_lookup_selectors = all_lookup_selectors.len(); + constant_vecs.extend(all_lookup_selectors); + num_lookup_selectors + } else { + 0 + }; + constant_vecs.extend(self.constant_polys()); let num_constants = constant_vecs.len(); @@ -883,6 +1076,13 @@ impl, const D: usize> CircuitBuilder { let num_partial_products = num_partial_products(self.config.num_routed_wires, quotient_degree_factor); + let lookup_degree = self.config.max_quotient_degree_factor - 1; + let num_lookup_polys = if num_luts == 0 { + 0 + } else { + // There is 1 RE polynomial and multiple Sum/LDC polynomials. + ceil_div_usize(LookupGate::num_slots(&self.config), lookup_degree) + 1 + }; let constants_sigmas_cap = constants_sigmas_commitment.merkle_tree.cap.clone(); let domain_separator = self.domain_separator.unwrap_or_default(); let domain_separator_digest = C::Hasher::hash_pad(&domain_separator); @@ -908,6 +1108,9 @@ impl, const D: usize> CircuitBuilder { num_public_inputs, k_is, num_partial_products, + num_lookup_polys, + num_lookup_selectors, + luts: self.luts, }; if let Some(goal_data) = self.goal_common_data { assert_eq!(goal_data, common, "The expected circuit data passed to cyclic recursion method did not match the actual circuit"); @@ -923,6 +1126,8 @@ impl, const D: usize> CircuitBuilder { representative_map: forest.parents, fft_root_table: Some(fft_root_table), circuit_digest, + lookup_rows: self.lookup_rows.clone(), + lut_to_lookups: self.lut_to_lookups.clone(), }; let verifier_only = VerifierOnlyCircuitData:: { diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 46bfd40cb0..7d91a3a3be 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -6,6 +6,7 @@ use core::ops::{Range, RangeFrom}; use anyhow::Result; use serde::Serialize; +use super::circuit_builder::LookupWire; use crate::field::extension::Extendable; use crate::field::fft::FftRootTable; use crate::field::types::Field; @@ -17,6 +18,8 @@ use crate::fri::structure::{ }; use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::GateRef; +use crate::gates::lookup::Lookup; +use crate::gates::lookup_table::LookupTable; use crate::gates::selectors::SelectorsInfo; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; @@ -312,6 +315,10 @@ pub struct ProverOnlyCircuitData< /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to /// seed Fiat-Shamir. pub circuit_digest: <>::Hasher as Hasher>::Hash, + ///The concrete placement of the lookup gates for each lookup table index. + pub lookup_rows: Vec, + /// A vector of (looking_in, looking_out) pairs for for each lookup table index. + pub lut_to_lookups: Vec, } /// Circuit data required by the verifier, but not the prover. @@ -366,6 +373,15 @@ pub struct CommonCircuitData, const D: usize> { /// The number of partial products needed to compute the `Z` polynomials. pub num_partial_products: usize, + + /// The number of lookup polynomials. + pub num_lookup_polys: usize, + + /// The number of lookup selectors. + pub num_lookup_selectors: usize, + + /// The stored lookup tables. + pub luts: Vec, } impl, const D: usize> CommonCircuitData { @@ -426,9 +442,20 @@ impl, const D: usize> CommonCircuitData { 0..self.config.num_challenges } - /// Range of the partial products polynomials in the `zs_partial_products_commitment`. - pub fn partial_products_range(&self) -> RangeFrom { - self.config.num_challenges.. + /// Range of the partial products polynomials in the `zs_partial_products_lookup_commitment`. + pub fn partial_products_range(&self) -> Range { + self.config.num_challenges..(self.num_partial_products + 1) * self.config.num_challenges + } + + /// Range of lookup polynomials in the `zs_partial_products_lookup_commitment`. + pub fn lookup_range(&self) -> RangeFrom { + self.num_zs_partial_products_polys().. + } + + /// Range of lookup polynomials needed for evaluation at `g * zeta`. + pub fn next_lookup_range(&self, i: usize) -> Range { + self.num_zs_partial_products_polys() + i * self.num_lookup_polys + ..self.num_zs_partial_products_polys() + i * self.num_lookup_polys + 2 } pub(crate) fn get_fri_instance(&self, zeta: F::Extension) -> FriInstanceInfo { @@ -443,7 +470,7 @@ impl, const D: usize> CommonCircuitData { let zeta_next = g * zeta; let zeta_next_batch = FriBatchInfo { point: zeta_next, - polynomials: self.fri_zs_polys(), + polynomials: self.fri_next_batch_polys(), }; let openings = vec![zeta_batch, zeta_next_batch]; @@ -469,7 +496,7 @@ impl, const D: usize> CommonCircuitData { let zeta_next = builder.mul_const_extension(g, zeta); let zeta_next_batch = FriBatchInfoTarget { point: zeta_next, - polynomials: self.fri_zs_polys(), + polynomials: self.fri_next_batch_polys(), }; let openings = vec![zeta_batch, zeta_next_batch]; @@ -490,7 +517,7 @@ impl, const D: usize> CommonCircuitData { blinding: PlonkOracle::WIRES.blinding, }, FriOracleInfo { - num_polys: self.num_zs_partial_products_polys(), + num_polys: self.num_zs_partial_products_polys() + self.num_all_lookup_polys(), blinding: PlonkOracle::ZS_PARTIAL_PRODUCTS.blinding, }, FriOracleInfo { @@ -527,14 +554,31 @@ impl, const D: usize> CommonCircuitData { self.config.num_challenges * (1 + self.num_partial_products) } + /// Returns the total number of lookup polynomials. + pub(crate) fn num_all_lookup_polys(&self) -> usize { + self.config.num_challenges * self.num_lookup_polys + } fn fri_zs_polys(&self) -> Vec { FriPolynomialInfo::from_range(PlonkOracle::ZS_PARTIAL_PRODUCTS.index, self.zs_range()) } + /// Returns polynomials that require evaluation at `zeta` and `g * zeta`. + fn fri_next_batch_polys(&self) -> Vec { + [self.fri_zs_polys(), self.fri_lookup_polys()].concat() + } + fn fri_quotient_polys(&self) -> Vec { FriPolynomialInfo::from_range(PlonkOracle::QUOTIENT.index, 0..self.num_quotient_polys()) } + /// Returns the information for lookup polynomials, i.e. the index within the oracle and the indices of the polynomials within the commitment. + fn fri_lookup_polys(&self) -> Vec { + FriPolynomialInfo::from_range( + PlonkOracle::ZS_PARTIAL_PRODUCTS.index, + self.num_zs_partial_products_polys() + ..self.num_zs_partial_products_polys() + self.num_all_lookup_polys(), + ) + } pub(crate) fn num_quotient_polys(&self) -> usize { self.config.num_challenges * self.quotient_degree_factor } @@ -545,6 +589,7 @@ impl, const D: usize> CommonCircuitData { self.fri_wire_polys(), self.fri_zs_partial_products_polys(), self.fri_quotient_polys(), + self.fri_lookup_polys(), ] .concat() } diff --git a/plonky2/src/plonk/get_challenges.rs b/plonky2/src/plonk/get_challenges.rs index d5f028b467..ee6167b90b 100644 --- a/plonky2/src/plonk/get_challenges.rs +++ b/plonky2/src/plonk/get_challenges.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; use hashbrown::HashSet; +use super::circuit_builder::NUM_COINS_LOOKUP; use crate::field::extension::Extendable; use crate::field::polynomial::PolynomialCoeffs; use crate::fri::proof::{CompressedFriProof, FriChallenges, FriProof, FriProofTarget}; @@ -38,6 +39,7 @@ fn get_challenges, C: GenericConfig, cons let num_challenges = config.num_challenges; let mut challenger = Challenger::::new(); + let has_lookup = common_data.num_lookup_polys != 0; // Observe the instance. challenger.observe_hash::(*circuit_digest); @@ -47,6 +49,22 @@ fn get_challenges, C: GenericConfig, cons let plonk_betas = challenger.get_n_challenges(num_challenges); let plonk_gammas = challenger.get_n_challenges(num_challenges); + // If there are lookups in the circuit, we should get delta challenges as well. + // But we can use the already generated `plonk_betas` and `plonk_gammas` as the first `plonk_deltas` challenges. + let plonk_deltas = if has_lookup { + let num_lookup_challenges = NUM_COINS_LOOKUP * num_challenges; + let mut deltas = Vec::with_capacity(num_lookup_challenges); + let num_additional_challenges = num_lookup_challenges - 2 * num_challenges; + let additional = challenger.get_n_challenges(num_additional_challenges); + deltas.extend(&plonk_betas); + deltas.extend(&plonk_gammas); + deltas.extend(additional); + deltas + } else { + vec![] + }; + + // `plonk_zs_partial_products_cap` also contains the commitment to lookup polynomials. challenger.observe_cap::(plonk_zs_partial_products_cap); let plonk_alphas = challenger.get_n_challenges(num_challenges); @@ -59,6 +77,7 @@ fn get_challenges, C: GenericConfig, cons plonk_betas, plonk_gammas, plonk_alphas, + plonk_deltas, plonk_zeta, fri_challenges: challenger.fri_challenges::( commit_phase_merkle_caps, @@ -255,15 +274,32 @@ impl, const D: usize> CircuitBuilder { let num_challenges = config.num_challenges; let mut challenger = RecursiveChallenger::::new(self); + let has_lookup = inner_common_data.num_lookup_polys != 0; // Observe the instance. challenger.observe_hash(&inner_circuit_digest); challenger.observe_hash(&public_inputs_hash); challenger.observe_cap(wires_cap); + let plonk_betas = challenger.get_n_challenges(self, num_challenges); let plonk_gammas = challenger.get_n_challenges(self, num_challenges); + // If there are lookups in the circuit, we should get delta challenges as well. + // But we can use the already generated `plonk_betas` and `plonk_gammas` as the first `plonk_deltas` challenges. + let plonk_deltas = if has_lookup { + let num_lookup_challenges = NUM_COINS_LOOKUP * num_challenges; + let mut deltas = Vec::with_capacity(num_lookup_challenges); + let num_additional_challenges = num_lookup_challenges - 2 * num_challenges; + let additional = challenger.get_n_challenges(self, num_additional_challenges); + deltas.extend(&plonk_betas); + deltas.extend(&plonk_gammas); + deltas.extend(additional); + deltas + } else { + vec![] + }; + challenger.observe_cap(plonk_zs_partial_products_cap); let plonk_alphas = challenger.get_n_challenges(self, num_challenges); @@ -276,6 +312,7 @@ impl, const D: usize> CircuitBuilder { plonk_betas, plonk_gammas, plonk_alphas, + plonk_deltas, plonk_zeta, fri_challenges: challenger.fri_challenges( self, diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index f000915095..bf70cacabc 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -264,6 +264,9 @@ pub struct ProofChallenges, const D: usize> { /// Random values used to combine PLONK constraints. pub plonk_alphas: Vec, + /// Lookup challenges. + pub plonk_deltas: Vec, + /// Point at which the PLONK polynomials are opened. pub plonk_zeta: F::Extension, @@ -274,6 +277,7 @@ pub(crate) struct ProofChallengesTarget { pub plonk_betas: Vec, pub plonk_gammas: Vec, pub plonk_alphas: Vec, + pub plonk_deltas: Vec, pub plonk_zeta: ExtensionTarget, pub fri_challenges: FriChallengesTarget, } @@ -299,6 +303,8 @@ pub struct OpeningSet, const D: usize> { pub plonk_zs_next: Vec, pub partial_products: Vec, pub quotient_polys: Vec, + pub lookup_zs: Vec, + pub lookup_zs_next: Vec, } impl, const D: usize> OpeningSet { @@ -307,7 +313,7 @@ impl, const D: usize> OpeningSet { g: F::Extension, constants_sigmas_commitment: &PolynomialBatch, wires_commitment: &PolynomialBatch, - zs_partial_products_commitment: &PolynomialBatch, + zs_partial_products_lookup_commitment: &PolynomialBatch, quotient_polys_commitment: &PolynomialBatch, common_data: &CommonCircuitData, ) -> Self { @@ -318,35 +324,64 @@ impl, const D: usize> OpeningSet { .collect::>() }; let constants_sigmas_eval = eval_commitment(zeta, constants_sigmas_commitment); - let zs_partial_products_eval = eval_commitment(zeta, zs_partial_products_commitment); + + // `zs_partial_products_lookup_eval` contains the permutation argument polynomials as well as lookup polynomials. + let zs_partial_products_lookup_eval = + eval_commitment(zeta, zs_partial_products_lookup_commitment); + let zs_partial_products_lookup_next_eval = + eval_commitment(g * zeta, zs_partial_products_lookup_commitment); + let quotient_polys = eval_commitment(zeta, quotient_polys_commitment); + Self { constants: constants_sigmas_eval[common_data.constants_range()].to_vec(), plonk_sigmas: constants_sigmas_eval[common_data.sigmas_range()].to_vec(), wires: eval_commitment(zeta, wires_commitment), - plonk_zs: zs_partial_products_eval[common_data.zs_range()].to_vec(), - plonk_zs_next: eval_commitment(g * zeta, zs_partial_products_commitment) - [common_data.zs_range()] - .to_vec(), - partial_products: zs_partial_products_eval[common_data.partial_products_range()] + plonk_zs: zs_partial_products_lookup_eval[common_data.zs_range()].to_vec(), + plonk_zs_next: zs_partial_products_lookup_next_eval[common_data.zs_range()].to_vec(), + partial_products: zs_partial_products_lookup_eval[common_data.partial_products_range()] + .to_vec(), + quotient_polys, + lookup_zs: zs_partial_products_lookup_eval[common_data.lookup_range()].to_vec(), + lookup_zs_next: zs_partial_products_lookup_next_eval[common_data.lookup_range()] .to_vec(), - quotient_polys: eval_commitment(zeta, quotient_polys_commitment), } } - pub(crate) fn to_fri_openings(&self) -> FriOpenings { - let zeta_batch = FriOpeningBatch { - values: [ - self.constants.as_slice(), - self.plonk_sigmas.as_slice(), - self.wires.as_slice(), - self.plonk_zs.as_slice(), - self.partial_products.as_slice(), - self.quotient_polys.as_slice(), - ] - .concat(), + let has_lookup = !self.lookup_zs.is_empty(); + let zeta_batch = if has_lookup { + FriOpeningBatch { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + self.lookup_zs.as_slice(), + ] + .concat(), + } + } else { + FriOpeningBatch { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + ] + .concat(), + } }; - let zeta_next_batch = FriOpeningBatch { - values: self.plonk_zs_next.clone(), + let zeta_next_batch = if has_lookup { + FriOpeningBatch { + values: [self.plonk_zs_next.clone(), self.lookup_zs_next.clone()].concat(), + } + } else { + FriOpeningBatch { + values: self.plonk_zs_next.clone(), + } }; FriOpenings { batches: vec![zeta_batch, zeta_next_batch], @@ -362,25 +397,49 @@ pub struct OpeningSetTarget { pub wires: Vec>, pub plonk_zs: Vec>, pub plonk_zs_next: Vec>, + pub lookup_zs: Vec>, + pub next_lookup_zs: Vec>, pub partial_products: Vec>, pub quotient_polys: Vec>, } impl OpeningSetTarget { pub(crate) fn to_fri_openings(&self) -> FriOpeningsTarget { - let zeta_batch = FriOpeningBatchTarget { - values: [ - self.constants.as_slice(), - self.plonk_sigmas.as_slice(), - self.wires.as_slice(), - self.plonk_zs.as_slice(), - self.partial_products.as_slice(), - self.quotient_polys.as_slice(), - ] - .concat(), + let has_lookup = !self.lookup_zs.is_empty(); + let zeta_batch = if has_lookup { + FriOpeningBatchTarget { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + self.lookup_zs.as_slice(), + ] + .concat(), + } + } else { + FriOpeningBatchTarget { + values: [ + self.constants.as_slice(), + self.plonk_sigmas.as_slice(), + self.wires.as_slice(), + self.plonk_zs.as_slice(), + self.partial_products.as_slice(), + self.quotient_polys.as_slice(), + ] + .concat(), + } }; - let zeta_next_batch = FriOpeningBatchTarget { - values: self.plonk_zs_next.clone(), + let zeta_next_batch = if has_lookup { + FriOpeningBatchTarget { + values: [self.plonk_zs_next.clone(), self.next_lookup_zs.clone()].concat(), + } + } else { + FriOpeningBatchTarget { + values: self.plonk_zs_next.clone(), + } }; FriOpeningsTarget { batches: vec![zeta_batch, zeta_next_batch], @@ -390,10 +449,14 @@ impl OpeningSetTarget { #[cfg(test)] mod tests { + use alloc::sync::Arc; + use anyhow::Result; + use itertools::Itertools; use crate::field::types::Sample; use crate::fri::reduction_strategies::FriReductionStrategy; + use crate::gates::lookup_table::LookupTable; use crate::gates::noop::NoopGate; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -438,4 +501,61 @@ mod tests { verify(proof, &data.verifier_only, &data.common)?; data.verify_compressed(compressed_proof) } + + #[test] + fn test_proof_compression_lookup() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + use plonky2_field::types::Field; + type F = >::F; + + let mut config = CircuitConfig::standard_recursion_config(); + config.fri_config.reduction_strategy = FriReductionStrategy::Fixed(vec![1, 1]); + config.fri_config.num_query_rounds = 50; + + let pw = PartialWitness::new(); + let tip5_table = vec![ + 0, 7, 26, 63, 124, 215, 85, 254, 214, 228, 45, 185, 140, 173, 33, 240, 29, 177, 176, + 32, 8, 110, 87, 202, 204, 99, 150, 106, 230, 14, 235, 128, 213, 239, 212, 138, 23, 130, + 208, 6, 44, 71, 93, 116, 146, 189, 251, 81, 199, 97, 38, 28, 73, 179, 95, 84, 152, 48, + 35, 119, 49, 88, 242, 3, 148, 169, 72, 120, 62, 161, 166, 83, 175, 191, 137, 19, 100, + 129, 112, 55, 221, 102, 218, 61, 151, 237, 68, 164, 17, 147, 46, 234, 203, 216, 22, + 141, 65, 57, 123, 12, 244, 54, 219, 231, 96, 77, 180, 154, 5, 253, 133, 165, 98, 195, + 205, 134, 245, 30, 9, 188, 59, 142, 186, 197, 181, 144, 92, 31, 224, 163, 111, 74, 58, + 69, 113, 196, 67, 246, 225, 10, 121, 50, 60, 157, 90, 122, 2, 250, 101, 75, 178, 159, + 24, 36, 201, 11, 243, 132, 198, 190, 114, 233, 39, 52, 21, 209, 108, 238, 91, 187, 18, + 104, 194, 37, 153, 34, 200, 143, 126, 155, 236, 118, 64, 80, 172, 89, 94, 193, 135, + 183, 86, 107, 252, 13, 167, 206, 136, 220, 207, 103, 171, 160, 76, 182, 227, 217, 158, + 56, 174, 4, 66, 109, 139, 162, 184, 211, 249, 47, 125, 232, 117, 43, 16, 42, 127, 20, + 241, 25, 149, 105, 156, 51, 53, 168, 145, 247, 223, 79, 78, 226, 15, 222, 82, 115, 70, + 210, 27, 41, 1, 170, 40, 131, 192, 229, 248, 255, + ]; + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let lut_index = builder.add_lookup_table_from_pairs(table); + + // Build dummy circuit with a lookup to get a valid proof. + let x = F::TWO; + let out = builder.constant(F::from_canonical_usize(26)); + + let xt = builder.constant(x); + let look_out = builder.add_lookup_from_index(xt, lut_index); + builder.connect(look_out, out); + for _ in 0..100 { + builder.add_gate(NoopGate, vec![]); + } + let data = builder.build::(); + + let proof = data.prove(pw)?; + verify(proof.clone(), &data.verifier_only, &data.common)?; + + // Verify that `decompress ∘ compress = identity`. + let compressed_proof = data.compress(proof.clone())?; + let decompressed_compressed_proof = data.decompress(compressed_proof.clone())?; + assert_eq!(proof, decompressed_compressed_proof); + + verify(proof, &data.verifier_only, &data.common)?; + data.verify_compressed(compressed_proof) + } } diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 86301bb9ac..df2249dfcd 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -1,19 +1,25 @@ use alloc::vec::Vec; use alloc::{format, vec}; +use core::cmp::min; use core::mem::swap; use anyhow::{ensure, Result}; use plonky2_maybe_rayon::*; +use super::circuit_builder::{LookupChallenges, LookupWire}; use crate::field::extension::Extendable; use crate::field::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::field::types::Field; use crate::field::zero_poly_coset::ZeroPolyOnCoset; use crate::fri::oracle::PolynomialBatch; +use crate::gates::lookup::LookupGate; +use crate::gates::lookup_table::LookupTableGate; use crate::hash::hash_types::RichField; use crate::iop::challenger::Challenger; use crate::iop::generator::generate_partial_witness; -use crate::iop::witness::{MatrixWitness, PartialWitness, Witness}; +use crate::iop::target::Target; +use crate::iop::witness::{MatrixWitness, PartialWitness, PartitionWitness, Witness, WitnessWrite}; +use crate::plonk::circuit_builder::NUM_COINS_LOOKUP; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::plonk_common::PlonkOracle; @@ -25,6 +31,74 @@ use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_pr use crate::util::timing::TimingTree; use crate::util::{ceil_div_usize, log2_ceil, transpose}; +/// Set all the lookup gate wires (including multiplicities) and pad unused LU slots. +/// Warning: rows are in descending order: the first gate to appear is the last LU gate, and +/// the last gate to appear is the first LUT gate. +pub fn set_lookup_wires< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, + pw: &mut PartitionWitness, +) { + for ( + lut_index, + &LookupWire { + last_lu_gate: _, + last_lut_gate, + first_lut_gate, + }, + ) in prover_data.lookup_rows.iter().enumerate() + { + let lut_len = common_data.luts[lut_index].len(); + let num_entries = LookupGate::num_slots(&common_data.config); + let num_lut_entries = LookupTableGate::num_slots(&common_data.config); + + // Compute multiplicities. + let mut multiplicities = vec![0; lut_len]; + for (inp_target, _) in prover_data.lut_to_lookups[lut_index].iter() { + let inp_value = pw.get_target(*inp_target); + let mut idx = 0; + while F::from_canonical_u16(common_data.luts[lut_index][idx].0) != inp_value { + idx += 1; + } + multiplicities[idx] += 1; + } + + // Pad the last `LookupGate` with the first entry from the LUT. + let remaining_slots = (num_entries + - (prover_data.lut_to_lookups[lut_index].len() % num_entries)) + % num_entries; + let first_inp_value = F::from_canonical_u16(common_data.luts[lut_index][0].0); + let first_out_value = F::from_canonical_u16(common_data.luts[lut_index][0].1); + for slot in (num_entries - remaining_slots)..num_entries { + let inp_target = + Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_inp(slot)); + let out_target = + Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_out(slot)); + pw.set_target(inp_target, first_inp_value); + pw.set_target(out_target, first_out_value); + + multiplicities[0] += 1; + } + + // We don't need to pad the last `LookupTableGate`; extra wires are set to 0 by default, which satisfies the constraints. + for lut_entry in 0..lut_len { + let row = first_lut_gate - lut_entry / num_lut_entries; + let col = lut_entry % num_lut_entries; + + let mul_target = Target::wire(row, LookupTableGate::wire_ith_multiplicity(col)); + + pw.set_target( + mul_target, + F::from_canonical_usize(multiplicities[lut_entry]), + ); + } + } +} + pub fn prove, C: GenericConfig, const D: usize>( prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, @@ -35,17 +109,20 @@ where C::Hasher: Hasher, C::InnerHasher: Hasher, { + let has_lookup = !common_data.luts.is_empty(); let config = &common_data.config; let num_challenges = config.num_challenges; let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - let partition_witness = timed!( + let mut partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), generate_partial_witness(inputs, prover_data, common_data) ); + set_lookup_wires(prover_data, common_data, &mut partition_witness); + let public_inputs = partition_witness.get_targets(&prover_data.public_inputs); let public_inputs_hash = C::InnerHasher::hash_no_pad(&public_inputs); @@ -85,9 +162,26 @@ where challenger.observe_hash::(public_inputs_hash); challenger.observe_cap::(&wires_commitment.merkle_tree.cap); + + // We need 4 values per challenge: 2 for the combos, 1 for (X-combo) in the accumulators and 1 to prove that the lookup table was computed correctly. + // We can reuse betas and gammas for two of them. + let num_lookup_challenges = NUM_COINS_LOOKUP * num_challenges; + let betas = challenger.get_n_challenges(num_challenges); let gammas = challenger.get_n_challenges(num_challenges); + let deltas = if has_lookup { + let mut delts = Vec::with_capacity(2 * num_challenges); + let num_additional_challenges = num_lookup_challenges - 2 * num_challenges; + let additional = challenger.get_n_challenges(num_additional_challenges); + delts.extend(&betas); + delts.extend(&gammas); + delts.extend(additional); + delts + } else { + vec![] + }; + assert!( common_data.quotient_degree_factor < common_data.config.num_routed_wires, "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." @@ -105,11 +199,21 @@ where .collect(); let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); - let partial_products_and_zs_commitment = timed!( + // All lookup polys: RE and partial SLDCs. + let lookup_polys = + compute_all_lookup_polys(&witness, &deltas, prover_data, common_data, has_lookup); + + let zs_partial_products_lookups = if has_lookup { + [zs_partial_products, lookup_polys].concat() + } else { + zs_partial_products + }; + + let partial_products_zs_and_lookup_commitment = timed!( timing, - "commit to partial products and Z's", - PolynomialBatch::::from_values( - zs_partial_products, + "commit to partial products, Z's and, if any, lookup polynomials", + PolynomialBatch::from_values( + zs_partial_products_lookups, config.fri_config.rate_bits, config.zero_knowledge && PlonkOracle::ZS_PARTIAL_PRODUCTS.blinding, config.fri_config.cap_height, @@ -118,7 +222,7 @@ where ) ); - challenger.observe_cap::(&partial_products_and_zs_commitment.merkle_tree.cap); + challenger.observe_cap::(&partial_products_zs_and_lookup_commitment.merkle_tree.cap); let alphas = challenger.get_n_challenges(num_challenges); @@ -130,15 +234,15 @@ where prover_data, &public_inputs_hash, &wires_commitment, - &partial_products_and_zs_commitment, + &partial_products_zs_and_lookup_commitment, &betas, &gammas, + &deltas, &alphas, ) ); - // Compute the quotient polynomials, aka `t` in the Plonk paper. - let all_quotient_poly_chunks = timed!( + let all_quotient_poly_chunks: Vec> = timed!( timing, "split up quotient polys", quotient_polys @@ -180,28 +284,29 @@ where let openings = timed!( timing, - "construct the opening set", - OpeningSet::new::( + "construct the opening set, including lookups", + OpeningSet::new( zeta, g, &prover_data.constants_sigmas_commitment, &wires_commitment, - &partial_products_and_zs_commitment, + &partial_products_zs_and_lookup_commitment, "ient_polys_commitment, - common_data, + common_data ) ); challenger.observe_openings(&openings.to_fri_openings()); + let instance = common_data.get_fri_instance(zeta); let opening_proof = timed!( timing, "compute opening proofs", PolynomialBatch::::prove_openings( - &common_data.get_fri_instance(zeta), + &instance, &[ &prover_data.constants_sigmas_commitment, &wires_commitment, - &partial_products_and_zs_commitment, + &partial_products_zs_and_lookup_commitment, "ient_polys_commitment, ], &mut challenger, @@ -212,7 +317,7 @@ where let proof = Proof:: { wires_cap: wires_commitment.merkle_tree.cap, - plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap, + plonk_zs_partial_products_cap: partial_products_zs_and_lookup_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, openings, opening_proof, @@ -310,6 +415,162 @@ fn wires_permutation_partial_products_and_zs< .collect() } +/// Computes lookup polynomials for a given challenge. +/// The polynomials hold the value of RE, Sum and Ldc of the Tip5 paper (https://eprint.iacr.org/2023/107.pdf). To reduce their +/// numbers, we batch multiple slots in a single polynomial. Since RE only involves degree one constraints, we can batch +/// all the slots of a row. For Sum and Ldc, batching increases the constraint degree, so we bound the number of +/// partial polynomials according to `max_quotient_degree_factor`. +/// As another optimization, Sum and LDC polynomials are shared (in so called partial SLDC polynomials), and the last value +/// of the last partial polynomial is Sum(end) - LDC(end). If the lookup argument is valid, then it must be equal to 0. +fn compute_lookup_polys< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + witness: &MatrixWitness, + deltas: &[F; 4], + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, +) -> Vec> { + let degree = common_data.degree(); + let num_lu_slots = LookupGate::num_slots(&common_data.config); + let max_lookup_degree = common_data.config.max_quotient_degree_factor - 1; + let num_partial_lookups = ceil_div_usize(num_lu_slots, max_lookup_degree); + let num_lut_slots = LookupTableGate::num_slots(&common_data.config); + let max_lookup_table_degree = ceil_div_usize(num_lut_slots, num_partial_lookups); + + // First poly is RE, the rest are partial SLDCs. + let mut final_poly_vecs = Vec::with_capacity(num_partial_lookups + 1); + for _ in 0..num_partial_lookups + 1 { + final_poly_vecs.push(PolynomialValues::::new(vec![F::ZERO; degree])); + } + + for LookupWire { + last_lu_gate: last_lu_row, + last_lut_gate: last_lut_row, + first_lut_gate: first_lut_row, + } in prover_data.lookup_rows.clone() + { + // Set values for partial Sums and RE. + for row in (last_lut_row..(first_lut_row + 1)).rev() { + // Get combos for Sum. + let looked_combos: Vec = (0..num_lut_slots) + .map(|s| { + let looked_inp = witness.get_wire(row, LookupTableGate::wire_ith_looked_inp(s)); + let looked_out = witness.get_wire(row, LookupTableGate::wire_ith_looked_out(s)); + + looked_inp + deltas[LookupChallenges::ChallengeA as usize] * looked_out + }) + .collect(); + // Get (alpha - combo). + let minus_looked_combos: Vec = (0..num_lut_slots) + .map(|s| deltas[LookupChallenges::ChallengeAlpha as usize] - looked_combos[s]) + .collect(); + // Get 1/(alpha - combo). + let looked_combo_inverses = F::batch_multiplicative_inverse(&minus_looked_combos); + + // Get lookup combos, used to check the well formation of the LUT. + let lookup_combos: Vec = (0..num_lut_slots) + .map(|s| { + let looked_inp = witness.get_wire(row, LookupTableGate::wire_ith_looked_inp(s)); + let looked_out = witness.get_wire(row, LookupTableGate::wire_ith_looked_out(s)); + + looked_inp + deltas[LookupChallenges::ChallengeB as usize] * looked_out + }) + .collect(); + + // Compute next row's first value of RE. + // If `row == first_lut_row`, then `final_poly_vecs[0].values[row + 1] == 0`. + let mut new_re = final_poly_vecs[0].values[row + 1]; + for elt in &lookup_combos { + new_re = new_re * deltas[LookupChallenges::ChallengeDelta as usize] + *elt + } + final_poly_vecs[0].values[row] = new_re; + + for slot in 0..num_partial_lookups { + let prev = if slot != 0 { + final_poly_vecs[slot].values[row] + } else { + // If `row == first_lut_row`, then `final_poly_vecs[num_partial_lookups].values[row + 1] == 0`. + final_poly_vecs[num_partial_lookups].values[row + 1] + }; + let sum = (slot * max_lookup_table_degree + ..min((slot + 1) * max_lookup_table_degree, num_lut_slots)) + .fold(prev, |acc, s| { + acc + witness.get_wire(row, LookupTableGate::wire_ith_multiplicity(s)) + * looked_combo_inverses[s] + }); + final_poly_vecs[slot + 1].values[row] = sum; + } + } + + // Set values for partial LDCs. + for row in (last_lu_row..last_lut_row).rev() { + // Get looking combos. + let looking_combos: Vec = (0..num_lu_slots) + .map(|s| { + let looking_in = witness.get_wire(row, LookupGate::wire_ith_looking_inp(s)); + let looking_out = witness.get_wire(row, LookupGate::wire_ith_looking_out(s)); + + looking_in + deltas[LookupChallenges::ChallengeA as usize] * looking_out + }) + .collect(); + // Get (alpha - combo). + let minus_looking_combos: Vec = (0..num_lu_slots) + .map(|s| deltas[LookupChallenges::ChallengeAlpha as usize] - looking_combos[s]) + .collect(); + // Get 1 / (alpha - combo). + let looking_combo_inverses = F::batch_multiplicative_inverse(&minus_looking_combos); + + for slot in 0..num_partial_lookups { + let prev = if slot == 0 { + // Valid at _any_ row, even `first_lu_row`. + final_poly_vecs[num_partial_lookups].values[row + 1] + } else { + final_poly_vecs[slot].values[row] + }; + let sum = (slot * max_lookup_degree + ..min((slot + 1) * max_lookup_degree, num_lu_slots)) + .fold(F::ZERO, |acc, s| acc + looking_combo_inverses[s]); + final_poly_vecs[slot + 1].values[row] = prev - sum; + } + } + } + + final_poly_vecs +} + +/// Computes lookup polynomials for all challenges. +fn compute_all_lookup_polys< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + witness: &MatrixWitness, + deltas: &[F], + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, + lookup: bool, +) -> Vec> { + if lookup { + let polys: Vec>> = (0..common_data.config.num_challenges) + .map(|c| { + compute_lookup_polys( + witness, + &deltas[c * NUM_COINS_LOOKUP..(c + 1) * NUM_COINS_LOOKUP] + .try_into() + .unwrap(), + prover_data, + common_data, + ) + }) + .collect(); + polys.concat() + } else { + vec![] + } +} + const BATCH_SIZE: usize = 32; fn compute_quotient_polys< @@ -322,12 +583,16 @@ fn compute_quotient_polys< prover_data: &'a ProverOnlyCircuitData, public_inputs_hash: &<>::InnerHasher as Hasher>::Hash, wires_commitment: &'a PolynomialBatch, - zs_partial_products_commitment: &'a PolynomialBatch, + zs_partial_products_and_lookup_commitment: &'a PolynomialBatch, betas: &[F], gammas: &[F], + deltas: &[F], alphas: &[F], ) -> Vec> { let num_challenges = common_data.config.num_challenges; + + let has_lookup = common_data.num_lookup_polys != 0; + let quotient_degree_bits = log2_ceil(common_data.quotient_degree_factor); assert!( quotient_degree_bits <= common_data.config.fri_config.rate_bits, @@ -364,6 +629,10 @@ fn compute_quotient_polys< let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len()); let mut local_zs_batch = Vec::with_capacity(xs_batch.len()); let mut next_zs_batch = Vec::with_capacity(xs_batch.len()); + + let mut local_lookup_batch = Vec::with_capacity(xs_batch.len()); + let mut next_lookup_batch = Vec::with_capacity(xs_batch.len()); + let mut partial_products_batch = Vec::with_capacity(xs_batch.len()); let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len()); @@ -379,13 +648,27 @@ fn compute_quotient_polys< let local_constants = &local_constants_sigmas[common_data.constants_range()]; let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()]; let local_wires = wires_commitment.get_lde_values(i, step); - let local_zs_partial_products = - zs_partial_products_commitment.get_lde_values(i, step); - let local_zs = &local_zs_partial_products[common_data.zs_range()]; - let next_zs = &zs_partial_products_commitment.get_lde_values(i_next, step) - [common_data.zs_range()]; + let local_zs_partial_and_lookup = + zs_partial_products_and_lookup_commitment.get_lde_values(i, step); + let next_zs_partial_and_lookup = + zs_partial_products_and_lookup_commitment.get_lde_values(i_next, step); + + let local_zs = &local_zs_partial_and_lookup[common_data.zs_range()]; + + let next_zs = &next_zs_partial_and_lookup[common_data.zs_range()]; + let partial_products = - &local_zs_partial_products[common_data.partial_products_range()]; + &local_zs_partial_and_lookup[common_data.partial_products_range()]; + + if has_lookup { + let local_lookup_zs = &local_zs_partial_and_lookup[common_data.lookup_range()]; + + let next_lookup_zs = &next_zs_partial_and_lookup[common_data.lookup_range()]; + debug_assert_eq!(local_lookup_zs.len(), common_data.num_all_lookup_polys()); + + local_lookup_batch.push(local_lookup_zs); + next_lookup_batch.push(next_lookup_zs); + } debug_assert_eq!(local_wires.len(), common_data.config.num_wires); debug_assert_eq!(local_zs.len(), num_challenges); @@ -431,10 +714,13 @@ fn compute_quotient_polys< vars_batch, &local_zs_batch, &next_zs_batch, + &local_lookup_batch, + &next_lookup_batch, &partial_products_batch, &s_sigmas_batch, betas, gammas, + deltas, alphas, &z_h_on_coset, ); diff --git a/plonky2/src/plonk/validate_shape.rs b/plonky2/src/plonk/validate_shape.rs index 0b4bf4a870..304aa04a23 100644 --- a/plonky2/src/plonk/validate_shape.rs +++ b/plonky2/src/plonk/validate_shape.rs @@ -52,6 +52,8 @@ where plonk_zs_next, partial_products, quotient_polys, + lookup_zs, + lookup_zs_next, } = openings; let cap_height = common_data.fri_params.config.cap_height; ensure!(wires_cap.height() == cap_height); @@ -64,5 +66,7 @@ where ensure!(plonk_zs_next.len() == config.num_challenges); ensure!(partial_products.len() == config.num_challenges * common_data.num_partial_products); ensure!(quotient_polys.len() == common_data.num_quotient_polys()); + ensure!(lookup_zs.len() == common_data.num_all_lookup_polys()); + ensure!(lookup_zs_next.len() == common_data.num_all_lookup_polys()); Ok(()) } diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index d1d06403f1..adc76654e2 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -1,10 +1,19 @@ use alloc::vec::Vec; use alloc::{format, vec}; +use core::cmp::min; +use plonky2_field::polynomial::PolynomialCoeffs; +use plonky2_util::ceil_div_usize; + +use super::circuit_builder::{LookupChallenges, NUM_COINS_LOOKUP}; +use super::vars::EvaluationVarsBase; use crate::field::batch_util::batch_add_inplace; use crate::field::extension::{Extendable, FieldExtension}; use crate::field::types::Field; use crate::field::zero_poly_coset::ZeroPolyOnCoset; +use crate::gates::lookup::LookupGate; +use crate::gates::lookup_table::LookupTableGate; +use crate::gates::selectors::LookupSelectors; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; @@ -18,6 +27,27 @@ use crate::util::reducing::ReducingFactorTarget; use crate::util::strided_view::PackedStridedView; use crate::with_context; +/// Get the polynomial associated to a lookup table with current challenges. +pub(crate) fn get_lut_poly, const D: usize>( + common_data: &CommonCircuitData, + lut_index: usize, + deltas: &[F], + degree: usize, +) -> PolynomialCoeffs { + let b = deltas[LookupChallenges::ChallengeB as usize]; + let mut coeffs = Vec::new(); + let n = common_data.luts[lut_index].len(); + for i in 0..n { + coeffs.push( + F::from_canonical_u16(common_data.luts[lut_index][i].0) + + b * F::from_canonical_u16(common_data.luts[lut_index][i].1), + ); + } + coeffs.append(&mut vec![F::ZERO; degree - n]); + coeffs.reverse(); + PolynomialCoeffs::new(coeffs) +} + /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random /// linear combination of gate constraints, plus some other terms relating to the permutation /// argument. All such terms should vanish on `H`. @@ -27,19 +57,37 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( vars: EvaluationVars, local_zs: &[F::Extension], next_zs: &[F::Extension], + local_lookup_zs: &[F::Extension], + next_lookup_zs: &[F::Extension], partial_products: &[F::Extension], s_sigmas: &[F::Extension], betas: &[F], gammas: &[F], alphas: &[F], + deltas: &[F], ) -> Vec { + let has_lookup = common_data.num_lookup_polys != 0; let max_degree = common_data.quotient_degree_factor; let num_prods = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints::(common_data, vars); + let lookup_selectors = &vars.local_constants[common_data.selectors_info.num_selectors() + ..common_data.selectors_info.num_selectors() + common_data.num_lookup_selectors]; + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); + + // The terms checking the lookup constraints, if any. + let mut vanishing_all_lookup_terms = if has_lookup { + let num_sldc_polys = common_data.num_lookup_polys - 1; + Vec::with_capacity( + common_data.config.num_challenges * (4 + common_data.luts.len() + 2 * num_sldc_polys), + ) + } else { + Vec::new() + }; + // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); @@ -50,6 +98,26 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( let z_gx = next_zs[i]; vanishing_z_1_terms.push(l_0_x * (z_x - F::Extension::ONE)); + if has_lookup { + let cur_local_lookup_zs = &local_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + let cur_next_lookup_zs = &next_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + + let cur_deltas = &deltas[NUM_COINS_LOOKUP * i..NUM_COINS_LOOKUP * (i + 1)]; + + let lookup_constraints = check_lookup_constraints( + common_data, + vars, + cur_local_lookup_zs, + cur_next_lookup_zs, + lookup_selectors, + cur_deltas.try_into().unwrap(), + ); + + vanishing_all_lookup_terms.extend(lookup_constraints); + } + let numerator_values = (0..common_data.config.num_routed_wires) .map(|j| { let wire_value = vars.local_wires[j]; @@ -83,6 +151,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, + vanishing_all_lookup_terms, constraint_terms, ] .concat(); @@ -99,18 +168,30 @@ pub(crate) fn eval_vanishing_poly_base_batch, const vars_batch: EvaluationVarsBaseBatch, local_zs_batch: &[&[F]], next_zs_batch: &[&[F]], + local_lookup_zs_batch: &[&[F]], + next_lookup_zs_batch: &[&[F]], partial_products_batch: &[&[F]], s_sigmas_batch: &[&[F]], betas: &[F], gammas: &[F], + deltas: &[F], alphas: &[F], z_h_on_coset: &ZeroPolyOnCoset, ) -> Vec> { + let has_lookup = common_data.num_lookup_polys != 0; + let n = indices_batch.len(); assert_eq!(xs_batch.len(), n); assert_eq!(vars_batch.len(), n); assert_eq!(local_zs_batch.len(), n); assert_eq!(next_zs_batch.len(), n); + if has_lookup { + assert_eq!(local_lookup_zs_batch.len(), n); + assert_eq!(next_lookup_zs_batch.len(), n); + } else { + assert_eq!(local_lookup_zs_batch.len(), 0); + assert_eq!(next_lookup_zs_batch.len(), 0); + } assert_eq!(partial_products_batch.len(), n); assert_eq!(s_sigmas_batch.len(), n); @@ -134,13 +215,40 @@ pub(crate) fn eval_vanishing_poly_base_batch, const // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); + // The terms checking the lookup constraints. + let mut vanishing_all_lookup_terms = if has_lookup { + let num_sldc_polys = common_data.num_lookup_polys - 1; + Vec::with_capacity( + common_data.config.num_challenges * (4 + common_data.luts.len() + 2 * num_sldc_polys), + ) + } else { + Vec::new() + }; + let mut res_batch: Vec> = Vec::with_capacity(n); for k in 0..n { let index = indices_batch[k]; let x = xs_batch[k]; let vars = vars_batch.view(k); + + let lookup_selectors: Vec = (0..common_data.num_lookup_selectors) + .map(|i| vars.local_constants[common_data.selectors_info.num_selectors() + i]) + .collect(); + let local_zs = local_zs_batch[k]; let next_zs = next_zs_batch[k]; + let local_lookup_zs = if has_lookup { + local_lookup_zs_batch[k] + } else { + &[] + }; + + let next_lookup_zs = if has_lookup { + next_lookup_zs_batch[k] + } else { + &[] + }; + let partial_products = partial_products_batch[k]; let s_sigmas = s_sigmas_batch[k]; @@ -152,6 +260,26 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let z_gx = next_zs[i]; vanishing_z_1_terms.push(l_0_x * z_x.sub_one()); + // If there are lookups in the circuit, then we add the lookup constraints. + if has_lookup { + let cur_deltas = &deltas[NUM_COINS_LOOKUP * i..NUM_COINS_LOOKUP * (i + 1)]; + + let cur_local_lookup_zs = &local_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + let cur_next_lookup_zs = &next_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + + let lookup_constraints = check_lookup_constraints_batch( + common_data, + vars, + cur_local_lookup_zs, + cur_next_lookup_zs, + &lookup_selectors, + cur_deltas.try_into().unwrap(), + ); + vanishing_all_lookup_terms.extend(lookup_constraints); + } + numerator_values.extend((0..num_routed_wires).map(|j| { let wire_value = vars.local_wires[j]; let k_i = common_data.k_is[j]; @@ -184,16 +312,361 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let vanishing_terms = vanishing_z_1_terms .iter() .chain(vanishing_partial_products_terms.iter()) + .chain(vanishing_all_lookup_terms.iter()) .chain(constraint_terms); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); res_batch.push(res); vanishing_z_1_terms.clear(); vanishing_partial_products_terms.clear(); + vanishing_all_lookup_terms.clear(); } res_batch } +/// Evaluates all lookup constraints, based on the logarithmic derivatives paper (https://eprint.iacr.org/2022/1530.pdf), +/// following the Tip5 paper's implementation (https://eprint.iacr.org/2023/107.pdf). +/// +/// There are three polynomials to check: +/// - RE ensures the well formation of lookup tables; +/// - Sum is a running sum of m_i/(X - (input_i + a * output_i)) where (input_i, output_i) are input pairs in the lookup table (LUT); +/// - LDC is a running sum of 1/(X - (input_i + a * output_i)) where (input_i, output_i) are input pairs that look in the LUT. +/// Sum and LDC are broken down in partial polynomials to lower the constraint degree, similarly to the permutation argument. +/// They also share the same partial SLDC polynomials, so that the last SLDC value is Sum(end) - LDC(end). The final constraint +/// Sum(end) = LDC(end) becomes simply SLDC(end) = 0, and we can remove the LDC initial constraint. +pub fn check_lookup_constraints, const D: usize>( + common_data: &CommonCircuitData, + vars: EvaluationVars, + local_lookup_zs: &[F::Extension], + next_lookup_zs: &[F::Extension], + lookup_selectors: &[F::Extension], + deltas: &[F; 4], +) -> Vec { + let num_lu_slots = LookupGate::num_slots(&common_data.config); + let num_lut_slots = LookupTableGate::num_slots(&common_data.config); + let lu_degree = common_data.quotient_degree_factor - 1; + let num_sldc_polys = local_lookup_zs.len() - 1; + let lut_degree = ceil_div_usize(num_lut_slots, num_sldc_polys); + + let mut constraints = Vec::with_capacity(4 + common_data.luts.len() + 2 * num_sldc_polys); + + // RE is the first polynomial stored. + let z_re = local_lookup_zs[0]; + let next_z_re = next_lookup_zs[0]; + + // Partial Sums and LDCs are both stored in the remaining SLDC polynomials. + let z_x_lookup_sldcs = &local_lookup_zs[1..num_sldc_polys + 1]; + let z_gx_lookup_sldcs = &next_lookup_zs[1..num_sldc_polys + 1]; + + let delta_challenge_a = F::Extension::from(deltas[LookupChallenges::ChallengeA as usize]); + let delta_challenge_b = F::Extension::from(deltas[LookupChallenges::ChallengeB as usize]); + + // Compute all current looked and looking combos, i.e. the combos we need for the SLDC polynomials. + let current_looked_combos: Vec = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + input_wire + delta_challenge_a * output_wire + }) + .collect(); + + let current_looking_combos: Vec = (0..num_lu_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupGate::wire_ith_looking_inp(s)]; + let output_wire = vars.local_wires[LookupGate::wire_ith_looking_out(s)]; + input_wire + delta_challenge_a * output_wire + }) + .collect(); + + // Compute all current lookup combos, i.e. the combos used to check that the LUT is correct. + let current_lookup_combos: Vec = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + input_wire + delta_challenge_b * output_wire + }) + .collect(); + + // Check last LDC constraint. + constraints.push( + lookup_selectors[LookupSelectors::LastLdc as usize] * z_x_lookup_sldcs[num_sldc_polys - 1], + ); + + // Check initial Sum constraint. + constraints.push(lookup_selectors[LookupSelectors::InitSre as usize] * z_x_lookup_sldcs[0]); + + // Check initial RE constraint. + constraints.push(lookup_selectors[LookupSelectors::InitSre as usize] * z_re); + + let current_delta = deltas[LookupChallenges::ChallengeDelta as usize]; + + // Check final RE constraints for each different LUT. + for r in LookupSelectors::StartEnd as usize..common_data.num_lookup_selectors { + let cur_ends_selector = lookup_selectors[r]; + let lut_row_number = ceil_div_usize( + common_data.luts[r - LookupSelectors::StartEnd as usize].len(), + num_lut_slots, + ); + let cur_function_eval = get_lut_poly( + common_data, + r - LookupSelectors::StartEnd as usize, + deltas, + num_lut_slots * lut_row_number, + ) + .eval(current_delta); + + constraints.push(cur_ends_selector * (z_re - cur_function_eval.into())) + } + + // Check RE row transition constraint. + let mut cur_sum = next_z_re; + for elt in ¤t_lookup_combos { + cur_sum = + cur_sum * F::Extension::from(deltas[LookupChallenges::ChallengeDelta as usize]) + *elt; + } + let unfiltered_re_line = z_re - cur_sum; + + constraints.push(lookup_selectors[LookupSelectors::TransSre as usize] * unfiltered_re_line); + + for poly in 0..num_sldc_polys { + // Compute prod(alpha - combo) for the current slot for Sum. + let lut_prod: F::Extension = (poly * lut_degree + ..min((poly + 1) * lut_degree, num_lut_slots)) + .map(|i| { + F::Extension::from(deltas[LookupChallenges::ChallengeAlpha as usize]) + - current_looked_combos[i] + }) + .product(); + + // Compute prod(alpha - combo) for the current slot for LDC. + let lu_prod: F::Extension = (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .map(|i| { + F::Extension::from(deltas[LookupChallenges::ChallengeAlpha as usize]) + - current_looking_combos[i] + }) + .product(); + + // Function which computes, given index i: prod_{j!=i}(alpha - combo_j) for Sum. + let lut_prod_i = |i| { + (poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots)) + .map(|j| { + if j != i { + F::Extension::from(deltas[LookupChallenges::ChallengeAlpha as usize]) + - current_looked_combos[j] + } else { + F::Extension::ONE + } + }) + .product() + }; + + // Function which computes, given index i: prod_{j!=i}(alpha - combo_j) for LDC. + let lu_prod_i = |i| { + (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .map(|j| { + if j != i { + F::Extension::from(deltas[LookupChallenges::ChallengeAlpha as usize]) + - current_looking_combos[j] + } else { + F::Extension::ONE + } + }) + .product() + }; + // Compute sum_i(prod_{j!=i}(alpha - combo_j)) for LDC. + let lu_sum_prods = (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .fold(F::Extension::ZERO, |acc, i| acc + lu_prod_i(i)); + + // Compute sum_i(mul_i.prod_{j!=i}(alpha - combo_j)) for Sum. + let lut_sum_prods_with_mul = (poly * lut_degree + ..min((poly + 1) * lut_degree, num_lut_slots)) + .fold(F::Extension::ZERO, |acc, i| { + acc + vars.local_wires[LookupTableGate::wire_ith_multiplicity(i)] * lut_prod_i(i) + }); + + // The previous element is the previous poly of the current row or the last poly of the next row. + let prev = if poly == 0 { + z_gx_lookup_sldcs[num_sldc_polys - 1] + } else { + z_x_lookup_sldcs[poly - 1] + }; + + // Check Sum row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_sum_transition = + lut_prod * (z_x_lookup_sldcs[poly] - prev) - lut_sum_prods_with_mul; + constraints + .push(lookup_selectors[LookupSelectors::TransSre as usize] * unfiltered_sum_transition); + + // Check LDC row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_ldc_transition = lu_prod * (z_x_lookup_sldcs[poly] - prev) + lu_sum_prods; + constraints + .push(lookup_selectors[LookupSelectors::TransLdc as usize] * unfiltered_ldc_transition); + } + + constraints +} + +/// Same as `check_lookup_constraints`, but for the base field case. +pub fn check_lookup_constraints_batch, const D: usize>( + common_data: &CommonCircuitData, + vars: EvaluationVarsBase, + local_lookup_zs: &[F], + next_lookup_zs: &[F], + lookup_selectors: &[F], + deltas: &[F; 4], +) -> Vec { + let num_lu_slots = LookupGate::num_slots(&common_data.config); + let num_lut_slots = LookupTableGate::num_slots(&common_data.config); + let lu_degree = common_data.quotient_degree_factor - 1; + let num_sldc_polys = local_lookup_zs.len() - 1; + let lut_degree = ceil_div_usize(num_lut_slots, num_sldc_polys); + + let mut constraints = Vec::with_capacity(4 + common_data.luts.len() + 2 * num_sldc_polys); + + // RE is the first polynomial stored. + let z_re = local_lookup_zs[0]; + let next_z_re = next_lookup_zs[0]; + + // Partial Sums and LDCs are both stored in the remaining polynomials. + let z_x_lookup_sldcs = &local_lookup_zs[1..num_sldc_polys + 1]; + let z_gx_lookup_sldcs = &next_lookup_zs[1..num_sldc_polys + 1]; + + // Compute all current looked and looking combos, i.e. the combos we need for the SLDC polynomials. + let current_looked_combos: Vec = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + input_wire + deltas[LookupChallenges::ChallengeA as usize] * output_wire + }) + .collect(); + + let current_looking_combos: Vec = (0..num_lu_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupGate::wire_ith_looking_inp(s)]; + let output_wire = vars.local_wires[LookupGate::wire_ith_looking_out(s)]; + input_wire + deltas[LookupChallenges::ChallengeA as usize] * output_wire + }) + .collect(); + + // Compute all current lookup combos, i.e. the combos used to check that the LUT is correct. + let current_lookup_combos: Vec = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + input_wire + deltas[LookupChallenges::ChallengeB as usize] * output_wire + }) + .collect(); + + // Check last LDC constraint. + constraints.push( + lookup_selectors[LookupSelectors::LastLdc as usize] * z_x_lookup_sldcs[num_sldc_polys - 1], + ); + + // Check initial Sum constraint. + constraints.push(lookup_selectors[LookupSelectors::InitSre as usize] * z_x_lookup_sldcs[0]); + + // Check initial RE constraint. + constraints.push(lookup_selectors[LookupSelectors::InitSre as usize] * z_re); + + let current_delta = deltas[LookupChallenges::ChallengeDelta as usize]; + + // Check final RE constraints for each different LUT. + for r in LookupSelectors::StartEnd as usize..common_data.num_lookup_selectors { + let cur_ends_selector = lookup_selectors[r]; + let lut_row_number = ceil_div_usize( + common_data.luts[r - LookupSelectors::StartEnd as usize].len(), + num_lut_slots, + ); + let cur_function_eval = get_lut_poly( + common_data, + r - LookupSelectors::StartEnd as usize, + deltas, + num_lut_slots * lut_row_number, + ) + .eval(current_delta); + + constraints.push(cur_ends_selector * (z_re - cur_function_eval)) + } + + // Check RE row transition constraint. + let mut cur_sum = next_z_re; + for elt in ¤t_lookup_combos { + cur_sum = cur_sum * deltas[LookupChallenges::ChallengeDelta as usize] + *elt; + } + let unfiltered_re_line = z_re - cur_sum; + + constraints.push(lookup_selectors[LookupSelectors::TransSre as usize] * unfiltered_re_line); + + for poly in 0..num_sldc_polys { + // Compute prod(alpha - combo) for the current slot for Sum. + let lut_prod: F = (poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots)) + .map(|i| deltas[LookupChallenges::ChallengeAlpha as usize] - current_looked_combos[i]) + .product(); + + // Compute prod(alpha - combo) for the current slot for LDC. + let lu_prod: F = (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .map(|i| deltas[LookupChallenges::ChallengeAlpha as usize] - current_looking_combos[i]) + .product(); + + // Function which computes, given index i: prod_{j!=i}(alpha - combo_j) for Sum. + let lut_prod_i = |i| { + (poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots)) + .map(|j| { + if j != i { + deltas[LookupChallenges::ChallengeAlpha as usize] - current_looked_combos[j] + } else { + F::ONE + } + }) + .product() + }; + + // Function which computes, given index i: prod_{j!=i}(alpha - combo_j) for LDC. + let lu_prod_i = |i| { + (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .map(|j| { + if j != i { + deltas[LookupChallenges::ChallengeAlpha as usize] + - current_looking_combos[j] + } else { + F::ONE + } + }) + .product() + }; + + // Compute sum_i(prod_{j!=i}(alpha - combo_j)) for LDC. + let lu_sum_prods = (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)) + .fold(F::ZERO, |acc, i| acc + lu_prod_i(i)); + + // Compute sum_i(mul_i.prod_{j!=i}(alpha - combo_j)) for Sum. + let lut_sum_prods_with_mul = (poly * lut_degree + ..min((poly + 1) * lut_degree, num_lut_slots)) + .fold(F::ZERO, |acc, i| { + acc + vars.local_wires[LookupTableGate::wire_ith_multiplicity(i)] * lut_prod_i(i) + }); + + // The previous element is the previous poly of the current row or the last poly of the next row. + let prev = if poly == 0 { + z_gx_lookup_sldcs[num_sldc_polys - 1] + } else { + z_x_lookup_sldcs[poly - 1] + }; + + // Check Sum row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_sum_transition = + lut_prod * (z_x_lookup_sldcs[poly] - prev) - lut_sum_prods_with_mul; + constraints + .push(lookup_selectors[LookupSelectors::TransSre as usize] * unfiltered_sum_transition); + + // Check LDC row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_ldc_transition = lu_prod * (z_x_lookup_sldcs[poly] - prev) + lu_sum_prods; + constraints + .push(lookup_selectors[LookupSelectors::TransLdc as usize] * unfiltered_ldc_transition); + } + constraints +} + /// Evaluates all gate constraints. /// /// `num_gate_constraints` is the largest number of constraints imposed by any gate. It is not @@ -212,6 +685,7 @@ pub fn evaluate_gate_constraints, const D: usize>( selector_index, common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors(), + common_data.num_lookup_selectors, ); for (i, c) in gate_constraints.into_iter().enumerate() { debug_assert!( @@ -242,6 +716,7 @@ pub fn evaluate_gate_constraints_base_batch, const selector_index, common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors(), + common_data.num_lookup_selectors, ); debug_assert!( gate_constraints_batch.len() <= constraints_batch.len(), @@ -274,6 +749,7 @@ pub fn evaluate_gate_constraints_circuit, const D: selector_index, common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors(), + common_data.num_lookup_selectors, &mut all_gate_constraints, ) ); @@ -281,6 +757,39 @@ pub fn evaluate_gate_constraints_circuit, const D: all_gate_constraints } +pub(crate) fn get_lut_poly_circuit, const D: usize>( + builder: &mut CircuitBuilder, + common_data: &CommonCircuitData, + lut_index: usize, + deltas: &[Target], + degree: usize, +) -> Target { + let b = deltas[LookupChallenges::ChallengeB as usize]; + let delta = deltas[LookupChallenges::ChallengeDelta as usize]; + let n = common_data.luts[lut_index].len(); + let mut coeffs: Vec = (0..n) + .map(|i| { + let temp = + builder.mul_const(F::from_canonical_u16(common_data.luts[lut_index][i].1), b); + builder.add_const( + temp, + F::from_canonical_u16(common_data.luts[lut_index][i].0), + ) + }) + .collect(); + for _ in n..degree { + coeffs.push(builder.zero()); + } + coeffs.reverse(); + coeffs + .iter() + .rev() + .fold(builder.constant(F::ZERO), |acc, &c| { + let temp = builder.mul(acc, delta); + builder.add(temp, c) + }) +} + /// Evaluate the vanishing polynomial at `x`. In this context, the vanishing polynomial is a random /// linear combination of gate constraints, plus some other terms relating to the permutation /// argument. All such terms should vanish on `H`. @@ -295,12 +804,16 @@ pub(crate) fn eval_vanishing_poly_circuit, const D: vars: EvaluationTargets, local_zs: &[ExtensionTarget], next_zs: &[ExtensionTarget], + local_lookup_zs: &[ExtensionTarget], + next_lookup_zs: &[ExtensionTarget], partial_products: &[ExtensionTarget], s_sigmas: &[ExtensionTarget], betas: &[Target], gammas: &[Target], alphas: &[Target], + deltas: &[Target], ) -> Vec> { + let has_lookup = common_data.num_lookup_polys != 0; let max_degree = common_data.quotient_degree_factor; let num_prods = common_data.num_partial_products; @@ -310,8 +823,22 @@ pub(crate) fn eval_vanishing_poly_circuit, const D: evaluate_gate_constraints_circuit::(builder, common_data, vars,) ); + let lookup_selectors = &vars.local_constants[common_data.selectors_info.num_selectors() + ..common_data.selectors_info.num_selectors() + common_data.num_lookup_selectors]; + // The L_0(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); + + // The terms checking lookup constraints. + let mut vanishing_all_lookup_terms = if has_lookup { + let num_sldc_polys = common_data.num_lookup_polys - 1; + Vec::with_capacity( + common_data.config.num_challenges * (4 + common_data.luts.len() + 2 * num_sldc_polys), + ) + } else { + Vec::new() + }; + // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); @@ -331,6 +858,27 @@ pub(crate) fn eval_vanishing_poly_circuit, const D: // L_0(x) (Z(x) - 1) = 0. vanishing_z_1_terms.push(builder.mul_sub_extension(l_0_x, z_x, l_0_x)); + // If there are lookups in the circuit, then we add the lookup constraints + if has_lookup { + let cur_local_lookup_zs = &local_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + let cur_next_lookup_zs = &next_lookup_zs + [common_data.num_lookup_polys * i..common_data.num_lookup_polys * (i + 1)]; + + let cur_deltas = &deltas[NUM_COINS_LOOKUP * i..NUM_COINS_LOOKUP * (i + 1)]; + + let lookup_constraints = check_lookup_constraints_circuit( + builder, + common_data, + vars, + cur_local_lookup_zs, + cur_next_lookup_zs, + lookup_selectors, + cur_deltas, + ); + vanishing_all_lookup_terms.extend(lookup_constraints); + } + let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); @@ -367,6 +915,7 @@ pub(crate) fn eval_vanishing_poly_circuit, const D: let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, + vanishing_all_lookup_terms, constraint_terms, ] .concat(); @@ -380,3 +929,215 @@ pub(crate) fn eval_vanishing_poly_circuit, const D: }) .collect() } + +/// Same as `check_lookup_constraints`, but for the recursive case. +pub fn check_lookup_constraints_circuit, const D: usize>( + builder: &mut CircuitBuilder, + common_data: &CommonCircuitData, + vars: EvaluationTargets, + local_lookup_zs: &[ExtensionTarget], + next_lookup_zs: &[ExtensionTarget], + lookup_selectors: &[ExtensionTarget], + deltas: &[Target], +) -> Vec> { + let num_lu_slots = LookupGate::num_slots(&common_data.config); + let num_lut_slots = LookupTableGate::num_slots(&common_data.config); + let lu_degree = common_data.quotient_degree_factor - 1; + let num_sldc_polys = local_lookup_zs.len() - 1; + let lut_degree = ceil_div_usize(num_lut_slots, num_sldc_polys); + + let mut constraints = Vec::with_capacity(4 + common_data.luts.len() + 2 * num_sldc_polys); + + // RE is the first polynomial stored. + let z_re = local_lookup_zs[0]; + let next_z_re = next_lookup_zs[0]; + + // Partial Sums and LDCs (i.e. the SLDC polynomials) are stored in the remaining polynomials. + let z_x_lookup_sldcs = &local_lookup_zs[1..num_sldc_polys + 1]; + let z_gx_lookup_sldcs = &next_lookup_zs[1..num_sldc_polys + 1]; + + // Convert deltas to ExtensionTargets. + let ext_deltas = deltas + .iter() + .map(|d| builder.convert_to_ext(*d)) + .collect::>(); + + // Computing all current looked and looking combos, i.e. the combos we need for the SLDC polynomials. + let current_looked_combos = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeA as usize], + output_wire, + input_wire, + ) + }) + .collect::>(); + let current_looking_combos = (0..num_lu_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupGate::wire_ith_looking_inp(s)]; + let output_wire = vars.local_wires[LookupGate::wire_ith_looking_out(s)]; + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeA as usize], + output_wire, + input_wire, + ) + }) + .collect::>(); + + let current_lut_subs = (0..num_lut_slots) + .map(|s| { + builder.sub_extension( + ext_deltas[LookupChallenges::ChallengeAlpha as usize], + current_looked_combos[s], + ) + }) + .collect::>(); + + let current_lu_subs = (0..num_lu_slots) + .map(|s| { + builder.sub_extension( + ext_deltas[LookupChallenges::ChallengeAlpha as usize], + current_looking_combos[s], + ) + }) + .collect::>(); + + // Computing all current lookup combos, i.e. the combos used to check that the LUT is correct. + let current_lookup_combos = (0..num_lut_slots) + .map(|s| { + let input_wire = vars.local_wires[LookupTableGate::wire_ith_looked_inp(s)]; + let output_wire = vars.local_wires[LookupTableGate::wire_ith_looked_out(s)]; + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeB as usize], + output_wire, + input_wire, + ) + }) + .collect::>(); + + // Check last LDC constraint. + constraints.push(builder.mul_extension( + lookup_selectors[LookupSelectors::LastLdc as usize], + z_x_lookup_sldcs[num_sldc_polys - 1], + )); + + // Check initial Sum constraint. + constraints.push(builder.mul_extension( + lookup_selectors[LookupSelectors::InitSre as usize], + z_x_lookup_sldcs[0], + )); + + // Check initial RE constraint. + constraints + .push(builder.mul_extension(lookup_selectors[LookupSelectors::InitSre as usize], z_re)); + + // Check final RE constraints for each different LUT. + for r in LookupSelectors::StartEnd as usize..common_data.num_lookup_selectors { + let cur_ends_selectors = lookup_selectors[r]; + let lut_row_number = ceil_div_usize( + common_data.luts[r - LookupSelectors::StartEnd as usize].len(), + num_lut_slots, + ); + let cur_function_eval = get_lut_poly_circuit( + builder, + common_data, + r - LookupSelectors::StartEnd as usize, + deltas, + num_lut_slots * lut_row_number, + ); + let cur_function_eval_ext = builder.convert_to_ext(cur_function_eval); + + let cur_re = builder.sub_extension(z_re, cur_function_eval_ext); + constraints.push(builder.mul_extension(cur_ends_selectors, cur_re)); + } + + // Check RE row transition constraint. + let mut cur_sum = next_z_re; + for elt in ¤t_lookup_combos { + cur_sum = builder.mul_add_extension( + cur_sum, + ext_deltas[LookupChallenges::ChallengeDelta as usize], + *elt, + ); + } + let unfiltered_re_line = builder.sub_extension(z_re, cur_sum); + + constraints.push(builder.mul_extension( + lookup_selectors[LookupSelectors::TransSre as usize], + unfiltered_re_line, + )); + + for poly in 0..num_sldc_polys { + // Compute prod(alpha - combo) for the current slot for Sum. + let mut lut_prod = builder.one_extension(); + for i in poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots) { + lut_prod = builder.mul_extension(lut_prod, current_lut_subs[i]); + } + + // Compute prod(alpha - combo) for the current slot for LDC. + let mut lu_prod = builder.one_extension(); + for i in poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots) { + lu_prod = builder.mul_extension(lu_prod, current_lu_subs[i]); + } + + let one = builder.one_extension(); + let zero = builder.zero_extension(); + + // Compute sum_i(prod_{j!=i}(alpha - combo_j)) for LDC. + let lu_sum_prods = + (poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots)).fold(zero, |acc, i| { + let mut prod_i = one; + + for j in poly * lu_degree..min((poly + 1) * lu_degree, num_lu_slots) { + if j != i { + prod_i = builder.mul_extension(prod_i, current_lu_subs[j]); + } + } + builder.add_extension(acc, prod_i) + }); + + // Compute sum_i(mul_i.prod_{j!=i}(alpha - combo_j)) for Sum. + let lut_sum_prods_mul = (poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots)) + .fold(zero, |acc, i| { + let mut prod_i = one; + + for j in poly * lut_degree..min((poly + 1) * lut_degree, num_lut_slots) { + if j != i { + prod_i = builder.mul_extension(prod_i, current_lut_subs[j]); + } + } + builder.mul_add_extension( + prod_i, + vars.local_wires[LookupTableGate::wire_ith_multiplicity(i)], + acc, + ) + }); + + // The previous element is the previous poly of the current row or the last poly of the next row. + let prev = if poly == 0 { + z_gx_lookup_sldcs[num_sldc_polys - 1] + } else { + z_x_lookup_sldcs[poly - 1] + }; + + let cur_sub = builder.sub_extension(z_x_lookup_sldcs[poly], prev); + + // Check sum row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_sum_transition = + builder.mul_sub_extension(lut_prod, cur_sub, lut_sum_prods_mul); + constraints.push(builder.mul_extension( + lookup_selectors[LookupSelectors::TransSre as usize], + unfiltered_sum_transition, + )); + + // Check ldc row and col transitions. It's the same constraint, with a row transition happening for slot == 0. + let unfiltered_ldc_transition = builder.mul_add_extension(lu_prod, cur_sub, lu_sum_prods); + constraints.push(builder.mul_extension( + lookup_selectors[LookupSelectors::TransLdc as usize], + unfiltered_ldc_transition, + )); + } + constraints +} diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index ecb7e46c39..b160fddc28 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -55,6 +55,8 @@ pub(crate) fn verify_with_challenges< }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_next; + let local_lookup_zs = &proof.openings.lookup_zs; + let next_lookup_zs = &proof.openings.lookup_zs_next; let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; @@ -65,11 +67,14 @@ pub(crate) fn verify_with_challenges< vars, local_zs, next_zs, + local_lookup_zs, + next_lookup_zs, partial_products, s_sigmas, &challenges.plonk_betas, &challenges.plonk_gammas, &challenges.plonk_alphas, + &challenges.plonk_deltas, ); // Check each polynomial identity, of the form `vanishing(x) = Z_H(x) quotient(x)`, at zeta. @@ -93,6 +98,7 @@ pub(crate) fn verify_with_challenges< let merkle_caps = &[ verifier_data.constants_sigmas_cap.clone(), proof.wires_cap, + // In the lookup case, `plonk_zs_partial_products_cap` should also include the lookup commitment. proof.plonk_zs_partial_products_cap, proof.quotient_polys_cap, ]; diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index 6331118b60..3f3b626751 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -191,6 +191,8 @@ impl, const D: usize> CircuitBuilder { wires: self.select_vec_ext(b, &os0.wires, &os1.wires), plonk_zs: self.select_vec_ext(b, &os0.plonk_zs, &os1.plonk_zs), plonk_zs_next: self.select_vec_ext(b, &os0.plonk_zs_next, &os1.plonk_zs_next), + lookup_zs: self.select_vec_ext(b, &os0.lookup_zs, &os1.lookup_zs), + next_lookup_zs: self.select_vec_ext(b, &os0.next_lookup_zs, &os1.next_lookup_zs), partial_products: self.select_vec_ext(b, &os0.partial_products, &os1.partial_products), quotient_polys: self.select_vec_ext(b, &os0.quotient_polys, &os1.quotient_polys), } diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index ec13dab340..613766e455 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -66,6 +66,8 @@ impl, const D: usize> CircuitBuilder { }; let local_zs = &proof.openings.plonk_zs; let next_zs = &proof.openings.plonk_zs_next; + let local_lookup_zs = &proof.openings.lookup_zs; + let next_lookup_zs = &proof.openings.next_lookup_zs; let s_sigmas = &proof.openings.plonk_sigmas; let partial_products = &proof.openings.partial_products; @@ -82,11 +84,14 @@ impl, const D: usize> CircuitBuilder { vars, local_zs, next_zs, + local_lookup_zs, + next_lookup_zs, partial_products, s_sigmas, &challenges.plonk_betas, &challenges.plonk_gammas, &challenges.plonk_alphas, + &challenges.plonk_deltas, ) ); @@ -147,7 +152,7 @@ impl, const D: usize> CircuitBuilder { let num_leaves_per_oracle = &[ common_data.num_preprocessed_polys(), config.num_wires + salt, - common_data.num_zs_partial_products_polys() + salt, + common_data.num_zs_partial_products_polys() + common_data.num_all_lookup_polys() + salt, common_data.num_quotient_polys() + salt, ]; @@ -164,12 +169,20 @@ impl, const D: usize> CircuitBuilder { let config = &common_data.config; let num_challenges = config.num_challenges; let total_partial_products = num_challenges * common_data.num_partial_products; + let has_lookup = common_data.num_lookup_polys != 0; + let num_lookups = if has_lookup { + common_data.num_all_lookup_polys() + } else { + 0 + }; OpeningSetTarget { constants: self.add_virtual_extension_targets(common_data.num_constants), plonk_sigmas: self.add_virtual_extension_targets(config.num_routed_wires), wires: self.add_virtual_extension_targets(config.num_wires), plonk_zs: self.add_virtual_extension_targets(num_challenges), plonk_zs_next: self.add_virtual_extension_targets(num_challenges), + lookup_zs: self.add_virtual_extension_targets(num_lookups), + next_lookup_zs: self.add_virtual_extension_targets(num_lookups), partial_products: self.add_virtual_extension_targets(total_partial_products), quotient_polys: self.add_virtual_extension_targets(common_data.num_quotient_polys()), } @@ -178,12 +191,17 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use alloc::sync::Arc; + use anyhow::Result; + use itertools::Itertools; use log::{info, Level}; use super::*; use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::FriConfig; + use crate::gadgets::lookup::{OTHER_TABLE, TIP5_TABLE}; + use crate::gates::lookup_table::LookupTable; use crate::gates::noop::NoopGate; use crate::iop::witness::{PartialWitness, WitnessWrite}; use crate::plonk::circuit_data::{CircuitConfig, VerifierOnlyCircuitData}; @@ -208,6 +226,54 @@ mod tests { Ok(()) } + #[test] + fn test_recursive_verifier_one_lookup() -> Result<()> { + init_logger(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let config = CircuitConfig::standard_recursion_zk_config(); + + let (proof, vd, cd) = dummy_lookup_proof::(&config, 10)?; + let (proof, vd, cd) = + recursive_proof::(proof, vd, cd, &config, None, true, true)?; + test_serialization(&proof, &vd, &cd)?; + + Ok(()) + } + + #[test] + fn test_recursive_verifier_two_luts() -> Result<()> { + init_logger(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let config = CircuitConfig::standard_recursion_config(); + + let (proof, vd, cd) = dummy_two_luts_proof::(&config)?; + let (proof, vd, cd) = + recursive_proof::(proof, vd, cd, &config, None, true, true)?; + test_serialization(&proof, &vd, &cd)?; + + Ok(()) + } + + #[test] + fn test_recursive_verifier_too_many_rows() -> Result<()> { + init_logger(); + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let config = CircuitConfig::standard_recursion_config(); + + let (proof, vd, cd) = dummy_too_many_rows_proof::(&config)?; + let (proof, vd, cd) = + recursive_proof::(proof, vd, cd, &config, None, true, true)?; + test_serialization(&proof, &vd, &cd)?; + + Ok(()) + } + #[test] fn test_recursive_recursive_verifier() -> Result<()> { init_logger(); @@ -339,6 +405,197 @@ mod tests { Ok((proof, data.verifier_only, data.common)) } + /// Creates a dummy lookup proof which does one lookup to one LUT. + fn dummy_lookup_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + config: &CircuitConfig, + num_dummy_gates: u64, + ) -> Result> { + let mut builder = CircuitBuilder::::new(config.clone()); + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + + let tip5_index = builder.add_lookup_table_from_pairs(table); + + let output_a = builder.add_lookup_from_index(initial_a, tip5_index); + let output_b = builder.add_lookup_from_index(initial_b, tip5_index); + + for _ in 0..num_dummy_gates + 1 { + builder.add_gate(NoopGate, vec![]); + } + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + + let data = builder.build::(); + let mut inputs = PartialWitness::new(); + inputs.set_target(initial_a, F::from_canonical_usize(look_val_a)); + inputs.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let proof = data.prove(inputs)?; + data.verify(proof.clone())?; + + assert!( + proof.public_inputs[2] == F::from_canonical_u16(out_a), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_b), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + + Ok((proof, data.verifier_only, data.common)) + } + + /// Creates a dummy lookup proof which does one lookup to two different LUTs. + fn dummy_two_luts_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + config: &CircuitConfig, + ) -> Result> { + let mut builder = CircuitBuilder::::new(config.clone()); + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + + let first_out = tip5_table[look_val_a]; + let second_out = tip5_table[look_val_b]; + + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let other_table = OTHER_TABLE.to_vec(); + + let tip5_index = builder.add_lookup_table_from_pairs(table); + let output_a = builder.add_lookup_from_index(initial_a, tip5_index); + + let output_b = builder.add_lookup_from_index(initial_b, tip5_index); + let sum = builder.add(output_a, output_b); + + let s = first_out + second_out; + let final_out = other_table[s as usize]; + + let table2: LookupTable = Arc::new((0..256).zip_eq(other_table).collect()); + + let other_index = builder.add_lookup_table_from_pairs(table2); + let output_final = builder.add_lookup_from_index(sum, other_index); + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + + builder.register_public_input(sum); + builder.register_public_input(output_a); + builder.register_public_input(output_b); + builder.register_public_input(output_final); + + let mut pw = PartialWitness::new(); + pw.set_target(initial_a, F::ONE); + pw.set_target(initial_b, F::TWO); + + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof.clone())?; + + assert!( + proof.public_inputs[3] == F::from_canonical_u16(first_out), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + assert!( + proof.public_inputs[4] == F::from_canonical_u16(second_out), + "Second lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[2] == F::from_canonical_u16(s), + "Sum between the first two LUT outputs is incorrect." + ); + assert!( + proof.public_inputs[5] == F::from_canonical_u16(final_out), + "Output of the second LUT at index {} is incorrect.", + s + ); + + Ok((proof, data.verifier_only, data.common)) + } + + /// Creates a dummy proof which has more than 256 lookups to one LUT. + fn dummy_too_many_rows_proof< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + config: &CircuitConfig, + ) -> Result> { + let mut builder = CircuitBuilder::::new(config.clone()); + + let initial_a = builder.add_virtual_target(); + let initial_b = builder.add_virtual_target(); + + let look_val_a = 1; + let look_val_b = 2; + + let tip5_table = TIP5_TABLE.to_vec(); + let table: LookupTable = Arc::new((0..256).zip_eq(tip5_table).collect()); + + let out_a = table[look_val_a].1; + let out_b = table[look_val_b].1; + + let tip5_index = builder.add_lookup_table_from_pairs(table); + let output_b = builder.add_lookup_from_index(initial_b, tip5_index); + let mut output = builder.add_lookup_from_index(initial_a, tip5_index); + for _ in 0..514 { + output = builder.add_lookup_from_index(initial_a, tip5_index); + } + + builder.register_public_input(initial_a); + builder.register_public_input(initial_b); + builder.register_public_input(output_b); + builder.register_public_input(output); + + let mut pw = PartialWitness::new(); + + pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); + pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + + let data = builder.build::(); + let proof = data.prove(pw)?; + assert!( + proof.public_inputs[2] == F::from_canonical_u16(out_b), + "First lookup, at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[1] + ); + assert!( + proof.public_inputs[3] == F::from_canonical_u16(out_a), + "Lookups at index {} in the Tip5 table gives an incorrect output.", + proof.public_inputs[0] + ); + data.verify(proof.clone())?; + + Ok((proof, data.verifier_only, data.common)) + } + fn recursive_proof< F: RichField + Extendable, C: GenericConfig, diff --git a/plonky2/src/util/serialization/gate_serialization.rs b/plonky2/src/util/serialization/gate_serialization.rs index 351fa76b7d..2d9e3e30ec 100644 --- a/plonky2/src/util/serialization/gate_serialization.rs +++ b/plonky2/src/util/serialization/gate_serialization.rs @@ -71,6 +71,8 @@ pub mod default { use crate::gates::constant::ConstantGate; use crate::gates::coset_interpolation::CosetInterpolationGate; use crate::gates::exponentiation::ExponentiationGate; + use crate::gates::lookup::LookupGate; + use crate::gates::lookup_table::LookupTableGate; use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::poseidon::PoseidonGate; @@ -92,6 +94,8 @@ pub mod default { ConstantGate, CosetInterpolationGate, ExponentiationGate, + LookupGate, + LookupTableGate, MulExtensionGate, NoopGate, PoseidonMdsGate, diff --git a/plonky2/src/util/serialization/generator_serialization.rs b/plonky2/src/util/serialization/generator_serialization.rs index 14f94acce6..cbc6d6bac2 100644 --- a/plonky2/src/util/serialization/generator_serialization.rs +++ b/plonky2/src/util/serialization/generator_serialization.rs @@ -111,6 +111,8 @@ pub mod default { use crate::gates::base_sum::BaseSplitGenerator; use crate::gates::coset_interpolation::InterpolationGenerator; use crate::gates::exponentiation::ExponentiationGenerator; + use crate::gates::lookup::LookupGenerator; + use crate::gates::lookup_table::LookupTableGenerator; use crate::gates::multiplication_extension::MulExtensionGenerator; use crate::gates::poseidon::PoseidonGenerator; use crate::gates::poseidon_mds::PoseidonMdsGenerator; @@ -147,6 +149,8 @@ pub mod default { EqualityGenerator, ExponentiationGenerator, InterpolationGenerator, + LookupGenerator, + LookupTableGenerator, LowHighGenerator, MulExtensionGenerator, NonzeroTestGenerator, diff --git a/plonky2/src/util/serialization/mod.rs b/plonky2/src/util/serialization/mod.rs index db86879d89..7c8141769c 100644 --- a/plonky2/src/util/serialization/mod.rs +++ b/plonky2/src/util/serialization/mod.rs @@ -5,6 +5,7 @@ pub mod generator_serialization; pub mod gate_serialization; use alloc::collections::BTreeMap; +use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use core::convert::Infallible; @@ -30,6 +31,7 @@ use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::{FriConfig, FriParams}; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::gates::gate::GateRef; +use crate::gates::lookup::Lookup; use crate::gates::selectors::SelectorsInfo; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::{MerkleProof, MerkleProofTarget}; @@ -38,6 +40,7 @@ use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::WitnessGeneratorRef; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; +use crate::plonk::circuit_builder::LookupWire; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, @@ -112,6 +115,14 @@ pub trait Read { Ok(buf[0]) } + /// Reads a `u16` value from `self`. + #[inline] + fn read_u16(&mut self) -> IoResult { + let mut buf = [0; size_of::()]; + self.read_exact(&mut buf)?; + Ok(u16::from_le_bytes(buf)) + } + /// Reads a `u32` value from `self`. #[inline] fn read_u32(&mut self) -> IoResult { @@ -334,6 +345,8 @@ pub trait Read { let wires = self.read_field_ext_vec::(config.num_wires)?; let plonk_zs = self.read_field_ext_vec::(config.num_challenges)?; let plonk_zs_next = self.read_field_ext_vec::(config.num_challenges)?; + let lookup_zs = self.read_field_ext_vec::(common_data.num_all_lookup_polys())?; + let lookup_zs_next = self.read_field_ext_vec::(common_data.num_all_lookup_polys())?; let partial_products = self .read_field_ext_vec::(common_data.num_partial_products * config.num_challenges)?; let quotient_polys = self.read_field_ext_vec::( @@ -347,6 +360,8 @@ pub trait Read { plonk_zs_next, partial_products, quotient_polys, + lookup_zs, + lookup_zs_next, }) } @@ -358,6 +373,8 @@ pub trait Read { let wires = self.read_target_ext_vec::()?; let plonk_zs = self.read_target_ext_vec::()?; let plonk_zs_next = self.read_target_ext_vec::()?; + let lookup_zs = self.read_target_ext_vec::()?; + let next_lookup_zs = self.read_target_ext_vec::()?; let partial_products = self.read_target_ext_vec::()?; let quotient_polys = self.read_target_ext_vec::()?; @@ -367,6 +384,8 @@ pub trait Read { wires, plonk_zs, plonk_zs_next, + lookup_zs, + next_lookup_zs, partial_products, quotient_polys, }) @@ -422,7 +441,9 @@ pub trait Read { evals_proofs.push((wires_v, wires_p)); let zs_partial_v = self.read_field_vec( - config.num_challenges * (1 + common_data.num_partial_products) + salt, + config.num_challenges + * (1 + common_data.num_partial_products + common_data.num_lookup_polys) + + salt, )?; let zs_partial_p = self.read_merkle_proof()?; evals_proofs.push((zs_partial_v, zs_partial_p)); @@ -740,6 +761,15 @@ pub trait Read { let num_partial_products = self.read_usize()?; + let num_lookup_polys = self.read_usize()?; + let num_lookup_selectors = self.read_usize()?; + let length = self.read_usize()?; + let mut luts = Vec::with_capacity(length); + + for _ in 0..length { + luts.push(Arc::new(self.read_lut()?)); + } + Ok(CommonCircuitData { config, fri_params, @@ -751,6 +781,9 @@ pub trait Read { num_public_inputs, k_is, num_partial_products, + num_lookup_polys, + num_lookup_selectors, + luts, }) } @@ -825,6 +858,22 @@ pub trait Read { let circuit_digest = self.read_hash::>::Hasher>()?; + let length = self.read_usize()?; + let mut lookup_rows = Vec::with_capacity(length); + for _ in 0..length { + lookup_rows.push(LookupWire { + last_lu_gate: self.read_usize()?, + last_lut_gate: self.read_usize()?, + first_lut_gate: self.read_usize()?, + }); + } + + let length = self.read_usize()?; + let mut lut_to_lookups = Vec::with_capacity(length); + for _ in 0..length { + lut_to_lookups.push(self.read_target_lut()?); + } + Ok(ProverOnlyCircuitData { generators, generator_indices_by_watches, @@ -835,6 +884,8 @@ pub trait Read { representative_map, fft_root_table, circuit_digest, + lookup_rows, + lut_to_lookups, }) } @@ -1089,6 +1140,30 @@ pub trait Read { public_inputs, }) } + + /// Reads a lookup table stored as `Vec<(u16, u16)>` from `self`. + #[inline] + fn read_lut(&mut self) -> IoResult> { + let length = self.read_usize()?; + let mut lut = Vec::with_capacity(length); + for _ in 0..length { + lut.push((self.read_u16()?, self.read_u16()?)); + } + + Ok(lut) + } + + /// Reads a target lookup table stored as `Lookup` from `self`. + #[inline] + fn read_target_lut(&mut self) -> IoResult { + let length = self.read_usize()?; + let mut lut = Vec::with_capacity(length); + for _ in 0..length { + lut.push((self.read_target()?, self.read_target()?)); + } + + Ok(lut) + } } /// Writing @@ -1128,6 +1203,12 @@ pub trait Write { self.write_all(&[x]) } + /// Writes a word `x` to `self`. + #[inline] + fn write_u16(&mut self, x: u16) -> IoResult<()> { + self.write_all(&x.to_le_bytes()) + } + /// Writes a word `x` to `self.` #[inline] fn write_u32(&mut self, x: u32) -> IoResult<()> { @@ -1334,6 +1415,8 @@ pub trait Write { self.write_field_ext_vec::(&os.wires)?; self.write_field_ext_vec::(&os.plonk_zs)?; self.write_field_ext_vec::(&os.plonk_zs_next)?; + self.write_field_ext_vec::(&os.lookup_zs)?; + self.write_field_ext_vec::(&os.lookup_zs_next)?; self.write_field_ext_vec::(&os.partial_products)?; self.write_field_ext_vec::(&os.quotient_polys) } @@ -1349,6 +1432,8 @@ pub trait Write { self.write_target_ext_vec::(&os.wires)?; self.write_target_ext_vec::(&os.plonk_zs)?; self.write_target_ext_vec::(&os.plonk_zs_next)?; + self.write_target_ext_vec::(&os.lookup_zs)?; + self.write_target_ext_vec::(&os.next_lookup_zs)?; self.write_target_ext_vec::(&os.partial_products)?; self.write_target_ext_vec::(&os.quotient_polys) } @@ -1664,6 +1749,9 @@ pub trait Write { num_public_inputs, k_is, num_partial_products, + num_lookup_polys, + num_lookup_selectors, + luts, } = common_data; self.write_circuit_config(config)?; @@ -1685,6 +1773,13 @@ pub trait Write { self.write_usize(*num_partial_products)?; + self.write_usize(*num_lookup_polys)?; + self.write_usize(*num_lookup_selectors)?; + self.write_usize(luts.len())?; + for lut in luts.iter() { + self.write_lut(lut)?; + } + Ok(()) } @@ -1722,6 +1817,8 @@ pub trait Write { representative_map, fft_root_table, circuit_digest, + lookup_rows, + lut_to_lookups, } = prover_only_circuit_data; self.write_usize(generators.len())?; @@ -1760,6 +1857,18 @@ pub trait Write { self.write_hash::>::Hasher>(*circuit_digest)?; + self.write_usize(lookup_rows.len())?; + for wire in lookup_rows.iter() { + self.write_usize(wire.last_lu_gate)?; + self.write_usize(wire.last_lut_gate)?; + self.write_usize(wire.first_lut_gate)?; + } + + self.write_usize(lut_to_lookups.len())?; + for tlut in lut_to_lookups.iter() { + self.write_target_lut(tlut)?; + } + Ok(()) } @@ -1962,6 +2071,30 @@ pub trait Write { self.write_compressed_proof(proof)?; self.write_field_vec(public_inputs) } + + /// Writes a lookup table to `self`. + #[inline] + fn write_lut(&mut self, lut: &[(u16, u16)]) -> IoResult<()> { + self.write_usize(lut.len())?; + for (a, b) in lut.iter() { + self.write_u16(*a)?; + self.write_u16(*b)?; + } + + Ok(()) + } + + /// Writes a target lookup table to `self`. + #[inline] + fn write_target_lut(&mut self, lut: &[(Target, Target)]) -> IoResult<()> { + self.write_usize(lut.len())?; + for (a, b) in lut.iter() { + self.write_target(*a)?; + self.write_target(*b)?; + } + + Ok(()) + } } impl Write for Vec {