Skip to content

Commit

Permalink
Merge pull request #2071 from fredrik-johansson/toom
Browse files Browse the repository at this point in the history
Generic Toom-3 multiplication for gr_poly
  • Loading branch information
fredrik-johansson authored Sep 18, 2024
2 parents 1f98483 + 6724256 commit 9500186
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 1 deletion.
13 changes: 13 additions & 0 deletions doc/source/gr_poly.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ Arithmetic
algorithm with `O(n^{1.6})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch
to :func:`_gr_poly_mul_karatsuba` above some cutoff.

.. function:: int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

Balanced Toom-3 multiplication with interpolation in five points,
using the Bodrato evaluation scheme. Assumes commutativity and that the ring
supports exact division by 2 and 3.
Not optimized for squaring.
The underscore method requires positive lengths and does not support aliasing.
This function calls :func:`_gr_poly_mul` recursively rather than itself, so to get a recursive
algorithm with `O(n^{1.5})` complexity, the ring must overload :func:`_gr_poly_mul` to dispatch
to :func:`_gr_poly_mul_toom33` above some cutoff.


Powering
--------------------------------------------------------------------------------

Expand Down
3 changes: 2 additions & 1 deletion src/gr_poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ WARN_UNUSED_RESULT int gr_poly_mul_scalar(gr_poly_t res, const gr_poly_t poly, g

WARN_UNUSED_RESULT int _gr_poly_mul_karatsuba(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_karatsuba(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

WARN_UNUSED_RESULT int _gr_poly_mul_toom33(gr_ptr res, gr_srcptr poly1, slong len1, gr_srcptr poly2, slong len2, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx);

/* powering */

Expand Down
207 changes: 207 additions & 0 deletions src/gr_poly/mul_toom33.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
Copyright (C) 2007 Marco Bodrato
Copyright (C) 2024 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "gr_vec.h"
#include "gr_poly.h"

/*
Toom33 (interpolation in 5 points) using Bodrato scheme
http://marco.bodrato.it/papers/Bodrato2007-OptimalToomCookMultiplicationForBinaryFieldAndIntegers.pdf
Assumes commutativity, division by 3.
Todo: squaring version.
Todo: skip unnecessary zero-extensions of vectors and tighten
allocations.
*/
int
_gr_poly_mul_toom33(gr_ptr res, gr_srcptr f, slong flen, gr_srcptr g, slong glen, gr_ctx_t ctx)
{
gr_srcptr U0, U1, U2, V0, V1, V2;
gr_ptr tmp, W0, W1, W2, W3, W4;
slong m, U2len, V2len, U1len, V1len, U0len, V0len, rlen, len;
slong W4len;
slong sz = ctx->sizeof_elem;
slong alloc;
int status = GR_SUCCESS;

/* TODO: should explicitly call basecase mul. */
if (flen <= 1 || glen <= 1)
return _gr_poly_mullow_generic(res, f, flen, g, glen, flen + glen - 1, ctx);

/* U = U2*x^(2m) + U1*x^m + U0 */
/* V = V2*x^(2m) + V1*x^m + V0 */
/* Each block has length m */
m = FLINT_MAX(flen, glen);
m = (m + 3 - 1) / 3;
U0 = f;
U1 = GR_ENTRY(f, m, sz);
U2 = GR_ENTRY(f, 2 * m, sz);
V0 = g;
V1 = GR_ENTRY(g, m, sz);
V2 = GR_ENTRY(g, 2 * m, sz);

U2len = FLINT_MAX(flen - 2 * m, 0);
V2len = FLINT_MAX(glen - 2 * m, 0);
U1len = FLINT_MIN(FLINT_MAX(flen - m, 0), m);
V1len = FLINT_MIN(FLINT_MAX(glen - m, 0), m);
U0len = FLINT_MIN(flen, m);
V0len = FLINT_MIN(glen, m);

alloc = 10 * m;
GR_TMP_INIT_VEC(tmp, alloc, ctx);
W0 = tmp;
W1 = GR_ENTRY(W0, 2 * m, sz);
W2 = GR_ENTRY(W1, 2 * m, sz);
W3 = GR_ENTRY(W2, 2 * m, sz);
W4 = GR_ENTRY(W3, 2 * m, sz);

/* Evaluation: 5*2 add, 2 shift; 5mul */
/* W0 = U2 + U0 */
/* if max(U2len,U0len) < m, assumes top coefficients are already zeroed from the initialization */
status |= _gr_poly_add(W0, U2, U2len, U0, U0len, ctx);
/* W4 = V2 + V0 */
/* if max(V2len,V0len) < m, assumes top coefficients are already zeroed from the initialization */
status |= _gr_poly_add(W4, V2, V2len, V0, V0len, ctx);
/* W2 = W0 - U1 */
status |= _gr_poly_sub(W2, W0, m, U1, U1len, ctx);
/* W1 = W4 - V1 */
status |= _gr_poly_sub(W1, W4, m, V1, V1len, ctx);
/* W0 = W0 + U1 */
status |= _gr_poly_add(W0, W0, m, U1, U1len, ctx);
/* W4 = W4 + V1 */
status |= _gr_poly_add(W4, W4, m, V1, V1len, ctx);
/* W3 = W2 * W1 */
status |= _gr_poly_mul(W3, W2, m, W1, m, ctx);
/* W1 = W0 * W4 */
status |= _gr_poly_mul(W1, W0, m, W4, m, ctx);
/* W0 = ((W0 + U2) << 1) - U0 */
status |= _gr_poly_add(W0, W0, m, U2, U2len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W0, W0, m, 1, ctx);
status |= _gr_poly_sub(W0, W0, m, U0, U0len, ctx);
/* W4 = ((W4 + V2) << 1) - V0 */
status |= _gr_poly_add(W4, W4, m, V2, V2len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W4, W4, m, 1, ctx);
status |= _gr_poly_sub(W4, W4, m, V0, V0len, ctx);
/* W2 = W0 * W4 */
status |= _gr_poly_mul(W2, W0, m, W4, m, ctx);
/* W0 = U0 * V0 */
if (U0len > 0 && V0len > 0)
{
status |= _gr_poly_mul(W0, U0, U0len, V0, V0len, ctx);
status |= _gr_vec_zero(GR_ENTRY(W0, U0len + V0len - 1, sz), 2 * m - (U0len + V0len - 1), ctx);
}
else
status |= _gr_vec_zero(W0, 2 * m, ctx);
/* W4 = U2 * V2 */
/* We compute this length accurately instead of zero-extending. */
if (U2len > 0 && V2len > 0)
{
W4len = U2len + V2len - 1;
status |= _gr_poly_mul(W4, U2, U2len, V2, V2len, ctx);
}
else
{
W4len = 0;
}

/* toom42 variant */
/* U = U3*x^(3m) + U2*x^(2m) + U1*x^m + U0 */
/* V = V1*x^m + V0 */
/* Evaluation: 7+3 add, 3 shift; 5mul */
/*
W0 = U1 + U3;
W4 = U0 + U2;
W3 = W4 + W0;
W4 = W4 - W0;
W0 = V0 + V1;
W2 = V0 - V1;
W1 = W3 * W0;
W3 = W4 * W2;
W4 = (((((U3<<1) + U2) << 1) + U1) << 1) + U0;
W0 = W0 + V1;
W2 = W4 * W0;
W0 = U0 * V0;
W4 = U3 * V1;
*/

/* Interpolation: 8 add, 3 shift, 1 Sdiv */
len = 2 * m - 1;
/* W2 = (W2 - W3) / 3 */
status |= _gr_vec_sub(W2, W2, W3, len, ctx);
status |= _gr_vec_divexact_scalar_ui(W2, W2, len, 3, ctx);
/* W3 = (W1 - W3) >> 1 */
status |= _gr_vec_sub(W3, W1, W3, len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W3, W3, len, -1, ctx);
/* W1 = W1 - W0 */
status |= _gr_vec_sub(W1, W1, W0, len, ctx);
/* W2 = ((W2 - W1) >> 1) - (W4 << 1) */
status |= _gr_vec_sub(W2, W2, W1, len, ctx);
status |= _gr_vec_mul_scalar_2exp_si(W2, W2, len, -1, ctx);
status |= _gr_vec_mul_scalar_2exp_si(res, W4, W4len, 1, ctx);
status |= _gr_vec_sub(W2, W2, res, W4len, ctx);
/* W1 = W1 - W3 - W4 */
status |= _gr_vec_sub(W1, W1, W3, len, ctx);
status |= _gr_poly_sub(W1, W1, len, W4, W4len, ctx);
/* W3 = W3 - W2 */
status |= _gr_vec_sub(W3, W3, W2, len, ctx);

/* Recomposition: */
/* W = W4 * x^(4m) + W2*x^(3m) + W1*x^(2m) + W*x^m + W0 */

rlen = flen + glen - 1;
len = FLINT_MIN(rlen, m);
status |= _gr_vec_set(res, W0, FLINT_MIN(rlen, m), ctx);
len = FLINT_MIN(rlen - m, m);
status |= _gr_vec_add(GR_ENTRY(res, m, sz), W3, GR_ENTRY(W0, m, sz), len, ctx);
len = FLINT_MIN(rlen - 2 * m, m);
status |= _gr_vec_add(GR_ENTRY(res, 2 * m, sz), W1, GR_ENTRY(W3, m, sz), len, ctx);
len = FLINT_MIN(rlen - 3 * m, m);
status |= _gr_vec_add(GR_ENTRY(res, 3 * m, sz), W2, GR_ENTRY(W1, m, sz), len, ctx);
len = FLINT_MIN(rlen - 4 * m, m);
status |= _gr_poly_add(GR_ENTRY(res, 4 * m, sz), W4, FLINT_MIN(W4len, len), GR_ENTRY(W2, m, sz), len, ctx);
len = FLINT_MIN(rlen - 5 * m, m);
status |= _gr_vec_set(GR_ENTRY(res, 5 * m, sz), GR_ENTRY(W4, m, sz), len, ctx);

GR_TMP_CLEAR_VEC(tmp, alloc, ctx);

return status;
}

int
gr_poly_mul_toom33(gr_poly_t res, const gr_poly_t poly1, const gr_poly_t poly2, gr_ctx_t ctx)
{
slong len_out;
int status;

if (poly1->length == 0 || poly2->length == 0)
return gr_poly_zero(res, ctx);

len_out = poly1->length + poly2->length - 1;

if (res == poly1 || res == poly2)
{
gr_poly_t t;
gr_poly_init2(t, len_out, ctx);
status = _gr_poly_mul_toom33(t->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
gr_poly_swap(res, t, ctx);
gr_poly_clear(t, ctx);
}
else
{
gr_poly_fit_length(res, len_out, ctx);
status = _gr_poly_mul_toom33(res->coeffs, poly1->coeffs, poly1->length, poly2->coeffs, poly2->length, ctx);
}

_gr_poly_set_length(res, len_out, ctx);
_gr_poly_normalise(res, ctx);
return status;
}
2 changes: 2 additions & 0 deletions src/gr_poly/test/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "t-log_series.c"
#include "t-make_monic.c"
#include "t-mul_karatsuba.c"
#include "t-mul_toom33.c"
#include "t-nth_derivative.c"
#include "t-pow_series_fmpq.c"
#include "t-pow_series_ui.c"
Expand Down Expand Up @@ -106,6 +107,7 @@ test_struct tests[] =
TEST_FUNCTION(gr_poly_log_series),
TEST_FUNCTION(gr_poly_make_monic),
TEST_FUNCTION(gr_poly_mul_karatsuba),
TEST_FUNCTION(gr_poly_mul_toom33),
TEST_FUNCTION(gr_poly_nth_derivative),
TEST_FUNCTION(gr_poly_pow_series_fmpq),
TEST_FUNCTION(gr_poly_pow_series_ui),
Expand Down
106 changes: 106 additions & 0 deletions src/gr_poly/test/t-mul_toom33.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright (C) 2023 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "test_helpers.h"
#include "ulong_extras.h"
#include "gr_poly.h"

FLINT_DLL extern gr_static_method_table _ca_methods;

int
test_mul1(flint_rand_t state, int which)
{
gr_ctx_t ctx;
slong n;
gr_poly_t A, B, C, D;
int status = GR_SUCCESS;

gr_ctx_init_random(ctx, state);

gr_poly_init(A, ctx);
gr_poly_init(B, ctx);
gr_poly_init(C, ctx);
gr_poly_init(D, ctx);

if (ctx->methods == _ca_methods)
n = 2;
else if (gr_ctx_is_finite(ctx) == T_TRUE)
n = 30;
else
n = 10;

GR_MUST_SUCCEED(gr_poly_randtest(A, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(B, state, 1 + n_randint(state, n), ctx));
GR_MUST_SUCCEED(gr_poly_randtest(C, state, 1 + n_randint(state, n), ctx));

switch (which)
{
case 0:
status |= gr_poly_mul_toom33(C, A, B, ctx);
break;
case 1:
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_toom33(C, C, B, ctx);
break;
case 2:
status |= gr_poly_set(C, B, ctx);
status |= gr_poly_mul_toom33(C, A, C, ctx);
break;
case 3:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_mul_toom33(C, A, A, ctx);
break;
case 4:
status |= gr_poly_set(B, A, ctx);
status |= gr_poly_set(C, A, ctx);
status |= gr_poly_mul_toom33(C, C, C, ctx);
break;

default:
flint_abort();
}

/* todo: should explicitly call basecase mul */
status |= gr_poly_mullow(D, A, B, FLINT_MAX(0, A->length + B->length - 1), ctx);

if (status == GR_SUCCESS && gr_poly_equal(C, D, ctx) == T_FALSE)
{
flint_printf("FAIL\n\n");
flint_printf("which = %d, n = %wd\n\n", which, n);
gr_ctx_println(ctx);
flint_printf("A = "); gr_poly_print(A, ctx); flint_printf("\n\n");
flint_printf("B = "); gr_poly_print(B, ctx); flint_printf("\n\n");
flint_printf("C = "); gr_poly_print(C, ctx); flint_printf("\n\n");
flint_printf("D = "); gr_poly_print(D, ctx); flint_printf("\n\n");
flint_abort();
}

gr_poly_clear(A, ctx);
gr_poly_clear(B, ctx);
gr_poly_clear(C, ctx);
gr_poly_clear(D, ctx);

gr_ctx_clear(ctx);

return status;
}

TEST_FUNCTION_START(gr_poly_mul_toom33, state)
{
slong iter;

for (iter = 0; iter < 1000; iter++)
{
test_mul1(state, n_randint(state, 5));
}

TEST_FUNCTION_END(state);
}

0 comments on commit 9500186

Please sign in to comment.