Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions doc/source/radix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ Except where otherwise noted, the following rules apply:
Sets *(res, xn)* to the sum or difference of *(x, xn)* and *(y, yn)*, returning
carry or borrow. Requires `xn \ge yn \ge 1`.

.. function:: ulong radix_mul_1(nn_ptr res, nn_srcptr a, slong n, ulong c, const radix_t radix)

Sets *(res, xn)* to *(x, xn)* multiplied by *c*, returning
carry out. Requires `1 \le d \le c`.

.. function:: ulong radix_mul_two(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)

Sets *(res, xn)* to *(x, xn)* multiplied by 2, returning carry out.

.. function:: void radix_mul(nn_ptr res, nn_srcptr x, slong xn, nn_srcptr y, slong yn, const radix_t radix)

Sets *(res, xn + yn)* to the product of *(x, xn)* and *(y, yn)*.
Expand Down Expand Up @@ -163,6 +172,11 @@ Except where otherwise noted, the following rules apply:
Sets *(res, xn)* to the quotient of *(x, xn)* divided by *d*, returning the
remainder. Requires `1 \le d \le B - 1`.

.. function:: ulong radix_divrem_two(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)

Sets *(res, xn)* to the quotient of *(x, xn)* divided by 2, returning the
remainder.

.. function:: void radix_divexact_1(nn_ptr res, nn_srcptr x, slong xn, ulong d, const radix_t radix)

Sets *(res, xn)* to the quotient of *(x, xn)* divided by *d* assuming that
Expand All @@ -176,21 +190,26 @@ Except where otherwise noted, the following rules apply:
`a \in [1/B, 1)` with `an` fraction limbs,
sets `(q, n+2)` to an approximation of `1/a \in (1, B]` with `n` fraction limbs
and two integral limbs (the highest limb may be zero).
The relative error is bounded by `4 B^{-n}`, i.e. the absolute error is bounded
by `4 B^{-n} / a`.

.. function:: void radix_divrem_via_mpn(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, const radix_t radix)
void radix_divrem_newton(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, const radix_t radix)
void radix_divrem(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, const radix_t radix)

Sets `(q,an-bn+1)` to the quotient and `(r,bn)` to the remainder of
`(a,an)` divided by `(b,bn)`. Requires `an \ge bn \ge 1` and
`b_{bn-1} \ne 0`.
`b_{bn-1} \ne 0`. The user can pass ``NULL`` for `r` to compute
only the quotient.

.. function:: void radix_divrem_preinv(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, nn_srcptr binv, slong binvn, const radix_t radix)

Similar to :func:`radix_divrem`, but accepts a precomputed inverse
of `b` given as `(b, binvn+2)` with `binvn` fraction limbs and two
integral limbs, as computed by :func:`radix_inv_approx`.
Currently requires that `binvn \ge an-bn+1`.
Currently requires that `binvn \ge an-bn+1`. Passing an inverse with
several extra limbs can improve performance. The user can pass ``NULL``
for `r` to compute only the quotient.

.. function:: int radix_div(nn_ptr q, nn_srcptr a, slong an, nn_srcptr b, slong bn, const radix_t radix)

Expand All @@ -214,6 +233,13 @@ Except where otherwise noted, the following rules apply:
Returns -1, 0 or 1 according to whether *(x, n)* is less, equal
or greater than `\lfloor B^n / 2 \rfloor`. We require *n* to be positive.

.. function:: void radix_rsqrt_1_approx_basecase(nn_ptr res, ulong a, slong n, const radix_t radix)
void radix_rsqrt_1_approx(nn_ptr res, ulong a, slong n, const radix_t radix)

Sets `(res,n)` to the fractional limbs of an approximation of `1 / \sqrt{a}`.
Assumes that `2 \le a < B`. The error is bounded by `2 B^{-n}`, i.e. by
2 fixed-point ulps.

Radix conversion
--------------------------------------------------------------------------------

Expand Down
14 changes: 14 additions & 0 deletions src/radix.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ radix_sub_1(nn_ptr res, nn_srcptr a, slong n, ulong c, const radix_t radix)

/* Multiplication */

ulong radix_mul_1(nn_ptr res, nn_srcptr a, slong n, ulong c, const radix_t radix);

void radix_mulmid_fft_small(nn_ptr res, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong lo, slong hi, const radix_t radix);
void radix_mulmid_classical(nn_ptr res, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong lo, slong hi, const radix_t radix);
void radix_mulmid_KS(nn_ptr res, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong lo, slong hi, const radix_t radix);
Expand Down Expand Up @@ -111,10 +113,17 @@ radix_sqr(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
radix_mul(res, a, an, a, an, radix);
}

RADIX_INLINE ulong
radix_mul_two(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
{
return radix_add(res, a, an, a, an, radix);
}

/* Division */

ulong radix_divrem_1(nn_ptr res, nn_srcptr a, slong an, ulong d, const radix_t radix);
void radix_divexact_1(nn_ptr res, nn_srcptr a, slong an, ulong d, const radix_t radix);
ulong radix_divrem_two(nn_ptr res, nn_srcptr a, slong an, const radix_t radix);

void radix_inv_approx_basecase(nn_ptr q, nn_srcptr a, slong an, slong n, const radix_t radix);
void radix_inv_approx(nn_ptr q, nn_srcptr a, slong an, slong n, const radix_t radix);
Expand Down Expand Up @@ -155,6 +164,11 @@ radix_cmp_bn_half(nn_srcptr x, slong n, const radix_t radix)
return 0;
}

/* Square roots */

void radix_rsqrt_1_approx_basecase(nn_ptr res, ulong a, slong n, const radix_t radix);
void radix_rsqrt_1_approx(nn_ptr res, ulong a, slong n, const radix_t radix);

/* Radix conversion */

typedef struct
Expand Down
146 changes: 130 additions & 16 deletions src/radix/div.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@
#include "longlong.h"
#include "radix.h"

ulong
radix_divrem_two(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
{
slong i;
ulong q, r, hi, lo, B = LIMB_RADIX(radix);

q = a[an - 1] >> 1;
r = a[an - 1] & 1;
res[an - 1] = q;

for (i = an - 2; i >= 0; i--)
{
umul_ppmm(hi, lo, r, B);
add_ssaaaa(hi, lo, hi, lo, 0, a[i]);
q = (hi << (FLINT_BITS - 1)) | (lo >> 1);
r = lo & 1;
res[i] = q;
}

return r;
}


/* todo: optimise */
ulong
radix_divrem_1(nn_ptr res, nn_srcptr a, slong an, ulong d, const radix_t radix)
Expand Down Expand Up @@ -154,6 +177,33 @@ radix_inv_approx_basecase(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_
/* Must be at least 4 */
#define RADIX_INV_NEWTON_CUTOFF 8

/*
Claim: the absolute error is bounded by 4*B^(-n)/A. This is not tight.

Proof sketch: let m = ceil(n/2) + 1 + [B <= 3].

The initial approximation is assumed to satisfy

Y = 1/A + e1 where e1 <= C*B^(-m)/A where we take C = 4.

We compute Y' = Y + (1 - Y * (A + e2) + e3) * Y + e4 where

e2 = B^(-n-1) is the error from truncating A,

e3 = B^(-n) + {(m+2)*B^(-n-1)}
e4 = B^(-n) + {(m+2)*B^(-n-1)}

are the errors from fixed-point multiplications where the term in curly
brackets appears only when we use mulhigh, when B > 2 * (m + 2) is satisfied.

Expanding the expression for Y' gives

|Y' - 1/A| <= A*e1^2 + e1^2*e2 + e1*e3 + e4 - 2*e1*e2/A + e3/A + e2/A^2

Verifying that the RHS satisfies the claimed bound is now straightforward
(but exhausting) numerics.
*/

void
radix_inv_approx(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_t radix)
{
Expand All @@ -173,7 +223,7 @@ radix_inv_approx(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_t radix)
Compute T ~= 1 / A as a fixed-point number with m fraction limbs,
2 integral limbs and store in the high part of Q.
*/
m = n / 2 + 1 + (LIMB_RADIX(radix) == 2);
m = (n + 1) / 2 + 1 + (LIMB_RADIX(radix) <= 3);
radix_inv_approx(Q + n - m, A, An, m, radix);

TMP_START;
Expand Down Expand Up @@ -232,7 +282,7 @@ radix_inv_approx(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_t radix)
// product.
slong low_trunc_with_mulhigh = m - 1;

if (LIMB_RADIX(radix) > m + 2 && A_low_zeroes <= low_trunc_with_mulhigh)
if (LIMB_RADIX(radix) > 2 * (m + 2) && A_low_zeroes <= low_trunc_with_mulhigh)
{
U = TMP_ALLOC((2 + (control_limb + 1)) * sizeof(ulong));
radix_mulmid(U, Ahigh, Ahighn, T, Tn,
Expand Down Expand Up @@ -275,7 +325,7 @@ radix_inv_approx(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_t radix)
if (Uhighn != 0)
{
/* Compute V = T * U with n fraction limbs. */
if (LIMB_RADIX(radix) > m + 2)
if (LIMB_RADIX(radix) > 2 * (m + 2))
{
// V = T * |1 - A' * T| with n+2 fraction limbs
V = TMP_ALLOC((Tn + Uhighn - (m - 2)) * sizeof(ulong));
Expand Down Expand Up @@ -304,6 +354,15 @@ radix_inv_approx(nn_ptr Q, nn_srcptr A, slong An, slong n, const radix_t radix)
}
}

/* Instead of computing Q with ~1 ulp error and doing a full multiplication
to check the remainder, compute a few fraction limbs. If the first
fraction limb is not too close to the limb boundary, we have the correct
quotient and the remainder can be determined by a low multiplication
(or omitted if we only want the quotient). */
#define USE_PREINV2 1
/* Must be >= 2 */
#define PREINV2_EXTRA_LIMBS 2

void
radix_divrem_preinv(nn_ptr Q, nn_ptr R, nn_srcptr A, slong An, nn_srcptr B, slong Bn, nn_srcptr Binv, slong Binvn, const radix_t radix)
{
Expand All @@ -319,17 +378,61 @@ radix_divrem_preinv(nn_ptr Q, nn_ptr R, nn_srcptr A, slong An, nn_srcptr B, slon
if (Binvn < n)
flint_throw(FLINT_ERROR, "radix_divrem_preinv: inverse has too few limbs");

if (LIMB_RADIX(radix) > n + 2)
slong n2 = n + PREINV2_EXTRA_LIMBS;

/*
Consider A, B as fixed-point numbers with A in [0,1), 1/B in [1/beta,1),
i.e. the integer quotient is floor((A/B)*beta^(an-bn)).

B' = High n2 limbs of Binv = 1/B * (1 + e1) where e1 <= 4*beta^(-n2)
e2 <= beta^(-n2) [truncation error for A]
e3 <= n2*beta^(-n2+1) [truncation error for product]

(A + e2) * B' + e3 - A/B = [A/B*e1 + e3 + e1*e2/B + e2/B]
<= [(e1 + e1*e2 + e2)*beta + e3]
< [(n2 + 6) beta^(-n2+1)]

PREINV2_EXTRA_LIMBS >= 2 ensures that q[-1] has <1 ulp error.
*/
#if USE_PREINV2
if (Binvn >= n2 && An >= n2 && LIMB_RADIX(radix) > n2 + 6)
{
U = TMP_ALLOC((n + 3) * sizeof(ulong));
radix_mulmid(U, A + Bn - 1, n, Binv + Binvn - n, n + 2, An - Bn, 2 * n + 2, radix);
q = U + 2;
U = TMP_ALLOC((n2 + 2) * sizeof(ulong));
/* n2 fraction limbs x (n2 fraction limbs + 2 integral limbs) */
radix_mulmid(U, A + An - n2, n2, Binv + Binvn - n2, n2 + 2, n2, 2 * n2 + 2, radix);

FLINT_ASSERT(U[n2 + 1] == 0);
q = U + n2 + 2 - (n + 1);

if (q[-1] > 1 && q[-1] < LIMB_RADIX(radix) - 1)
{
if (R == NULL)
r = NULL;
else
{
T = TMP_ALLOC(Bn * sizeof(ulong));
r = T;
radix_mulmid(r, q, FLINT_MIN(Bn, n + 1), B, Bn, 0, Bn, radix);
radix_sub(r, A, Bn, r, Bn, radix);
}
goto done;
}
}
else
#endif
{
U = TMP_ALLOC((2 * n + 2) * sizeof(ulong));
radix_mul(U, A + Bn - 1, n, Binv + Binvn - n, n + 2, radix);
q = U + An - Bn + 2;
if (LIMB_RADIX(radix) > n + 2)
{
U = TMP_ALLOC((n + 3) * sizeof(ulong));
radix_mulmid(U, A + Bn - 1, n, Binv + Binvn - n, n + 2, An - Bn, 2 * n + 2, radix);
q = U + 2;
}
else
{
U = TMP_ALLOC((2 * n + 2) * sizeof(ulong));
radix_mul(U, A + Bn - 1, n, Binv + Binvn - n, n + 2, radix);
q = U + An - Bn + 2;
}
}

T = TMP_ALLOC((Bn + n) * sizeof(ulong));
Expand All @@ -354,8 +457,10 @@ radix_divrem_preinv(nn_ptr Q, nn_ptr R, nn_srcptr A, slong An, nn_srcptr B, slon
radix_sub(r, r, Bn + 1, B, Bn, radix);
}

done:
flint_mpn_copyi(Q, q, An - Bn + 1);
flint_mpn_copyi(R, r, Bn);
if (R != NULL)
flint_mpn_copyi(R, r, Bn);

TMP_END;
}
Expand All @@ -370,15 +475,21 @@ radix_divrem_newton(nn_ptr Q, nn_ptr R, nn_srcptr A, slong An, nn_srcptr B, slon

if (flint_mpn_zero_p(A + Bn, An - Bn) && mpn_cmp(A, B, Bn) < 0)
{
flint_mpn_copyi(R, A, Bn);
if (R != NULL)
flint_mpn_copyi(R, A, Bn);
flint_mpn_zero(Q, An - Bn + 1);
return;
}

n = An - Bn + 1;
T = TMP_ALLOC((n + 2) * sizeof(ulong));
T = TMP_ALLOC((n + PREINV2_EXTRA_LIMBS + 2) * sizeof(ulong));
#if USE_PREINV2
radix_inv_approx(T, B, Bn, n + PREINV2_EXTRA_LIMBS, radix);
radix_divrem_preinv(Q, R, A, An, B, Bn, T, n + PREINV2_EXTRA_LIMBS, radix);
#else
radix_inv_approx(T, B, Bn, n, radix);
radix_divrem_preinv(Q, R, A, An, B, Bn, T, n, radix);
#endif
TMP_END;
return;
}
Expand All @@ -392,11 +503,13 @@ radix_divrem(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, c

if (bn == 1)
{
r[0] = radix_divrem_1(q, a, an, b[0], radix);
ulong r0 = radix_divrem_1(q, a, an, b[0], radix);
if (r != NULL)
r[0] = r0;
return;
}

if (bn >= 48)
if ((bn >= 20 && (an - bn + 1) <= 70) || bn >= 80)
{
radix_divrem_newton(q, r, a, an, b, bn, radix);
return;
Expand Down Expand Up @@ -438,7 +551,8 @@ radix_divrem(nn_ptr q, nn_ptr r, nn_srcptr a, slong an, nn_srcptr b, slong bn, c
cy = q[i * bn];
}

flint_mpn_copyi(r, R, bn);
if (r != NULL)
flint_mpn_copyi(r, R, bn);
TMP_END;
return;
}
Expand Down
Loading