#include "defs.h"
#include "mp.e"
#include "mp.h"


/*
Karatsuba constants:
    When (t < KARA_CONST_1), mp_mul() uses the Karatsuba function.
    When (t < KARA_CONST_2), the Karatsuba function uses the classical method.

These constants are quite arbitrary.
*/

#define KARA_CONST_1	100
#define KARA_CONST_2	50



mp_float
mp_mul		WITH_3_ARGS(
	mp_float,	x,
	mp_float,	y,
	mp_float,	z
)
/*
Sets z = x * y and returns z.  The simple O(t^2) algorithm is used,
with rounding performed on the digits of the product which cannot fit
into z.  The routine is faster if x has more zero digits.  Accumulator
operations are performed.
*/
{
    mp_ptr_type		xp = mp_ptr(x), yp = mp_ptr(y), zp = mp_ptr(z);
    mp_base_type	b;
    mp_length		t, guard, total;
    mp_acc_index	acc_index;
    mp_digit_ptr_type	x_dig_p, y_dig_p;
    mp_int		*z_dig_p, mul_count = 8;
    register mp_length	i;


    DEBUG_BEGIN(DEBUG_MUL);
    DEBUG_PRINTF_1("+mul {\n");
    DEBUG_1("x = ", xp);
    DEBUG_1("y = ", yp);

    /*
    Check that x, y, & z have compatible parameters.
    */

    mp_check_3("mp_mul", xp, yp, zp);


    if (mp_is_zero(xp) || mp_is_zero(yp))
    {
	mp_set_sign(zp, 0);
	DEBUG_PRINTF_1("-} z = 0\n");
	DEBUG_END();
	return z;
    }


    /*
    Extract base and number of digits from z.
    */

    t = mp_t(zp);
    b = mp_b(zp);
    mp_change_up();

#ifdef KARA
    if (t > KARA_CONST_1)
    {

	mp_digit_index	x_copy, y_copy, result;

	mp_acc_digit_alloc(4 * t, acc_index);
	x_copy = int_to_digit(acc_index);

	y_copy = x_copy + t;
	result = y_copy + t;

	mp_dig_copy(mp_karatsuba_ptr(x_copy), mp_digit_ptr(xp, 0), t);
	mp_dig_copy(mp_karatsuba_ptr(y_copy), mp_digit_ptr(yp, 0), t);

	mp_karatsuba(x_copy, y_copy, result, t, b);

	if (mp_has_changed())
	{
	    xp = mp_ptr(x);
	    yp = mp_ptr(y);
	    zp = mp_ptr(z);
	}

	mp_nzr(mp_sign(xp) * mp_sign(yp), mp_expt(xp) + mp_expt(yp),
			    zp, mp_karatsuba_ptr(result), t);

    }

    else
    {
#endif

	guard = (round != MP_TRUNC || t > b * b / 100)? t: int_min(t, 3);
	total = t + guard;


	/*
	Allocate space for digits.  All intermediate calculations with
	digits are done with mp_ints.  Because the size of a digit may be
	less than the size of an mp_int, the accumulator space needs to
	have one mp_int per digit, so we use appropriate casts and scale
	factors.  We use the macro mp_acc_alloc(), not the macro
	mp_acc_digits_alloc().  Note that zp has type (mp_int *), not
	(mp_digit_type *), as in all the other arithmetic operations where
	the accumulator is used.  The mp pointers x, y, and z are also
	reset if necessary.
	*/

	mp_acc_alloc(total, acc_index);
	z_dig_p = mp_acc_ptr(acc_index);

	if (mp_has_changed())
	{
	    xp = mp_ptr(x);
	    yp = mp_ptr(y);
	    zp = mp_ptr(z);
	}

#define z_dig(i)	z_dig_p[i]
#define z_ptr(i)	(z_dig_p + (i))


	/*
	Clear the product digits since we will add into them.
	mp_set_digits_zero() takes a number of digits, so we must scale
	the number of ints to the number of digits.
	*/

	mp_set_digits_zero((mp_digit_ptr_type)z_ptr(0), int_to_digit(total));


	/*
	Main multiplication loop: i runs along the digits of x, j along the
	digits of y.  We set pointers to where the digits of x & y are which
	makes the inner loop much faster.
	*/

	x_dig_p = mp_digit_ptr(xp, 0);
	y_dig_p = mp_digit_ptr(yp, 0);

#define x_dig(i)		(mp_int)x_dig_p[i]
#define y_dig(i)		(mp_int)y_dig_p[i]


	for (i = 0; i < t; i++)
	{
	    register mp_length	j;
	    register mp_int		xi;


	    if (xi = x_dig(i))
	    {
		register mp_length n = i >= guard? total - i - 1: t;


#ifndef NO_DEBUG
		if (xi < 0 || xi >= b)
		    mp_error("mp_mul: illegal digit >= b");
#endif

		for (j = 0; j < n; j++)
		{
		    z_dig(i + j + 1) += xi * y_dig(j);
		}
	    }

	    if (i == t - 1 || xi && !--mul_count)
	    {
		/*
		Propagate carries at end and every eighth time -
		faster than doing it every time.  If xi is zero, then
		no additions would have been made to the digits, so we
		need not propagate this time.
		*/

		mp_int		c = 0;

		mul_count = 8;


		for (j = total - 1; j >= 0; j--)
		{
		    register mp_int zj = z_dig(j) + c;

		    c = zj / b;
		    z_dig(j) = zj - b * c;		/* i.e. zj % b */
		}


		/*
		We place the digits of the product in the second digit and
		onwards and leave the first digit for a carry which may
		reach into it.  If c is non-zero now, a true overflow
		must have occured.
		*/

		if (c)
		    mp_error("mp_mul: illegal digit >= b");
	    }

	}


	/*
	Because we have used mp_int digits in the accumulator, and mp_nzr()
	expects mp_digit_type digits, we must change the mp_ints to digits.
	Since sizeof(mp_int) cannot be less than sizeof(mp_digit_type),
	we copy from the left, so no data can be overwritten before it
	is copied.
	*/

	for (i = 0; i < total; i++)
	    ((mp_digit_ptr_type)z_dig_p)[i] = z_dig_p[i];


	/*
	Normalize & round result.  
	*/

	mp_nzr(mp_sign(xp) * mp_sign(yp), mp_expt(xp) + mp_expt(yp), zp,
					(mp_digit_ptr_type)z_ptr(0), guard);
	
#ifdef KARA
    }
#endif

    mp_acc_delete(acc_index);
    mp_change_down();

    DEBUG_1("-} z = ", zp);
    DEBUG_END();
    return z;
}


#ifdef KARA

/*
KARATSUBA:
  Sadly, this method seems to be not worth it unless t is very large.
  I haven't been able to test many possible values because the actual
  multiplication is rather slow for each large t.
*/


void
mp_priv_karatsuba	WITH_5_ARGS(
	mp_digit_index,	x,
	mp_digit_index,	y,
	mp_digit_index,	z,
	mp_length,	n,
	mp_base_type,	b
)
/*
Karatsuba multiplication: multiplies the n digits in x by the n digits in
y and places the result in 2 * n digits of z.  x, y, and z give a digit
index into the accumulator.  b is the base.
*/
{
    mp_digit_index	z_mid;
    mp_acc_index	z_mid_index;
    mp_digit_ptr_type	xp, yp, zp, z_midp, x0, y0, z0;
    mp_length		i, j, k, k1, l; 
    mp_int		c, odd, i_min, last_carry, neg, swapped;
    mp_bool		product_zero;
    mp_int		xy_non_zero;


    if (n <= KARA_CONST_2)
    {
	/*
	Base case: use classical algorithm.  The details of the algorithm
	are the same as mp_mul() above, but the access to the digits is
	different, so we must repeat some of the source code.
	*/

	mp_acc_index	z_int_index;
	mp_int		*z_int, mul_count = 8;
	mp_length	i, total;


	/*
	Allocate space for total INTS - see mp_mul() for reasons why.
	*/

	total = 2 * n;

	mp_acc_alloc(total, z_int_index);
	z_int = mp_acc_ptr(z_int_index);

	xp = mp_karatsuba_ptr(x);
	yp = mp_karatsuba_ptr(y);
	zp = mp_karatsuba_ptr(z);


	DEBUG_BEGIN(DEBUG_KARA);
	DEBUG_PRINTF_1("+classical {\n");
	DEBUG_PRINTF_3("n = %d\n", n, k);
	DEBUG_DUMP("x = ", xp, n);
	DEBUG_DUMP("y = ", yp, n);


	/*
	Initialize product to zero.
	*/

	mp_set_digits_zero((mp_digit_ptr_type)z_int, int_to_digit(total));


	for (i = 0; i < n; i++)
	{
	    register mp_length		j;
	    register mp_digit_type	xi;


	    if (xi = xp[i])
	    {

#ifndef NO_DEBUG
		if (xi < 0 || xi >= b)
		    mp_bug("mp_class: illegal digit (%d) >= b", xi);
#endif

		for (j = 0; j < n; j++)
		{
		    z_int[i + j + 1] += xi * yp[j];
		}
	    }

	    if (i == n - 1 || xi && !--mul_count)
	    {
		/*
		Propagate carries at end and every eighth time - faster
		than doing it every time.  If xi is zero, then no additions
		would have been made to the digits, so we need not propagate
		this time.
		*/

		register mp_int		c = 0;

		mul_count = 8;


		for (j = total - 1; j >= 0; j--)
		{
		    register mp_int	zj = z_int[j] + c;

		    c = zj / b;
		    z_int[j] = zj - b * c;	/* i.e. zj % b */
		}


		/*
		We place the digits of the product in the second digit and
		onwards and leave the first digit for a carry which may
		reach into it.  If c is non-zero now, a true overflow must
		have occured.
		*/

		if (c)
		    mp_bug("mp_class: illegal digit >= b");
	    }
	}



	/*
	Copy the ints of the product into the digits of z - see mp_mul().
	*/

	for (i = 0; i < total; i++)
	    zp[i] = z_int[i];


	DEBUG_DUMP("-} z = ", zp, 2 * n);
	DEBUG_END();

	mp_acc_delete(z_int_index);
	return;
    }


    /*
    Recursive case, based on the formula:

	(At + B)(Ct + D) = ACt^2 + (AD + BC)t + BD
			 = ACt^2 + [(A - B)(D - C) - AC - BD]t + BD,

    where t is b^floor(n/2).
    Thus we only need to calculate 3 sub-products (AC, BD, and (A - B)(D - C)),
    which each have complexity half of this call.

    We use	A = the first half of x (xp),
		B = the second half of x (x0),
		C = the first half of y (y),
		D = the second half of y (y0),
		t = b^k, where k = floor(b/2).
    */


    DEBUG_BEGIN(DEBUG_KARA);
    DEBUG_PRINTF_1("+kara {\n");

    odd = n & 1;

    k = n >> 1;
    if (odd)
	k++;

    /*
    Allocate 2k digits for z_mid, which will hold the product (x - x0)(y0 - y).
    */

    mp_acc_digit_alloc(2 * k, z_mid_index);
    mp_change_up();

    z_mid = int_to_digit(z_mid_index);


    k1 = k - odd;

    xp = mp_karatsuba_ptr(x);
    yp = mp_karatsuba_ptr(y);
    zp = mp_karatsuba_ptr(z);


    DEBUG_PRINTF_3("n = %d, k = %d\n", n, k);
    DEBUG_DUMP("x = ", xp, n);
    DEBUG_DUMP("y = ", yp, n);


    /*
    If n is odd, pretend we have an even number of digits by decrementing
    the pointers of x and y - the most significant digit of each is then 0.
    */

    if (odd)
    {
	xp--;
	yp--;
    }


    /*
    x0 points to the second half of x; similarly for y and z.
    */

    x0 = xp + k;
    y0 = yp + k;
    z0 = zp + 2 * k1;



    /*
    We compare digits of x and x0 to find out which is greater and thus to
    determine the sign.  If x < x0, we swap the pointers to x and x0 and
    invert the sign.  If the digits of x and x0 all correspond, then z_mid
    is zero, so we can avoid one recursive step.  We also check to
    see whether all of x is zero - if so, the answer to everything is zero
    and we can exit!  (This check is justifiable, since using a profiler
    I found that this situation happens often.)  Exactly the same calculations
    and checks are then done for y.
    */

    product_zero = FALSE;
    swapped = FALSE;
    neg = TRUE;

    xy_non_zero = 0;

    if (!odd || x0[0] == 0)
    {
	/*
	Run along digits of x and x0 until a difference is found.
	*/

	for (i = odd; i < k; i++)
	{
	    if (xp[i] == x0[i])
	    {
		zp[i] = 0;
		xy_non_zero |= xp[i];
	    }
	    else
		break;
	}

	DEBUG_PRINTF_2("i == %d\n", i);

	if (i == k)
	{
	    DEBUG_PRINTF_1("xdiff all zero\n");

	    if (!xy_non_zero)
	    {
		DEBUG_PRINTF_1("product all zero\n");
		mp_set_digits_zero(zp, 2 * n);
		goto finished;
	    }

	    product_zero = TRUE;
	}

	else if (xp[i] > x0[i])
	{
	    /*
	    x > x0, so swap pointers and negate.
	    */

	    z_midp = xp;
	    xp = x0;
	    x0 = z_midp;

	    j = i;

	    swapped = TRUE;
	    neg = FALSE;
	}
    }
    else
    {
	/*
	odd is true, and the first digit of x0 is non-zero, so x0 is trivially
	greater than x.
	*/

	i = 1;
    }


    if (!product_zero)
    {
	/*
	Form x0 - x in z (z is used as a temporary scratch).
	*/

	for (j = k - 1, c = 0; j >= i; j--)
	{
	    zp[j] = c + x0[j] - xp[j];

	    if (zp[j] < 0)
	    {
		zp[j] += b;
		c = -1;
	    }
	    else
		c = 0;
	}

	if (odd)
	{
	    if (swapped)
		zp[0] = 0;

	    else
	    {
		zp[0] = x0[0];
		if (i == 1 && c < 0)
		    zp[0]--;
	    }
	}

	DEBUG_PRINTF_2("neg = %d\n", neg);


	/*
	Now do the same for y as for x.
	*/

	swapped = FALSE;
	xy_non_zero = FALSE;

	if (!odd || y0[0] == 0)
	{
	    for (i = odd; i < k; i++)
	    {
		if (yp[i] == y0[i])
		{
		    z0[i] = 0;
		    xy_non_zero |= yp[i];
		}
		else
		    break;
	    }

	    DEBUG_PRINTF_2("i == %d\n", i);

	    if (i == k)
	    {
		DEBUG_PRINTF_1("ydiff all zero\n");

		if (!xy_non_zero)
		{
		    DEBUG_PRINTF_1("product all zero\n");
		    mp_set_digits_zero(zp, 2 * n);

		    goto finished;
		}

		product_zero = TRUE;
	    }

	    else if (yp[i] > y0[i])
	    {
		z_midp = yp;
		yp = y0;
		y0 = z_midp;

		j = i;

		swapped = 1;
		neg = !neg;
	    }
	}
	else
	    i = 1;

	if (!product_zero)
	{
	    /*
	    Form y0 - yp in z0.
	    */

	    for (j = k - 1, c = 0; j >= i; j--)
	    {
		z0[j] = c + y0[j] - yp[j];

		if (z0[j] < 0)
		{
		    z0[j] += b;
		    c = -1;
		}
		else
		    c = 0;
	    }

	    if (odd)
	    {
		if (swapped)
		    z0[0] = 0;

		else
		{
		    z0[0] = y0[0];
		    if (i == 1 && c < 0)
			z0[0]--;
		}
	    }
	}
    }


    DEBUG_PRINTF_2("neg = %d\n", neg);


    if (!product_zero)
    {
	/*
	Now z == int_abs(x0 - x), z0 == int_abs(y0 - y), so place z * z0 in z_mid.
	*/

	mp_karatsuba(z, z + 2 * k1, z_mid, k, b);
    }


    /*
    Place most significant digits in z (x * y) (coefficient of b^2k) and
    least significant digits in z0 (x0 * y0) (constant coefficient).
    */

    mp_karatsuba(x, y, z, k1, b);
    mp_karatsuba(x + k1, y + k1, z + 2 * k1, k, b);


    if (mp_has_changed())
    {
	zp = mp_karatsuba_ptr(z);
	z0 = zp + 2 * k1;
    }

    z_midp = mp_karatsuba_ptr(z_mid);

    DEBUG_DUMP("z_mid = ", z_midp, 2 * k);
    DEBUG_PRINTF_2("adding into z_mid (neg = %d)\n", neg);


    /*
    Now add z == (x * y) and z0 == (x0 * y0) into z_mid (the co-efficient
    of b^k).
    */

    c = 0;

    for (i = 2 * k - 1, j = 2 * k1 - 1, l = 2 * k - 1; i >= 0; i--, j--, l--)
    {
	register mp_int		dig;

	if (product_zero)
	{
	    /*
	    z_mid is zero, so just add a digit each from z and z0.
	    */

	    dig = c + z0[i] + (j >= 0? zp[j]: 0);
	    c = 0;
	}

	else if (neg)
	{
	    /*
	    z_mid was really negative, so we negate it before adding in
	    the digits of z and z0.
	    */

	    dig = -z_midp[l] + c + z0[i] + (j >= 0? zp[j] : 0);

	    c = 0;

	    while (dig < 0)
	    {
		/*
		Account for borrow.
		*/

		dig += b;
		c--;
	    }
	}

	else
	{
	    /*
	    Just add in a digit each of z and z0.
	    */

	    DEBUG_PRINTF_5("i = %d, z_midp[l] = %d, z0[i] = %d, zp[j] = %d\n",
				i, z_midp[l], z0[i], zp[j]);
	    dig = z_midp[l] + c + z0[i] + (j >= 0? zp[j]: 0);
	    c = 0;
	}

	while (dig >= b)
	{
	    /*
	    Account for carry.
	    */

	    dig -= b;
	    c++;
	}

	z_midp[l] = dig;
    }

    DEBUG_DUMP("z_mid = ", z_midp, 2 * k);
    DEBUG_PRINTF_2("c = %d\n", c);

    /*
    Remember the last_carry (if any).
    */

    last_carry = c;


    /*
    Now add z_mid into the middle of the complete result (i.e. so z_mid is
    the co-efficient of b^k).
    */

    c = 0;

    for (i = 2 * k - 1, j = 2 * n - 1 - k; i >= 0; i--, j--)
    {
	zp[j] += c + z_midp[i];

	if (zp[j] >= b)
	{
	    zp[j] -= b;
	    c = 1;
	}
	else
	    c = 0;
    }


    DEBUG_PRINTF_3("c = %d, last_carry = %d\n", c, last_carry);

    for (c += last_carry; c; j--)
	if ((zp[j] += c) >= b)
	{
	    zp[j] -= b;
	    c = 1;
	}
	else
	    break;

finished:

    mp_acc_delete(z_mid_index);
    mp_change_down();

    DEBUG_DUMP("-} z = ", zp, 2 * n);
    DEBUG_END();
}
#endif
