Add a trim algorithm to lib.rs to prevent Polys from becoming unbearably gigantic

Our Poly algorithm is incredibly leaky. While it presumably should be improved,
we can take advantage of our known structure while constructing divisors (and
the small modulus) to simply trim out the zero coefficients leaked. This
maintains Polys in a manageable size.
This commit is contained in:
Luke Parker
2024-09-24 02:15:06 -04:00
parent 1be9084119
commit 1ea7cb8b5b

View File

@@ -185,6 +185,8 @@ pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement
None?;
}
let points_len = points.len();
// Create the initial set of divisors
let mut divs = vec![];
let mut iter = points.iter().copied();
@@ -194,11 +196,35 @@ pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement
// Draw the line between those points
// These unwraps are branching on the length of the iterator, not violating the constant-time
// priorites desired
divs.push((a + b.unwrap_or(C::identity()), line::<C>(a, b.unwrap_or(-a))));
divs.push((2, a + b.unwrap_or(C::identity()), line::<C>(a, b.unwrap_or(-a))));
}
let modulus = C::divisor_modulus();
// Our Poly algorithm is leaky and will create an excessive amount of y x**j and x**j
// coefficients which are zero, yet as our implementation is constant time, still come with
// an immense performance cost. This code truncates the coefficients we know are zero.
let trim = |divisor: &mut Poly<_>, points_len: usize| {
// We should only be trimming divisors reduced by the modulus
debug_assert!(divisor.yx_coefficients.len() <= 1);
if divisor.yx_coefficients.len() == 1 {
let truncate_to = ((points_len + 1) / 2).saturating_sub(2);
#[cfg(debug_assertions)]
for p in truncate_to .. divisor.yx_coefficients[0].len() {
debug_assert_eq!(divisor.yx_coefficients[0][p], <C::FieldElement as Field>::ZERO);
}
divisor.yx_coefficients[0].truncate(truncate_to);
}
{
let truncate_to = points_len / 2;
#[cfg(debug_assertions)]
for p in truncate_to .. divisor.x_coefficients.len() {
debug_assert_eq!(divisor.x_coefficients[p], <C::FieldElement as Field>::ZERO);
}
divisor.x_coefficients.truncate(truncate_to);
}
};
// Pair them off until only one remains
while divs.len() > 1 {
let mut next_divs = vec![];
@@ -207,23 +233,28 @@ pub fn new_divisor<C: DivisorCurve>(points: &[C]) -> Option<Poly<C::FieldElement
next_divs.push(divs.pop().unwrap());
}
while let Some((a, a_div)) = divs.pop() {
let (b, b_div) = divs.pop().unwrap();
while let Some((a_points, a, a_div)) = divs.pop() {
let (b_points, b, b_div) = divs.pop().unwrap();
let points = a_points + b_points;
// Merge the two divisors
let numerator = a_div.mul_mod(&b_div, &modulus).mul_mod(&line::<C>(a, b), &modulus);
let denominator = line::<C>(a, -a).mul_mod(&line::<C>(b, -b), &modulus);
let (q, r) = numerator.div_rem(&denominator);
let (mut q, r) = numerator.div_rem(&denominator);
debug_assert_eq!(r, Poly::zero());
next_divs.push((a + b, q));
trim(&mut q, 1 + points);
next_divs.push((points, a + b, q));
}
divs = next_divs;
}
// Return the unified divisor
Some(divs.remove(0).1)
let mut divisor = divs.remove(0).2;
trim(&mut divisor, points_len);
Some(divisor)
}
#[cfg(any(test, feature = "pasta"))]