#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>

#define NO_HARDWARE_MUL 1

#ifdef NO_HARDWARE_MUL
// Multiply for 8/16 bit processor with no MUL instruction, eg 6502...
// Not needed on 6809

// generate table of squares or quarter-squares with iteration!
// a live program would put this table into a const array in ROM
uint16_t square[256]; // = { 0,1,4,9,16,25, ..., 65025};
uint16_t squareqtr[256*2]; // = { 0,1,4,9,16,25, ..., 65025};

void init_sq(const int MAX) {
  uint32_t step, i, prev, tot;
  step = i = prev = tot = 0;
  do { // derived from table of differences.
    tot += step;
    if (i < 256) square[i] = tot+prev;  // i^2
    squareqtr[i] = (tot+prev)>>2U;      // (i^2)/4
    //fprintf(stdout, "sq4[%d]=%d\n",i,squareqtr[i]);
    prev = tot;
    i += 1; step += 1;
  } while (i < MAX);
}

static uint16_t umul8add(uint8_t X, uint8_t Y) {
  uint8_t pow_of_2;
  uint16_t res;
  
  if (X > Y) {
    uint8_t swap = X; X = Y; Y = swap;
  } else if (X == Y) return square[X];

  if (X == 0) return 0;

  pow_of_2 = 1; res = Y;
  for (;;) {
    if (X == pow_of_2) return res;
    res = res<<1;
    pow_of_2 = pow_of_2<<1; // 0 comes after 128 because uint8_t
    if (pow_of_2 == 0) break;
  }
  //if (X == 1) return Y;
  //if (X == 2) return Y<<1;
  //if (X == 4) return Y<<2;
  //if (X == 8) return Y<<3;
  //if (X == 16) return Y<<4;
  //if (X == 32) return Y<<5;
  //if (X == 64) return Y<<6;
  //if (X == 128) return Y<<7;

  // derived from formula:  X * Y == X * (X + (Y-X)) == X^2 + X * (Y-X)
  return umul8add(X, Y-X) + square[X]; // not efficient for small X, big Y.
}

static inline uint16_t umul8qtr(uint8_t X, uint8_t Y) {
  return squareqtr[X+Y] -  (X < Y ? squareqtr[Y-X] : squareqtr[X-Y]); // very efficient
}

//#define umul8(x,y) umul8add(x,y)
#define umul8(x,y) umul8qtr(x,y)
#else
#define umul8(x,y) ((x)*(y))
#endif

// =====================================================================================

// This first implementation is based on:
//     https://stackoverflow.com/questions/22845801/32-bit-signed-integer-multiplication-without-using-64-bit-data-type/22847373#22847373
// with some minimal tweaking to supply two different word sizes...

/* compute the full 32-bit product of two signed 16-bit integers */
uint32_t umultiply16 (uint16_t a, uint16_t b) {
  /* split operands into halves */
  uint8_t al = (uint8_t)a;
  uint8_t ah = a >> 8;
  uint8_t bl = (uint8_t)b;
  uint8_t bh = b >> 8;
  uint16_t p0 = umul8(al, bl);   /* compute partial products using native 8x8 -> 16 MUL */
  uint16_t p1 = umul8(al, bh);
  uint16_t p2 = umul8(ah, bl);
  uint16_t p3 = umul8(ah, bh);
  uint16_t cy = ((p0 >> 8) + (uint8_t)p1 + (uint8_t)p2) >> 8;
  /* sum partial products for high */
  uint16_t umul16hi = p3 + (p2 >> 8) + (p1 >> 8) + cy; /* compute the upper 16 bits of the product of two unsigned 16-bit integers */
  //uint16_t umul16lo = p0 + (p2 << 8) + (p1 << 8);
  uint16_t umul16lo = ((p2 + p1) << 8) + p0;
  return (uint32_t)(((uint32_t)umul16hi<<(uint32_t)16) | (uint32_t)umul16lo);  /* bits <31:16> of the product a * b */
}

int32_t multiply16 (int16_t a, int16_t b) {
  /* split operands into halves */
  uint8_t al = (uint8_t)a;
  uint8_t ah = a >> 8;
  uint8_t bl = (uint8_t)b;
  uint8_t bh = b >> 8;
  uint16_t p0 = umul8(al, bl);   /* compute partial products using native 8x8 -> 16 MUL */
  uint16_t p1 = umul8(al, bh);
  uint16_t p2 = umul8(ah, bl);
  uint16_t p3 = umul8(ah, bh);
  uint16_t cy = ((p0 >> 8) + (uint8_t)p1 + (uint8_t)p2) >> 8;
  /* sum partial products for high */
  uint16_t umul16hi = p3 + (p2 >> 8) + (p1 >> 8) + cy; /* compute the upper 16 bits of the product of two unsigned 16-bit integers */
  int16_t mul16hi = umul16hi - ((a < 0) ? b : 0) - ((b < 0) ? a : 0); /* compute the upper 16 bits of the product of two signed 16-bit integers */
  /* sum partial products for low */
  //uint16_t umul16lo = p0 + (p2 << 8) + (p1 << 8);
  uint16_t umul16lo = ((p2 + p1) << 8) + p0;
  return (int32_t)(((int32_t)mul16hi<<(int32_t)16) | (int32_t)((uint32_t)umul16lo));  /* bits <31:16> of the product a * b */
}

// -----------

/* compute the full 64-bit product of two signed 32-bit integers */
int64_t multiply32(int32_t a, int32_t b) {
  /* split operands into halves */
  uint16_t al = (uint16_t)a;
  uint16_t ah = a >> 16;
  uint16_t bl = (uint16_t)b;
  uint16_t bh = b >> 16;
  /* compute partial products */
  uint32_t p0 = umultiply16(al, bl); // Could use umultiply16 here, if we hadn't already merged it with signed mult...
  uint32_t p1 = umultiply16(al, bh);
  uint32_t p2 = umultiply16(ah, bl);
  uint32_t p3 = umultiply16(ah, bh);
  /* sum partial products */
  uint32_t cy = ((p0 >> 16) + (uint16_t)p1 + (uint16_t)p2) >> 16;
  uint32_t umul32hi = p3 + (p2 >> 16) + (p1 >> 16) + cy; /* compute the upper 32 bits of the product of two unsigned 32-bit integers */
  int32_t mul32hi = umul32hi - ((a < 0) ? b : 0) - ((b < 0) ? a : 0); /* compute the upper 32 bits of the product of two signed 32-bit integers */
  /* sum partial products for low */
  //uint32_t umul32lo = p0 + (p2 << 16) + (p1 << 16);
  uint32_t umul32lo = ((p2 + p1) << 16) + p0;
  return (int64_t)mul32hi<<32 | (int64_t)((uint64_t)umul32lo);
                                                                      // requires native 32x32 -> lower-32 multiply.  (16x16 does not work)
}

// =====================================================================================

// The next implementation is based on:
//    https://www.techiedelight.com/multiply-16-bit-integers-using-8-bit-multiplier/
// although it required a lot more hacking to fit our use-case...

static uint32_t umultiply16bit(uint16_t m, uint16_t n) {
  uint8_t mLow = m;  	// stores first 8-bits of m
  uint8_t mHigh = m>>8;	// stores last 8-bits of m

  uint8_t nLow = n;  	// stores first 8-bits of n
  uint8_t nHigh = n>>8;	// stores last 8-bits of n

  uint16_t mLow_nLow   = umul8(mLow,  nLow); // native 8x8 -> 16 MULs
  uint16_t mHigh_nLow  = umul8(mHigh, nLow);
  uint16_t mLow_nHigh  = umul8(mLow,  nHigh);
  uint16_t mHigh_nHigh = umul8(mHigh, nHigh);

  // return mLow_nLow + ((mHigh_nLow + mLow_nHigh) << 8L) + (mHigh_nHigh << 16L);
  // -> return (mHigh_nHigh << 16L) + ((mHigh_nLow + mLow_nHigh) << 8L) + mLow_nLow;
  return (((mHigh_nHigh << 8L) + (mHigh_nLow + mLow_nHigh)) << 8L) + mLow_nLow;
}

static int32_t multiply16bit(int16_t m, int16_t n) {
  int sign = 1;

  if (m < 0) {m = -m; sign = -sign;}
  if (n < 0) {n = -n; sign = -sign;}

  if (sign < 0) {
    return -umultiply16bit(m,n);
  } else {
    return umultiply16bit(m,n);
  }
}

// -----------

static uint64_t umultiply32bit(uint32_t m, uint32_t n) {
  // could use these 4 as parameters instead of 2 longs:
  uint16_t mLow =  m;  		// stores first 8-bits of m
  uint16_t mHigh = m >> 16;	// stores last 8-bits of m

  uint16_t nLow = n;  		// stores first 8-bits of n
  uint16_t nHigh = n >> 16;	// stores last 8-bits of n

  uint32_t mLow_nLow = umultiply16bit(mLow, nLow);
  uint32_t mHigh_nLow = umultiply16bit(mHigh, nLow);
  uint32_t mLow_nHigh = umultiply16bit(mLow, nHigh);
  uint32_t mHigh_nHigh = umultiply16bit(mHigh, nHigh);
  uint32_t LOW, LOW_LOW, SUM, CARRY;

  SUM = mHigh_nLow + mLow_nHigh;
  LOW_LOW  = SUM << 16L;
  CARRY = 0; LOW = mLow_nLow+LOW_LOW; if (LOW < mLow_nLow || LOW < LOW_LOW) CARRY = 1;
	
  return ((uint64_t) ((SUM >> 16L)+mHigh_nHigh+CARRY) << 32LL) | LOW;
}

static int64_t multiply32bit(int32_t m, int32_t n) { // for Vectrex, will need to pass in parameters as multiple 16-bit ints (Q16.16 format)
  int sign = 1;
  // will fail if both params are largest negative integer.  I don't care.
  if (m < 0) {m = -m; sign = -sign;}
  if (n < 0) {n = -n; sign = -sign;}

  if (sign < 0) {
    return -umultiply32bit(m,n);
  } else {
    return umultiply32bit(m,n);
  }
}

// =====================================================================================

static inline uint8_t msb8(register uint8_t x)
{
  x |= (x >> 1);
  x |= (x >> 2);
  x |= (x >> 4);
  return(x & ~(x >> 1));
}

static inline uint16_t msb16(register uint16_t x)
{
  x |= (x >> 1);
  x |= (x >> 2);
  x |= (x >> 4);
  x |= (x >> 8);
  return(x & ~(x >> 1));
}

static inline uint32_t msb32(register uint32_t x)
{
  x |= (x >> 1);
  x |= (x >> 2);
  x |= (x >> 4);
  x |= (x >> 8);
  x |= (x >> 16);
  return(x & ~(x >> 1));
}

static inline uint64_t msb64(register uint64_t x)
{
  x |= (x >> 1);
  x |= (x >> 2);
  x |= (x >> 4);
  x |= (x >> 8);
  x |= (x >> 16);
  x |= (x >> 32);
  return(x & ~(x >> 1));
}

static uint16_t udiv16(uint16_t qq, uint16_t d) {
  uint32_t x, q=qq;
  uint16_t b, res, t;

  t = msb16(d);
  //fprintf(stdout, "q: %d  t: %d  d: %d\n", q, t, d);
  if (t != d) t <<= 1; // t >= d
  //fprintf(stdout, "t: %d >= d: %d\n", t, d);
  b = 1;
  while (t < q) {
    //fprintf(stdout, "t: %d <= q: %d  b: %d\n", t, q, b);
    t <<= 1;
    b <<= 1;
    if (t == 0) break;
  }
  //fprintf(stdout, "... t: %d  q: %d  b: %d\n", t, q, b);
  res = b;
  for (;;) {
    x = umultiply16(d, res); // assumes 16 bit quotient was formed as the product of 2 8-bit factors.  If not, use umultiply16.
    //fprintf(stdout, "trying %d * %d = %d against %d\n", d, res, x, q);
    b >>= 1;
    if (x == q) {
      //fprintf(stdout, "%d / %d = %d (actual %d)\n", q, d, res, q / d);
      return res;
    } else if ((x <= q) && (x+d > q)) {
      //fprintf(stdout, "%d / %d = %d rem %d  (actual %d rem %d)\n", q, d, res, q-x, q / d, q % d);
      return res;
    } else if (x < q) {
      res += b;
    } else if (x > q) {
      res -= b;
    }
  }
}
  
static uint8_t udiv16_by_8(uint16_t q, uint8_t d) {
  uint16_t x, t;
  uint16_t b, res;

  t = msb8(d);
  //fprintf(stdout, "q: %d  t: %d  d: %d\n", q, t, d);
  if (t != d) t <<= 1; // t >= d
  //fprintf(stdout, "t: %d >= d: %d\n", t, d);
  b = 1;
  while (t < q) {
    //fprintf(stdout, "t: %d <= q: %d  b: %d\n", t, q, b);
    t <<= 1;
    b <<= 1;
    if (t == 0) break;
  }
  //fprintf(stderr, "... t: %d  q: %d  b: %d\n", t, q, b);
  res = b;
  for (;;) {
    x = umul8(d, res); // assumes 16 bit quotient was formed as the product of 2 8-bit factors.  If not, use umultiply16.
    //fprintf(stdout, "trying %d * %d = %d against %d\n", d, res, x, q);
    b >>= 1;
    if (x == q) {
      //fprintf(stdout, "%d / %d = %d (actual %d)\n", q, d, res, q / d);
      return res;
    } else if ((x <= q) && (x+d > q)) {
      //fprintf(stdout, "%d / %d = %d rem %d  (actual %d rem %d)\n", q, d, res, q-x, q / d, q % d);
      return res;
    } else if (x < q) {
      res += b;
    } else if (x > q) {
      res -= b;
    }
  }
}
  
static uint32_t udiv32(uint32_t qq, uint32_t d) {
  uint64_t x, q=qq;
  uint32_t b, res, t;

  t = msb32(d); if (t != d) t <<= 1; // t >= d
  b = 1; while (t < q) { t <<= 1; b <<= 1; if (t == 0) break; } res = b;
  for (;;) {
    x = umultiply32bit(d, res); // umultiply16 assumes 32 bit quotient was formed as the product of 2 16-bit factors.  If not, use umultiply32bit.
    //fprintf(stdout, "trying %d * %d = %lld against %lld\n", d, res, x, q);
    b >>= 1;
    if (x == q) {
      //fprintf(stdout, "%lld / %d = %d (actual %lld)\n", q, d, res, q / d);
      return res;
    } else if (x <= q && x+d > q) {
      //fprintf(stdout, "%lld / %d = %d rem %lld  (actual %lld rem %lld)\n", q, d, res, q-x, q / d, q % d);
      return res;
    } else if (x < q) {
      res += b;
    } else if (x > q) {
      res -= b;
    }
  }
}
  
static uint16_t udiv32_by_16(uint32_t q, uint16_t d) {
  uint32_t x, t;
  uint16_t b, res;

  t = msb16(d); if (t != d) t <<= 1; // t >= d
  b = 1; while (t < q) { t <<= 1; b <<= 1; if (t == 0) break; } res = b;
  for (;;) {
    x = umultiply16(d, res); // assumes 32 bit quotient was formed as the product of 2 16-bit factors.  If not, use umultiply32bit.
    //fprintf(stderr, "trying %d * %d = %d against %d\n", d, res, x, q);
    b >>= 1;
    if (x == q) {
      //fprintf(stderr, "%d / %d = %d (actual %d)\n", q, d, res, q / d);
      return res;
    } else if (x <= q && x+d > q) {
      //fprintf(stderr, "%d / %d = %d rem %d  (actual %d rem %d)\n", q, d, res, q-x, q / d, q % d);
      return res;
    } else if (x < q) {
      res += b;
    } else if (x > q) {
      res -= b;
    }
  }
}
  
static uint32_t udiv64_by_32(uint64_t q, uint32_t d) {
  uint64_t x, b, res, t;

  t = msb64(d); if (t != d) t <<= 1; // t >= d
  b = 1; while (t < q) { t <<= 1; b <<= 1; if (t == 0) break; } res = b; // might check bitsize of b here?
  for (;;) {
    x = umultiply32bit(d, res); // assumes 64 bit quotient was formed as the product of 2 32-bit factors.  If not, this won't work.
    //fprintf(stderr, "trying %d * %d = %d against %d\n", d, res, x, q);
    b >>= 1;
    if (x == q) {
      //fprintf(stderr, "%d / %d = %d (actual %d)\n", q, d, res, q / d);
      return res;
    } else if (x <= q && x+d > q) {
      //fprintf(stderr, "%d / %d = %d rem %d  (actual %d rem %d)\n", q, d, res, q-x, q / d, q % d);
      return res;
    } else if (x < q) {
      res += b;
    } else if (x > q) {
      res -= b;
    }
  }
}
  
// =====================================================================================

// vectrex random:
uint16_t _x, _a, _b, _c;

uint16_t vrandom() {
  _x++; _a = (_a^_c^_x); _b = (_b+_a);
  return _c = ((_c+(_b>>1))^_a);
}

void initRandom(uint16_t s1, uint16_t s2, uint16_t s3, uint16_t x0) {
  _x = x0; _a = s1; _b = s2; _c = s3; (void)vrandom();
}


// main function
int main(void) {
  uint8_t i;

#ifdef NO_HARDWARE_MUL
  init_sq(512); // 256 if using the less efficient adder.
#endif
  
  // Need to test GCC's native multiply for various types on Vectrex itself (int:8, long:16, long long:32)
  initRandom(294, 4780, 11978, 9371);

  // Need to do a lot more loops and time them to get comparitive timings for each method
  for (i = 0; i < 10; i++) {
    static uint16_t m = 0x7f, n = 0x7f;
    uint16_t t;
    t = m * n;
    printf("Normal 8x8 multiplication           m * n = %08x (%d = %d * %d)\n", t, t, m, n);
    t = umul8(m, n);
    printf("Using 8-bit mul                     m * n = %08x (%d)\n", t, t);
    if ((m >= 0) && (n >= 0)) {
      n = udiv16_by_8(t, m);
      printf("Division                            %d / %d = %d\n", t, m, n);
      m = 2;
      n = udiv16(t, m);
      printf("Division                            %d / %d = %d\n", t, m, n);
    }
    m = vrandom()&255; n = vrandom()&255;
  }
  for (i = 0; i < 10; i++) {
    static int16_t m = 0x8000, n = 0x8000;
    int32_t t, nn;
    t = m * n;
    printf("Normal 16x16 multiplication         m * n = %08x (%d = %d * %d)\n", t, t, m, n);
    t = multiply16bit(m, n);
    printf("Using 8-bit multiplier A            m * n = %08x (%d)\n", t, t);
    t = multiply16(m, n);
    printf("Using 8-bit multiplier B'           m * n = %08x (%d)\n", t, t);
    if ((m >= 0) && (n >= 0)) {
      n = udiv32_by_16(t, m);
      printf("Division                            %d / %d = %d\n", t, m, n);
      m = 2;
      nn = udiv32(t, (int32_t)m);
      printf("Division                            %d / %d = %d\n", t, m, nn);
    }
    m = vrandom(); n = vrandom();
  }
  for (i = 0; i < 10; i++) {
    static int32_t m = 0x80000000, n = 0x80000000;
    int64_t t;
    t = (int64_t)m * (int64_t)n;
    printf("Normal 32x32 multiplication         m * n = %016llx (%lld = %d * %d)\n", t, t, m, n);
    t = multiply32bit(m, n);
    printf("Using cascaded 8-bit multipliers A  m * n = %016llx (%lld)\n", t, t);
    t = multiply32(m, n);
    printf("Using cascaded 8-bit multipliers B' m * n = %016llx (%lld)\n", t, t);
    if ((m >= 0) && (n >= 0)) {
      n = udiv64_by_32(t, m);
      printf("Division                            %lld / %d = %d\n", t, m, n);
    }
    m = ((int32_t)vrandom()) * ((int32_t)vrandom()); n = ((int32_t)vrandom()) * ((int32_t)vrandom());
  }
  return 0;
}