#include <stdio.h>
#include <string.h>
#include <ctype.h>
#include <stdlib.h>

// DECIMAL MULTI-PRECISION ARITHMETIC
// ==================================

// This is a quick sketch of extended precision integer multiplication using decimal representation of digits
// (so no need to convert large numbers to and from binary before performing multiplication).

// I wrote this because I was thinking about ways to implement computations mechanically, after seeing a
// presentation about Ludgate and his mechanical computer design.  Ludgate used discrete logarithms to
// multiply two decimal digits.  I'm using a simple lookup table and pondering how one might implement
// that mechanically.  While contemplating that facet of the calculation I realised that the extension
// to multiple digits was actually going to be the bulk of any device and the ripple carry between adders
// was going to be a significant factor in the computation speed.

// Note that I'm using one digit per byte rather than two digits per byte (ie BCD) both for simplicity,
// and because a mechanical hardware implementation would have no concept of a byte anyway.

// I am representing decimal numbers in little-endian form, i.e. lowest position first followed by
// powers of ten to the right.  This is the opposite of the normal form in Left-to-right languages.
// It simplifies computation because digit 0 of any multi-digit number is always the single units digit,
// unlike normal notation where the units are in an arbitrary n'th position depending on the size of
// the number.

// Note that this interface is likely to cause some heap lossage. It's unlikely to be significant.

static int debug = 0;

typedef struct mp {
  int digits;   // number of decimal digits stored in d[].
                // I SHOULD allow digits=0 to represent 0, with no space allocated for d[] - but I don't, yet, completely.
  int decimal;  // where the decimal point goes. 0 means all integer.  Currently only integer arithmetic supported.
  int repeat_first, repeat_last; // decimal repeating sequence such as .142142142 etc. NOT YET USED.
  int negative; // 0 if number is 0 or positive, 1 if number is negative
  unsigned char d[];
} mp;


static inline mp *normalise(mp *self) {
  // assert all input parameters are normalised already - we only normalise results about to be returned.
  int i = self->digits;
  while ((self->d[--i] == 0) && i) ;
  self->digits = i+1;
  return self;
}

mp *mpi(char *ascii) {
  // convert an ascii string into a low-endian decimal blob
  int neg = (*ascii == '-' ? 1 : 0);
  int low = 0, len = strlen(ascii+neg);
  mp *n = malloc(sizeof(mp)+len);
  if (!n) { fprintf(stderr, "mpi: malloc(%d) fails.\n", sizeof(mp)+len); exit(1); }
  n->digits = len; n->negative = neg;
  do n->d[low++] = ascii[--len+neg]-'0'; while (len);
  return normalise(n);
}

mp *copy(mp *num) {
  // make a duplicate of a number in case it is modified in situ
  int i;
  mp *n = malloc(sizeof(mp)+num->digits);
  if (!n) { fprintf(stderr, "copy: malloc(%d) fails.\n", sizeof(mp)+num->digits+1); exit(1); }
  *n = *num; // copy entire struct
  for (i = 0; i < num->digits; i++) n->d[i] = num->d[i];
  return n;
}

char *mptoascii(mp *m) {
  // revert low-endian decimal block back to standard ascii left-to-right number, for printing.
  char *s;
  int low = 0, len = m->digits;
  s = malloc(len+1);  if (!s) { fprintf(stderr, "mptoascii: malloc(%d) fails.\n", len+1); exit(1); }  s[len] = '\0';
  do s[--len] = m->d[low++]; while (len);
  return s;
}

static inline void pnum(FILE *f, mp *m, int tens) {
  // print a number, possibly with some extra trailing zeroes.
  // internal use only.
  int i = m->digits;
  if (m->negative) fputc('-', f);
  do fputc(m->d[--i]+'0', f); while (i > tens);
}

int compare(mp *a, mp *b) {
  // compare two numbers.  a < b => -1, a == b => 0, a > b => 1
  int i;
  if (a->digits > b->digits) {
    if (a->digits==1 && b->digits==0 && a->d[0]==0) {
      return 0;
    }
    return  1;
  }
  if (a->digits < b->digits) {
    if (a->digits==0 && b->digits==1 && b->d[0]==0) {
      return 0;
    }
    return -1;
  }
  if (a->digits == 0) return 0;
  // a->digits == b->digits
  for (i = a->digits-1; i >= 0; i--) { // rightmost digit is highest digit, eg 100 is represented as [001]
    if (a->d[i] > b->d[i]) return  1;
    if (a->d[i] < b->d[i]) return -1;
  }
  return 0;
}


// Lookup tables for N x N multiplication. One table for the units part of the result and one for the tens.

// e.g. 9 * 9 = 81, so look at product[9][9]
const int upper_product[10][10] = {
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,},
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,},
  { 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,},
  { 0, 0, 0, 0, 1, 1, 1, 2, 2, 2,},
  { 0, 0, 0, 1, 1, 2, 2, 2, 3, 3,},
  { 0, 0, 1, 1, 2, 2, 3, 3, 4, 4,},
  { 0, 0, 1, 1, 2, 3, 3, 4, 4, 5,},
  { 0, 0, 1, 2, 2, 3, 4, 4, 5, 6,},
  { 0, 0, 1, 2, 3, 4, 4, 5, 6, 7,},
  { 0, 0, 1, 2, 3, 4, 5, 6, 7, 8,},
}; //                          ^ '8' part of '81'

const int lower_product[10][10] = {
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,},
  { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,},
  { 0, 2, 4, 6, 8, 0, 2, 4, 6, 8,},
  { 0, 3, 6, 9, 2, 5, 8, 1, 4, 7,},
  { 0, 4, 8, 2, 6, 0, 4, 8, 2, 6,},
  { 0, 5, 0, 5, 0, 5, 0, 5, 0, 5,},
  { 0, 6, 2, 8, 4, 0, 6, 2, 8, 4,},
  { 0, 7, 4, 1, 8, 5, 2, 9, 6, 3,},
  { 0, 8, 6, 4, 2, 0, 8, 6, 4, 2,},
  { 0, 9, 8, 7, 6, 5, 4, 3, 2, 1,},
}; //                          ^ '1' part of '81'

// This is a one-digit multiply-and-accumulate step.
void mul_add_digit(int multiplicand_digit, mp *result, int result_index, unsigned char multiplier_digit) {
  int need_carry;

  // We get a 2-digit result when multiplying 2 1-digit decimal numbers.  So two adds and two carries.

  // These two adds could execute in parallel. Note use of two parallel lookup tables instead of hardware multiply.
  result->d[result_index]   += lower_product[multiplicand_digit][multiplier_digit];
  result->d[result_index+1] += upper_product[multiplicand_digit][multiplier_digit];

  do {
    need_carry = 0;
    while (result->d[result_index] >= 10) {
      result->d[result_index] -= 10;
      result->d[result_index+1] += 1;
      need_carry = 1;
    }
    // ripple carry
    result_index += 1;
  } while (need_carry);
}

// This is a multiply-and-accumulate step where an arbitrary length multiplicand
// is multiplied by a single-digit multiplier.
static inline void multiply_and_accumulate(mp *multiplicand, mp *multiplier, mp *result, int tens_shift) {
  int i, multiplier_digit = multiplier->d[tens_shift];
  for (i = 0; i < multiplicand->digits; i++) mul_add_digit(multiplicand->d[i], result, tens_shift+i, multiplier_digit);
}

// And finally we perform multiple multiply-and-accumulate steps to produce the product
// of two arbitrary length decimal numbers.  In hardware the 'arbitrary length' would
// be done by extending the hardware units side-by-side, though note that a ripple-carry
// is still required (which I didn't see mentioned in the Ludgate documentation)
// Slightly faster to use the smaller multiplier for the actual multiply step.
static inline mp *multiply(mp *multiplicand, mp *multiplier) {
  int i = multiplicand->digits + multiplier->digits, tens_shift = 0;
  mp *result = calloc(sizeof(mp) + i, 1);
  if (!result) { fprintf(stderr, "multiply: calloc(%d) fails.\n", sizeof(mp)+i); exit(1); }
  result->digits = i; result->decimal = 0; result->negative = 0;
  if (multiplicand->digits >= multiplier->digits) {
    while (tens_shift <= (multiplier->digits - 1)) multiply_and_accumulate(multiplicand, multiplier, result, tens_shift++);
  } else {
    while (tens_shift <= (multiplicand->digits - 1)) multiply_and_accumulate(multiplier, multiplicand, result, tens_shift++);
  }
  result->decimal = multiplicand->decimal + multiplier->decimal;
  result->negative = multiplicand->negative ^ multiplier->negative;
  return normalise(result);
}

mp *mulby10(mp *num) {
  int i;
  mp *n = malloc(sizeof(mp)+num->digits+1);
  if (!n) { fprintf(stderr, "mulby10: malloc(%d) fails.\n", sizeof(mp)+num->digits+1); exit(1); }
  *n = *num; // copy entire struct except for digits
  for (i = 0; i < num->digits; i++) n->d[i+1] = num->d[i]; n->d[0] = 0; n->digits += 1;
  return n;
}

mp *divby10(mp *num) {
  int i;
  mp *n = malloc(sizeof(mp)+num->digits); // usually allocates 1 extra byte
  if (!n) { fprintf(stderr, "divby10: malloc(%d) fails.\n", sizeof(mp)+num->digits); exit(1); }
  *n = *num; // copy entire struct except for digits
  for (i = 1; i < num->digits; i++) n->d[i-1] = num->d[i];
  if (--n->digits == 0) {
    n->d[0] = 0; n->digits = 1; // this is why we allowed the extra byte ...
  }
  return n;
}

mp *add(mp *a, mp *b);

mp *subtract(mp *a, mp *b) { // a - b ...
  mp *tmp;
  if (a->negative != b->negative) { // subtract positive quantity
    if (a->negative) { // -a - b == - (a+b)
      a->negative = 0; tmp = add(a,b); a->negative = 1; // restore
      tmp->negative = 1;
    } else { // a - -b == a+b
      b->negative = 0; tmp = add(a,b); b->negative = 1; // restore
      tmp->negative = 0;
    }
    return tmp;
  } else if (a->negative) { // both
    // -a - -b == -a + b == b - a
    tmp = a; a = b; b = tmp;
  }

  // calculate as if both are positive
  int i = 1+((a->digits > b->digits) ? a->digits : b->digits);
  mp *result = calloc(sizeof(mp) + i, 1);
  if (!result) { fprintf(stderr, "subtract: calloc(%d) fails.\n", sizeof(mp)+i); exit(1); }
  result->digits = i; result->decimal = 0; result->negative = 0;
  result->d[0] = 1; // a - b == a + (-b) == a + ~b+1 where ~b is 9's complement
  for (i = 0; i < result->digits; i++) {
    int sum = (i < a->digits ? a->d[i] : 0) + 9-(i < b->digits ? b->d[i] : 0);
    result->d[i] += sum%10; result->d[i+1] += sum/10;
    if (result->d[i] >= 10) { result->d[i] -= 10; result->d[i+1] += 1; }
  }
  if (result->d[result->digits-1] == 9) {
    result->negative = 1;
    for (i = 0; i < result->digits; i++) result->d[i] = 9 - result->d[i];
    result->d[0] += 1;
    for (i = 0; i < result->digits; i++) {
      if (result->d[i] >= 10) {
        result->d[i] -= 10; result->d[i+1] += 1;
      }
    }
  }
  return normalise(result);
}

mp *add(mp *a, mp *b) {
  if (a->negative != b->negative) { // subtract positive quantity
    mp *tmp;
    if (a->negative) {
      a->negative = 0; tmp = subtract(b,a); a->negative = 1; // restore
    } else {
      b->negative = 0; tmp = subtract(a,b); b->negative = 1; // restore
    }
    return tmp;
  }

  int i = 1+((a->digits > b->digits) ? a->digits : b->digits);
  mp *result = calloc(sizeof(mp) + i, 1);
  if (!result) { fprintf(stderr, "add: calloc(%d) fails.\n", sizeof(mp)+i); exit(1); }
  result->digits = i; result->decimal = 0; result->negative = 0;
  for (i = 0; i < result->digits; i++) {
    int sum = (i < a->digits ? a->d[i] : 0) + (i < b->digits ? b->d[i] : 0);
    result->d[i] += sum%10; result->d[i+1] += sum/10;
    if (result->d[i] >= 10) {
      result->d[i] -= 10; result->d[i+1] += 1;
    }
  }
  result->negative = a->negative; // both
  return normalise(result);
}

mp *pos_div(mp *numerator, mp *denom, mp **remainder) {
  mp *full_result, *scaled_tens, *scaled_denom;
  int comp;
  numerator = copy(numerator); full_result = mpi("0"); scaled_denom = copy(denom);
  
  for (;;) { // As long as we can subtract the denominator without going negative
             // i.e. as long as denominator < numerator
    comp = compare(denom, numerator);
    if (comp > 0) { // denom > numerator
      *remainder  = numerator;
      return full_result;
    }
    if (comp == 0) { // denom == numerator
      *remainder  = mpi("0");
      full_result = add(full_result,  mpi("1"));
      return full_result;
    }
    scaled_denom = copy(denom);    scaled_tens  = mpi("1");
    for (;;) {
      int comp;
      comp = compare(scaled_denom, numerator);
      if (comp >= 0) break;
      scaled_denom = mulby10(scaled_denom);      scaled_tens  = mulby10(scaled_tens);
    }
    scaled_denom = divby10(scaled_denom);
    scaled_tens  = divby10(scaled_tens);

    for (;;) {
      numerator = subtract(numerator, scaled_denom); 
      if (numerator->negative) break;
      full_result = add(full_result, scaled_tens);
    }
    numerator      = add(numerator,    scaled_denom);
  }
}

mp *divide(mp *num, mp *denom, mp **remainder) {
  int nn = num->negative, dn = denom->negative;
  int negative = nn ^ dn;
  num->negative = 0;  denom->negative = 0;
  mp *tmp = pos_div(num, denom, remainder);
  num->negative = nn; denom->negative = dn;
  tmp->negative = negative;
  return tmp;
}

int main(int argc, char **argv) {
  //char *t1;
  mp *a, *b, *c, *d, *e, *f, *g, *h, *rem;
  
  if (argc > 1 && !strcmp(argv[1], "-d")) {
    debug = 1; argc -= 1; argv += 1;
  }
  
  if (argc != 3) {
    fprintf(stderr, "syntax: mp [-d] nnn nnn\n\n");
    exit(EXIT_FAILURE);
  }
  
  if (!(*argv[1] == '-' || isdigit(*argv[1])) || !(*argv[2] == '-' || isdigit(*argv[2]))) {
    fprintf(stderr, "mp: both arguments should be integers: %s %s\n\n", argv[1], argv[2]);
    exit(EXIT_FAILURE);
  }

  a = mpi(argv[1]);
  b = mpi(argv[2]);
  c = multiply(a,b);
  d = add(a,b);
  e = subtract(a,b);
  f = subtract(b,a);
  
  pnum(stdout, a,0); fprintf(stdout, " * "); pnum(stdout, b,0); fprintf(stdout, " = "); pnum(stdout, c,0); fprintf(stdout, "\n"); 
  pnum(stdout, a,0); fprintf(stdout, " + "); pnum(stdout, b,0); fprintf(stdout, " = "); pnum(stdout, d,0); fprintf(stdout, "\n"); 
  pnum(stdout, a,0); fprintf(stdout, " - "); pnum(stdout, b,0); fprintf(stdout, " = "); pnum(stdout, e,0); fprintf(stdout, "\n"); 
  pnum(stdout, b,0); fprintf(stdout, " - "); pnum(stdout, a,0); fprintf(stdout, " = "); pnum(stdout, f,0); fprintf(stdout, "\n"); 
  pnum(stdout, a,0); fprintf(stdout, " * 100 = "); g = mulby10(a); g = mulby10(g); pnum(stdout, g,0); fprintf(stdout, "\n");
  pnum(stdout, g,0); fprintf(stdout, " / 10 = "); h = divby10(g); pnum(stdout, h,0); fprintf(stdout, "\n");
  pnum(stdout, h,0); fprintf(stdout, " / 10 = "); g = divby10(h); pnum(stdout, g,0); fprintf(stdout, "\n");
  pnum(stdout, g,0); fprintf(stdout, " / 10 = "); h = divby10(g); pnum(stdout, h,0); fprintf(stdout, "\n");
  g = divide(a,b,&rem);
  pnum(stdout, a,0); fprintf(stdout, " / "); pnum(stdout, b,0); fprintf(stdout, " = ");
  pnum(stdout, g,0); fprintf(stdout, " remainder "); pnum(stdout, rem,0); fprintf(stdout, "\n"); 
  h = divide(b,a,&rem);
  pnum(stdout, b,0); fprintf(stdout, " / "); pnum(stdout, a,0); fprintf(stdout, " = ");
  pnum(stdout, h,0); fprintf(stdout, " remainder "); pnum(stdout, rem,0); fprintf(stdout, "\n"); 
  
  exit(EXIT_SUCCESS);
  return EXIT_FAILURE;
}