Skip to content

Commit

Permalink
Memory16 (#1871)
Browse files Browse the repository at this point in the history
Extracted from #1790

This PR adds std and witgen support for the new 16-bit limb memory
machine.

- The current version only supports 24-bit addresses. I think it's fine
for now, but we should fix it later.
- Github fails miserably at showing the proper diff, but:
- the new file `double_sorted_32.rs` is the same as the old unique
`double_sorted.rs`, minus the common parts which
- were left in the outermost `double_sorted.rs` which dispatches calls
depending on the field
- the new file `double_sorted_16.rs` is very similar to the 32 case, but
adjusted for 2 value fields.
- There are probably more common things we can extract, but this here is
enough for a first version.

---------

Co-authored-by: Leo Alt <[email protected]>
Co-authored-by: Georg Wiese <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 5d1a2fa commit 3364d91
Show file tree
Hide file tree
Showing 10 changed files with 808 additions and 23 deletions.
559 changes: 559 additions & 0 deletions executor/src/witgen/machines/double_sorted_witness_machine_16.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn split_column_name(name: &str) -> (&str, &str) {

/// TODO make this generic

pub struct DoubleSortedWitnesses<'a, T: FieldElement> {
pub struct DoubleSortedWitnesses32<'a, T: FieldElement> {
degree_range: DegreeRange,
degree: DegreeType,
//key_col: String,
Expand Down Expand Up @@ -74,7 +74,7 @@ struct Operation<T> {
pub selector_id: PolyID,
}

impl<'a, T: FieldElement> DoubleSortedWitnesses<'a, T> {
impl<'a, T: FieldElement> DoubleSortedWitnesses32<'a, T> {
fn namespaced(&self, name: &str) -> String {
format!("{}::{}", self.namespace, name)
}
Expand Down Expand Up @@ -165,7 +165,7 @@ impl<'a, T: FieldElement> DoubleSortedWitnesses<'a, T> {

if !parts.prover_functions.is_empty() {
log::warn!(
"DoubleSortedWitness machine does not support prover functions.\
"DoubleSortedWitness32 machine does not support prover functions.\
The following prover functions are ignored:\n{}",
parts.prover_functions.iter().format("\n")
);
Expand All @@ -187,7 +187,7 @@ impl<'a, T: FieldElement> DoubleSortedWitnesses<'a, T> {
}
}

impl<'a, T: FieldElement> Machine<'a, T> for DoubleSortedWitnesses<'a, T> {
impl<'a, T: FieldElement> Machine<'a, T> for DoubleSortedWitnesses32<'a, T> {
fn identity_ids(&self) -> Vec<u64> {
self.selector_ids.keys().cloned().collect()
}
Expand Down Expand Up @@ -353,7 +353,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for DoubleSortedWitnesses<'a, T> {
}
}

impl<'a, T: FieldElement> DoubleSortedWitnesses<'a, T> {
impl<'a, T: FieldElement> DoubleSortedWitnesses32<'a, T> {
fn process_plookup_internal(
&mut self,
identity_id: u64,
Expand Down
18 changes: 13 additions & 5 deletions executor/src/witgen/machines/machine_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::collections::{BTreeMap, HashSet};
use itertools::Itertools;

use super::block_machine::BlockMachine;
use super::double_sorted_witness_machine::DoubleSortedWitnesses;
use super::double_sorted_witness_machine_16::DoubleSortedWitnesses16;
use super::double_sorted_witness_machine_32::DoubleSortedWitnesses32;
use super::fixed_lookup_machine::FixedLookup;
use super::sorted_witness_machine::SortedWitnesses;
use super::FixedData;
Expand Down Expand Up @@ -217,13 +218,20 @@ fn build_machine<'a, T: FieldElement>(
{
log::debug!("Detected machine: sorted witnesses / write-once memory");
KnownMachine::SortedWitnesses(machine)
} else if let Some(machine) = DoubleSortedWitnesses::try_new(
name_with_type("DoubleSortedWitnesses"),
} else if let Some(machine) = DoubleSortedWitnesses16::try_new(
name_with_type("DoubleSortedWitnesses16"),
fixed_data,
&machine_parts,
) {
log::debug!("Detected machine: memory");
KnownMachine::DoubleSortedWitnesses(machine)
log::debug!("Detected machine: memory16");
KnownMachine::DoubleSortedWitnesses16(machine)
} else if let Some(machine) = DoubleSortedWitnesses32::try_new(
name_with_type("DoubleSortedWitnesses32"),
fixed_data,
&machine_parts,
) {
log::debug!("Detected machine: memory32");
KnownMachine::DoubleSortedWitnesses32(machine)
} else if let Some(machine) = WriteOnceMemory::try_new(
name_with_type("WriteOnceMemory"),
fixed_data,
Expand Down
23 changes: 16 additions & 7 deletions executor/src/witgen/machines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use powdr_number::FieldElement;
use crate::Identity;

use self::block_machine::BlockMachine;
use self::double_sorted_witness_machine::DoubleSortedWitnesses;
use self::double_sorted_witness_machine_16::DoubleSortedWitnesses16;
use self::double_sorted_witness_machine_32::DoubleSortedWitnesses32;
pub use self::fixed_lookup_machine::FixedLookup;
use self::profiling::{record_end, record_start};
use self::sorted_witness_machine::SortedWitnesses;
Expand All @@ -20,7 +21,8 @@ use super::rows::RowPair;
use super::{EvalResult, FixedData, MutableState, QueryCallback};

mod block_machine;
mod double_sorted_witness_machine;
mod double_sorted_witness_machine_16;
mod double_sorted_witness_machine_32;
mod fixed_lookup_machine;
pub mod machine_extractor;
pub mod profiling;
Expand Down Expand Up @@ -71,7 +73,8 @@ pub trait Machine<'a, T: FieldElement>: Send + Sync {
/// which requires that all lifetime parameters are 'static.
pub enum KnownMachine<'a, T: FieldElement> {
SortedWitnesses(SortedWitnesses<'a, T>),
DoubleSortedWitnesses(DoubleSortedWitnesses<'a, T>),
DoubleSortedWitnesses16(DoubleSortedWitnesses16<'a, T>),
DoubleSortedWitnesses32(DoubleSortedWitnesses32<'a, T>),
WriteOnceMemory(WriteOnceMemory<'a, T>),
BlockMachine(BlockMachine<'a, T>),
Vm(Generator<'a, T>),
Expand All @@ -89,7 +92,10 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
KnownMachine::SortedWitnesses(m) => {
m.process_plookup(mutable_state, identity_id, caller_rows)
}
KnownMachine::DoubleSortedWitnesses(m) => {
KnownMachine::DoubleSortedWitnesses16(m) => {
m.process_plookup(mutable_state, identity_id, caller_rows)
}
KnownMachine::DoubleSortedWitnesses32(m) => {
m.process_plookup(mutable_state, identity_id, caller_rows)
}
KnownMachine::WriteOnceMemory(m) => {
Expand All @@ -108,7 +114,8 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
fn name(&self) -> &str {
match self {
KnownMachine::SortedWitnesses(m) => m.name(),
KnownMachine::DoubleSortedWitnesses(m) => m.name(),
KnownMachine::DoubleSortedWitnesses16(m) => m.name(),
KnownMachine::DoubleSortedWitnesses32(m) => m.name(),
KnownMachine::WriteOnceMemory(m) => m.name(),
KnownMachine::BlockMachine(m) => m.name(),
KnownMachine::Vm(m) => m.name(),
Expand All @@ -122,7 +129,8 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
) -> HashMap<String, Vec<T>> {
match self {
KnownMachine::SortedWitnesses(m) => m.take_witness_col_values(mutable_state),
KnownMachine::DoubleSortedWitnesses(m) => m.take_witness_col_values(mutable_state),
KnownMachine::DoubleSortedWitnesses16(m) => m.take_witness_col_values(mutable_state),
KnownMachine::DoubleSortedWitnesses32(m) => m.take_witness_col_values(mutable_state),
KnownMachine::WriteOnceMemory(m) => m.take_witness_col_values(mutable_state),
KnownMachine::BlockMachine(m) => m.take_witness_col_values(mutable_state),
KnownMachine::Vm(m) => m.take_witness_col_values(mutable_state),
Expand All @@ -133,7 +141,8 @@ impl<'a, T: FieldElement> Machine<'a, T> for KnownMachine<'a, T> {
fn identity_ids(&self) -> Vec<u64> {
match self {
KnownMachine::SortedWitnesses(m) => m.identity_ids(),
KnownMachine::DoubleSortedWitnesses(m) => m.identity_ids(),
KnownMachine::DoubleSortedWitnesses16(m) => m.identity_ids(),
KnownMachine::DoubleSortedWitnesses32(m) => m.identity_ids(),
KnownMachine::WriteOnceMemory(m) => m.identity_ids(),
KnownMachine::BlockMachine(m) => m.identity_ids(),
KnownMachine::Vm(m) => m.identity_ids(),
Expand Down
12 changes: 6 additions & 6 deletions pipeline/tests/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,25 @@ fn vm_to_block_multiple_links() {
#[test]
fn mem_read_write() {
let f = "asm/mem_read_write.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

#[test]
fn mem_read_write_no_memory_accesses() {
let f = "asm/mem_read_write_no_memory_accesses.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

#[test]
fn mem_read_write_with_bootloader() {
let f = "asm/mem_read_write_with_bootloader.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

#[test]
fn mem_read_write_large_diffs() {
let f = "asm/mem_read_write_large_diffs.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

#[test]
Expand Down Expand Up @@ -440,7 +440,7 @@ fn vm_args() {
#[test]
fn vm_args_memory() {
let f = "asm/vm_args_memory.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

#[test]
Expand All @@ -452,7 +452,7 @@ fn vm_args_relative_path() {
#[test]
fn vm_args_two_levels() {
let f = "asm/vm_args_two_levels.asm";
regular_test(f, Default::default());
regular_test_without_babybear(f, Default::default());
}

mod reparse {
Expand Down
7 changes: 7 additions & 0 deletions pipeline/tests/powdr_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ fn memory_test_parallel_accesses() {
regular_test(f, &[]);
}

#[test]
#[ignore = "Too slow"]
fn memory16_test() {
let f = "std/memory16_test.asm";
test_plonky3_with_backend_variant::<BabyBearField>(f, vec![], BackendVariant::Composite);
}

#[test]
fn permutation_via_challenges_bn() {
let f = "std/permutation_via_challenges.asm";
Expand Down
116 changes: 116 additions & 0 deletions std/machines/memory16.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::array;
use std::field::modulus;
use std::check::assert;
use std::machines::range::Bit12;
use std::machines::range::Byte2;

// A read/write memory, similar to that of Polygon:
// https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/mem.pil
machine Memory16(bit12: Bit12, byte2: Byte2) with
latch: LATCH,
operation_id: m_is_write,
call_selectors: selectors,
{
// We compute m_diff (28-Bit) + m_step (28-Bit) + 1, which fits into 29 Bits.
assert(modulus() > 2**29, || "Memory16 requires a field that fits any 29-Bit value.");

operation mload<0> m_addr_high, m_addr_low, m_step -> m_value1, m_value2;
operation mstore<1> m_addr_high, m_addr_low, m_step, m_value1, m_value2 ->;

let LATCH = 1;

// =============== read-write memory =======================
// Read-write memory. Columns are sorted by addr and
// then by step. change is 1 if and only if addr changes
// in the next row.
// Note that these column names are used by witgen to detect
// this machine...
col witness m_addr_high, m_addr_low;
col witness m_step_high, m_step_low;
col witness m_change;
col witness m_value1, m_value2;

link => bit12.check(m_step_high);
link => byte2.check(m_step_low);
let m_step = m_step_high * 2**16 + m_step_low;

link => byte2.check(m_value1);
link => byte2.check(m_value2);

// Memory operation flags
col witness m_is_write;
std::utils::force_bool(m_is_write);

// is_write can only be 1 if a selector is active
let is_mem_op = array::sum(selectors);
std::utils::force_bool(is_mem_op);
(1 - is_mem_op) * m_is_write = 0;

// If the next line is a not a write and we have an address change,
// then the value is zero.
(1 - m_is_write') * m_change * m_value1' = 0;
(1 - m_is_write') * m_change * m_value2' = 0;

// change has to be 1 in the last row, so that a first read on row zero is constrained to return 0
(1 - m_change) * LAST = 0;

// If the next line is a read and we stay at the same address, then the
// value cannot change.
(1 - m_is_write') * (1 - m_change) * (m_value1' - m_value1) = 0;
(1 - m_is_write') * (1 - m_change) * (m_value2' - m_value2) = 0;

col fixed FIRST = [1] + [0]*;
let LAST = FIRST';

std::utils::force_bool(m_change);

// if change is zero, addr has to stay the same.
(m_addr_low' - m_addr_low) * (1 - m_change) = 0;
(m_addr_high' - m_addr_high) * (1 - m_change) = 0;

// Except for the last row, if m_change is 1, then addr has to increase,
// if it is zero, step has to increase.
// The diff has to be equal to the difference **minus one**.

// These two helper columns have different semantics, depending on
// whether we're comparing addresses or time steps.
// In both cases, m_tmp2 needs to be of 16 Bits.
col witness m_tmp1, m_tmp2;
link => byte2.check(m_diff);

// When comparing time steps, a 28-Bit diff is sufficient assuming a maximum step
// of 2**28.
// The difference is computed on the field, which is larger than 2**28.
// We prove that m_step' - m_step > 0 by letting the prover provide a 28-Bit value
// such that claimed_diff + 1 == m_step' - m_step.
// Because all values are constrained to be 28-Bit, no overflow can occur.
let m_diff_upper = m_tmp1;
let m_diff_lower = m_tmp2;
link if (1 - m_change) => bit12.check(m_diff_upper);
let claimed_time_step_diff = m_diff_upper * 2**16 + m_diff_lower;
let actual_time_step_diff = m_step' - m_step;
(1 - m_change) * (claimed_time_step_diff + 1 - actual_time_step_diff) = 0;

// When comparing addresses, we let the prover indicate whether the upper or lower
// limb needs to be compared and then assert that the diff is positive.
let address_high_unequal = m_tmp1;
let m_diff = m_tmp2;

// address_high_unequal is binary.
m_change * address_high_unequal * (address_high_unequal - 1) = 0;

// Whether to do any comparison.
// We want to compare whenever m_change == 1, but not in the last row.
// Because we constrained m_change to be 1 in the last row, this will just
// be equal to m_change, except that the last entry is 0.
// (`m_change * (1 - LAST)` would be the same, but of higher degree.)
let do_comparison = m_change - LAST;

// If address_high_unequal is 0, the higher limbs should be equal.
do_comparison * (1 - address_high_unequal) * (m_addr_high' - m_addr_high) = 0;

// Assert that m_diff stores the actual diff - 1.
let actual_addr_limb_diff = address_high_unequal * (m_addr_high' - m_addr_high)
+ (1 - address_high_unequal) * (m_addr_low' - m_addr_low);
do_comparison * (m_diff + 1 - actual_addr_limb_diff) = 0;
}
1 change: 1 addition & 0 deletions std/machines/mod.asm
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod binary_bb;
mod range;
mod hash;
mod memory;
mod memory16;
mod memory_with_bootloader_write;
mod shift;
mod shift16;
Expand Down
12 changes: 12 additions & 0 deletions std/machines/range.asm
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ machine Bit7 with
col fixed latch = [1]*;
col fixed operation_id = [0]*;
}

machine Bit12 with
latch: latch,
operation_id: operation_id,
degree: 4096
{
operation check<0> BIT12 -> ;

let BIT12: col = |i| i % (2**12);
let latch = 1;
col fixed operation_id = [0]*;
}
Loading

0 comments on commit 3364d91

Please sign in to comment.