/*
This file is part of mfaktc.
Copyright (C) 2012  George Woltman (woltman@alum.mit.edu)

mfaktc is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

mfaktc is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.
                                
You should have received a copy of the GNU General Public License
along with mfaktc.  If not, see <http://www.gnu.org/licenses/>.
*/


__device__ static int cmp_ge_96(int96 a, int96 b)
/* checks if a is greater or equal than b */
{
  if(a.d2 == b.d2)
  {
    if(a.d1 == b.d1)return(a.d0 >= b.d0);
    else            return(a.d1 >  b.d1);
  }
  else              return(a.d2 >  b.d2);
}


__device__ static void shl_96(int96 *a)
/* shiftleft a one bit */
{
  a->d0 = __add_cc (a->d0, a->d0);
  a->d1 = __addc_cc(a->d1, a->d1);
  a->d2 = __addc   (a->d2, a->d2);
}


__device__ static void sub_96(int96 *res, int96 a, int96 b)
/* a must be greater or equal b!
res = a - b */
{
  res->d0 = __sub_cc (a.d0, b.d0);
  res->d1 = __subc_cc(a.d1, b.d1);
  res->d2 = __subc   (a.d2, b.d2);
}


__device__ static void mul_96(int96 *res, int96 a, int96 b)
/* res = a * b (only lower 96 bits of the result) */
{
  asm("{\n\t"
      "mul.lo.u32    %0, %3, %6;\n\t"       /* (a.d0 * b.d0).lo */

      "mul.hi.u32    %1, %3, %6;\n\t"       /* (a.d0 * b.d0).hi */
      "mad.lo.cc.u32 %1, %4, %6, %1;\n\t"   /* (a.d1 * b.d0).lo */

      "mul.lo.u32    %2, %5, %6;\n\t"       /* (a.d2 * b.d0).lo */
      "madc.hi.u32   %2, %4, %6, %2;\n\t"   /* (a.d1 * b.d0).hi */

      "mad.lo.cc.u32 %1, %3, %7, %1;\n\t"   /* (a.d0 * b.d1).lo */
      "madc.hi.u32   %2, %3, %7, %2;\n\t"   /* (a.d0 * b.d1).hi */

      "mad.lo.u32    %2, %3, %8, %2;\n\t"   /* (a.d0 * b.d2).lo */

      "mad.lo.u32    %2, %4, %7, %2;\n\t"   /* (a.d1 * b.d1).lo */
      "}"
      : "=r" (res->d0), "=r" (res->d1), "=r" (res->d2)
      : "r" (a.d0), "r" (a.d1), "r" (a.d2), "r" (b.d0), "r" (b.d1), "r" (b.d2));
}


__device__ static void square_96_192(int192 *res, int96 a)
/* res = a^2
assuming that a is < 2^95 (a.d2 < 2^31)! */
{
  asm("{\n\t"
      ".reg .u32 a2;\n\t"

      "mul.lo.u32      %0, %6, %6;\n\t"       /* (a.d0 * a.d0).lo */
      "mul.lo.u32      %1, %6, %7;\n\t"       /* (a.d0 * a.d1).lo */
      "mul.hi.u32      %2, %6, %7;\n\t"       /* (a.d0 * a.d1).hi */
      
      "add.cc.u32      %1, %1, %1;\n\t"       /* 2 * (a.d0 * a.d1).lo */
      "addc.cc.u32     %2, %2, %2;\n\t"       /* 2 * (a.d0 * a.d1).hi */
      "madc.hi.cc.u32  %3, %7, %7, 0;\n\t"    /* (a.d1 * a.d1).hi */
/* highest possible value for next instruction: mul.lo.u32 (N, N) is 0xFFFFFFF9
this occurs for N = {479772853, 1667710795, 2627256501, 3815194443}
We'll use this knowledge later to avoid some two carry steps to %5 */
      "madc.lo.u32     %4, %8, %8, 0;\n\t"    /* (a.d2 * a.d2).lo */
                                              /* %4 <= 0xFFFFFFFA => no carry to %5 needed! */

      "add.u32         a2, %8, %8;\n\t"       /* a2 = 2 * a.d2 */
                                              /* a is < 2^95 so a.d2 is < 2^31 */

      "mad.hi.cc.u32   %1, %6, %6, %1;\n\t"   /* (a.d0 * a.d0).hi */
      "madc.lo.cc.u32  %2, %7, %7, %2;\n\t"   /* (a.d1 * a.d1).lo */
      "madc.lo.cc.u32  %3, %7, a2, %3;\n\t"   /* 2 * (a.d1 * a.d2).lo */
      "addc.u32        %4, %4,  0;\n\t"       /* %4 <= 0xFFFFFFFB => not carry to %5 needed, see above! */

      "mad.lo.cc.u32   %2, %6, a2, %2;\n\t"   /* 2 * (a.d0 * a.d2).lo */
      "madc.hi.cc.u32  %3, %6, a2, %3;\n\t"   /* 2 * (a.d0 * a.d2).hi */
      "madc.hi.cc.u32  %4, %7, a2, %4;\n\t"   /* 2 * (a.d1 * a.d2).hi */
      "madc.hi.u32     %5, %8, %8, 0;\n\t"    /* (a.d2 * a.d2).hi */
      "}"
      : "=r" (res->d0), "=r" (res->d1), "=r" (res->d2), "=r" (res->d3), "=r" (res->d4), "=r" (res->d5)
      : "r" (a.d0), "r" (a.d1), "r" (a.d2));
}


__device__ static void square_96_160(int192 *res, int96 a)
/* res = a^2
this is a stripped down version of square_96_192, it doesn't compute res.d5
and is a little bit faster.
For correct results a must be less than 2^80 (a.d2 less than 2^16) */
{
  asm("{\n\t"
      ".reg .u32 a2;\n\t"

      "mul.lo.u32     %0, %5, %5;\n\t"     /* (a.d0 * a.d0).lo */
      "mul.lo.u32     %1, %5, %6;\n\t"     /* (a.d0 * a.d1).lo */
      "mul.hi.u32     %2, %5, %6;\n\t"     /* (a.d0 * a.d1).hi */

      "add.u32        a2, %7, %7;\n\t"     /* shl(a.d2) */

      "add.cc.u32     %1, %1, %1;\n\t"     /* 2 * (a.d0 * a.d1).lo */
      "addc.cc.u32    %2, %2, %2;\n\t"     /* 2 * (a.d0 * a.d1).hi */
      "madc.hi.u32    %3, %5, a2, 0;\n\t"  /* 2 * (a.d0 * a.d2).hi */
                                           /* %3 (res.d3) has some space left because a2 is < 2^17 */

      "mad.hi.cc.u32  %1, %5, %5, %1;\n\t" /* (a.d0 * a.d0).hi */
      "madc.lo.cc.u32 %2, %6, %6, %2;\n\t" /* (a.d1 * a.d1).lo */
      "madc.hi.cc.u32 %3, %6, %6, %3;\n\t" /* (a.d1 * a.d1).hi */
      "madc.lo.u32    %4, %7, %7, 0;\n\t"  /* (a.d2 * a.d2).lo */
      
      "mad.lo.cc.u32  %2, %5, a2, %2;\n\t" /* 2 * (a.d0 * a.d2).lo */
      "madc.lo.cc.u32 %3, %6, a2, %3;\n\t" /* 2 * (a.d1 * a.d2).lo */
      "madc.hi.u32    %4, %6, a2, %4;\n\t" /* 2 * (a.d1 * a.d2).hi */                                          
      "}"
      : "=r"(res->d0), "=r"(res->d1), "=r"(res->d2), "=r"(res->d3), "=r"(res->d4)
      : "r"(a.d0), "r"(a.d1), "r"(a.d2));
}


__device__ static void mul_96_192_no_low3(int96 *res, int96 a, int96 b)
/*
res ~= a * b / 2^96
Carries into res.d0 are NOT computed. So the result differs from a full mul_96_192() / 2^96.
In a full mul_96_192() there are four possible carries from res.d3 to res.d4. So ignoring the carries
the result is 0 to 4 lower than a full mul_96_192() / 2^96.
 */
{
  asm("{\n\t"
      "mul.hi.u32      %0, %5, %6;\n\t"       /* (a.d2 * b.d0).hi */
      "mad.lo.cc.u32   %0, %5, %7, %0;\n\t"   /* (a.d2 * b.d1).lo */
      "addc.u32        %1,  0,  0;\n\t"

      "mad.lo.cc.u32   %0, %4, %8, %0;\n\t"   /* (a.d1 * b.d2).lo */
      "madc.hi.u32     %1, %5, %7, %1;\n\t"   /* (a.d2 * b.d1).hi */

      "mad.hi.cc.u32   %0, %3, %8, %0;\n\t"   /* (a.d0 * b.d2).hi */
      "madc.lo.cc.u32  %1, %5, %8, %1;\n\t"   /* (a.d2 * b.d2).lo */
      "madc.hi.u32     %2, %5, %8,  0;\n\t"   /* (a.d2 * b.d2).hi */

      "mad.hi.cc.u32   %0, %4, %7, %0;\n\t"   /* (a.d1 * b.d1).hi */
      "madc.hi.cc.u32  %1, %4, %8, %1;\n\t"   /* (a.d1 * b.d2).hi */
      "addc.u32        %2, %2,  0;\n\t"
      "}"
      : "=r" (res->d0), "=r" (res->d1), "=r" (res->d2)
      : "r" (a.d0), "r" (a.d1), "r" (a.d2), "r" (b.d0), "r" (b.d1), "r" (b.d2));
}


__device__ static void mulsub_96(int96 *res, int192 c, int96 a, int96 negb)
/* res = c - a * b (only lower 96 bits of the result) */
{
  asm("{\n\t"
      "mad.lo.cc.u32   %0, %3, %6, %9;\n\t"    /* c += (a.d0 * negb.d0).lo */
      "madc.lo.cc.u32  %1, %3, %7, %10;\n\t"   /* c += (a.d0 * negb.d1).lo */
      "madc.lo.u32     %2, %3, %8, %11;\n\t"   /* c += (a.d0 * negb.d2).lo */

      "mad.hi.cc.u32   %1, %3, %6, %1;\n\t"    /* c += (a.d0 * negb.d0).hi */
      "madc.hi.u32     %2, %3, %7, %2;\n\t"    /* c += (a.d0 * negb.d1).hi */

      "mad.lo.cc.u32   %1, %4, %6, %1;\n\t"    /* c += (a.d1 * negb.d0).lo */
      "madc.lo.u32     %2, %4, %7, %2;\n\t"    /* c += (a.d1 * negb.d1).lo */

      "mad.hi.u32      %2, %4, %6, %2;\n\t"    /* c += (a.d1 * negb.d0).hi */

      "mad.lo.u32      %2, %5, %6, %2;\n\t"    /* c += (a.d2 * negb.d0).lo */
      "}"
      : "=r" (res->d0), "=r" (res->d1), "=r" (res->d2)
      : "r" (a.d0), "r" (a.d1), "r" (a.d2),
        "r" (negb.d0), "r" (negb.d1), "r" (negb.d2),
        "r" (c.d0), "r" (c.d1), "r" (c.d2));
}


__device__ static void div_224_96(int128 *res, int224 q, int96 n, float nf)
/* res = q / n (integer division) */
{
  float qf;
  unsigned int qi;
  int224 nn;
  int96 tmp96;

/********** Step -1, Offset 2^115 (3*32 + 19) **********/
  qf= __uint2float_rn(q.d6);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d5);
  qf*= 8192.0f;

  qi=__float2uint_rz(qf*nf);
//if (blockIdx.x==0 && threadIdx.x == 4) printf ("q4: %X, %X, %X\n", qi, q.d6, q.d5);

  res->d3 = qi << 19;

// nn = n * qi
  nn.d3 =                                 __umul32(n.d0, qi);
  nn.d4 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d5 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d6 = __addc_cc(__umul32hi(n.d2, qi),                  0);

// shiftleft nn 19 bits
  nn.d6  = (nn.d6 << 19) + (nn.d5 >> 13);
  nn.d5  = (nn.d5 << 19) + (nn.d4 >> 13);
  nn.d4  = (nn.d4 << 19) + (nn.d3 >> 13);
  nn.d3  =  nn.d3 << 19;

//  q = q - nn
  q.d3  = __sub_cc (q.d3,  nn.d3);
  q.d4  = __subc_cc(q.d4,  nn.d4);
  q.d5  = __subc_cc(q.d5,  nn.d5);
  q.d6  = __subc   (q.d6,  nn.d6);

/********** Step 0, Offset 2^95 (2*32 + 31) **********/
  qf= __uint2float_rn(q.d6);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d5);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d4);
  qf*= 2.0f;

  qi=__float2uint_rz(qf*nf);
//if (blockIdx.x==0 && threadIdx.x == 4) printf ("q5: %X\n", qi);

  res->d2  = qi << 31;
  res->d3 += qi >> 1;

// nn = n * qi
  nn.d2 =                                 __umul32(n.d0, qi);
  nn.d3 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d4 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d5 = __addc_cc(__umul32hi(n.d2, qi),                  0);

//if (nn.d6 >> 1 != q.d7) printf ("1/f fail 7\n");
// shiftleft nn 31 bits
  nn.d5 = (nn.d5 << 31) + (nn.d4 >> 1);
  nn.d4 = (nn.d4 << 31) + (nn.d3 >> 1);
  nn.d3 = (nn.d3 << 31) + (nn.d2 >> 1);
  nn.d2 =  nn.d2 << 31;

//  q = q - nn
  q.d2 = __sub_cc (q.d2, nn.d2);
  q.d3 = __subc_cc(q.d3, nn.d3);
  q.d4 = __subc_cc(q.d4, nn.d4);
  q.d5 = __subc   (q.d5, nn.d5);

/********** Step 1, Offset 2^75 (2*32 + 11) **********/
  qf= __uint2float_rn(q.d5);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d4);
  qf*= 2097152.0f;

  qi=__float2uint_rz(qf*nf);

  res->d2 = __add_cc(res->d2, qi << 11);
  res->d3 = __addc  (res->d3, 0);

// nn = n * qi
  nn.d2 =                                 __umul32(n.d0, qi);
  nn.d3 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d4 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d5 = __addc   (__umul32hi(n.d2, qi),                  0);

// shiftleft nn 11 bits
  nn.d5 = (nn.d5 << 11) + (nn.d4 >> 21);
  nn.d4 = (nn.d4 << 11) + (nn.d3 >> 21);
  nn.d3 = (nn.d3 << 11) + (nn.d2 >> 21);
  nn.d2 =  nn.d2 << 11;

//  q = q - nn
  q.d2 = __sub_cc (q.d2, nn.d2);
  q.d3 = __subc_cc(q.d3, nn.d3);
  q.d4 = __subc_cc(q.d4, nn.d4);
  q.d5 = __subc   (q.d5, nn.d5);

/********** Step 2, Offset 2^55 (1*32 + 23) **********/
  qf= __uint2float_rn(q.d5);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d4);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d3);
  qf*= 512.0f;

  qi=__float2uint_rz(qf*nf);

  res->d1 =  qi << 23;
  res->d2 = __add_cc(res->d2, qi >>  9);
  res->d3 = __addc  (res->d3, 0);

// nn = n * qi
  nn.d1 =                                 __umul32(n.d0, qi);
  nn.d2 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d3 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d4 = __addc   (__umul32hi(n.d2, qi),                  0);

// shiftleft nn 23 bits
  nn.d4 = (nn.d4 << 23) + (nn.d3 >> 9);
  nn.d3 = (nn.d3 << 23) + (nn.d2 >> 9);
  nn.d2 = (nn.d2 << 23) + (nn.d1 >> 9);
  nn.d1 =  nn.d1 << 23;

// q = q - nn
  q.d1 = __sub_cc (q.d1, nn.d1);
  q.d2 = __subc_cc(q.d2, nn.d2);
  q.d3 = __subc_cc(q.d3, nn.d3);
  q.d4 = __subc   (q.d4, nn.d4);

/********** Step 3, Offset 2^35 (1*32 + 3) **********/

  qf= __uint2float_rn(q.d4);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d3);
  qf*= 536870912.0f;

  qi=__float2uint_rz(qf*nf);

  res->d1 = __add_cc (res->d1, qi << 3 );
  res->d2 = __addc_cc(res->d2, qi >> 29);
  res->d3 = __addc   (res->d3, 0);

// shiftleft qi 3 bits to avoid "long shiftleft" after multiplication
  qi <<= 3;

// nn = n * qi
  nn.d1 =                                 __umul32(n.d0, qi);
  nn.d2 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d3 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d4 = __addc   (__umul32hi(n.d2, qi),                  0);

//  q = q - nn
  q.d1 = __sub_cc (q.d1, nn.d1);
  q.d2 = __subc_cc(q.d2, nn.d2);
  q.d3 = __subc_cc(q.d3, nn.d3);
  q.d4 = __subc   (q.d4, nn.d4);

/********** Step 4, Offset 2^15 (0*32 + 15) **********/

  qf= __uint2float_rn(q.d4);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d3);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d2);
  qf*= 131072.0f;

  qi=__float2uint_rz(qf*nf);

  res->d0 = qi << 15;
  res->d1 = __add_cc (res->d1, qi >> 17);
  res->d2 = __addc_cc(res->d2, 0);
  res->d3 = __addc   (res->d3, 0);

// nn = n * qi
  nn.d0 =                                 __umul32(n.d0, qi);
  nn.d1 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d2 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d3 = __addc   (__umul32hi(n.d2, qi),                  0);

// shiftleft nn 15 bits
  nn.d3 = (nn.d3 << 15) + (nn.d2 >> 17);
  nn.d2 = (nn.d2 << 15) + (nn.d1 >> 17);
  nn.d1 = (nn.d1 << 15) + (nn.d0 >> 17);
  nn.d0 =  nn.d0 << 15;

//  q = q - nn
  q.d0 = __sub_cc (q.d0, nn.d0);
  q.d1 = __subc_cc(q.d1, nn.d1);
  q.d2 = __subc_cc(q.d2, nn.d2);
  q.d3 = __subc   (q.d3, nn.d3);

/********** Step 5, Offset 2^0 (0*32 + 0) **********/

  qf= __uint2float_rn(q.d3);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d2);
  qf= qf * 4294967296.0f + __uint2float_rn(q.d1);

  qi=__float2uint_rz(qf*nf);

  res->d0 = __add_cc (res->d0, qi);
  res->d1 = __addc_cc(res->d1,  0);
  res->d2 = __addc_cc(res->d2,  0);
  res->d3 = __addc   (res->d3, 0);

// nn = n * qi
  nn.d0 =                                 __umul32(n.d0, qi);
  nn.d1 = __add_cc (__umul32hi(n.d0, qi), __umul32(n.d1, qi));
  nn.d2 = __addc_cc(__umul32hi(n.d1, qi), __umul32(n.d2, qi));
  nn.d2 = __addc   (__umul32hi(n.d2, qi),                  0);

//  q = q - nn
  q.d0 = __sub_cc (q.d0, nn.d0);
  q.d1 = __subc_cc(q.d1, nn.d1);
  q.d2 = __subc_cc(q.d2, nn.d2);
  q.d3 = __subc   (q.d3, nn.d3);

/*
qi is always a little bit too small, this is OK for all steps except the last
one. Sometimes the result is a little bit bigger than n
*/

  tmp96.d0=q.d0;
  tmp96.d1=q.d1;
  tmp96.d2=q.d2;
  if(q.d3 || cmp_ge_96(tmp96,n))
  {
    res->d0 = __add_cc (res->d0,  1);
    res->d1 = __addc_cc(res->d1,  0);
    res->d2 = __addc_cc(res->d2,  0);
    res->d3 = __addc   (res->d3,  0);
  }
}


__device__ static void mod_simple_128_96(int96 *res, int128 q, int96 n, float nf)
/*
res = q mod n
used for refinement in barrett modular multiplication
assumes q < Xn where X is a small integer
*/
{
  float qf;
  unsigned int qi;
  int128 nn;

  qf = __uint2float_rn(q.d3);
  qf = qf * 4294967296.0f + __uint2float_rn(q.d2);
  qf = qf * 4294967296.0f + __uint2float_rn(q.d1);

  qi=__float2uint_rz(qf*nf);

  nn.d0 =                           __umul32(n.d0, qi);
  nn.d1 = __umad32hi_cc  (n.d0, qi, __umul32(n.d1, qi));
  nn.d2 = __umad32hic_cc (n.d1, qi, __umul32(n.d2, qi));
  nn.d3 = __umad32hic    (n.d2, qi,                  0);

  res->d0 = __sub_cc (q.d0, nn.d0);
  res->d1 = __subc_cc(q.d1, nn.d1);
  res->d2 = __subc_cc(q.d2, nn.d2);
  q.d3 =    __subc   (q.d3, nn.d3);

  if(q.d3 || cmp_ge_96(*res, n))			// final adjustment in case finalrem >= f
  {
    sub_96(res, *res, n);
  }
}


__device__ static void mod_simple_96(int96 *res, int96 q, int96 n, float nf)
/*
res = q mod n
used for refinement in barrett modular multiplication
assumes q < Xn where X is a small integer
*/
{
  float qf;
  unsigned int qi;
  int96 nn;

  qf = __uint2float_rn(q.d2);
  qf = qf * 4294967296.0f + __uint2float_rn(q.d1);

  qi=__float2uint_rz(qf*nf);

  nn.d0 =                          __umul32(n.d0, qi);
  nn.d1 = __umad32hi_cc (n.d0, qi, __umul32(n.d1, qi));
  nn.d2 = __umad32hic   (n.d1, qi, __umul32(n.d2, qi));

  res->d0 = __sub_cc (q.d0, nn.d0);
  res->d1 = __subc_cc(q.d1, nn.d1);
  res->d2 = __subc   (q.d2, nn.d2);

  if(cmp_ge_96(*res,n))
  {
    sub_96(res, *res, n);
  }
}

#undef DIV_160_96
#include "tf_barrett96_div.cu"
#define DIV_160_96
#include "tf_barrett96_div.cu"
#undef DIV_160_96

