#include <Rcpp.h>
#include <algorithm> 

using namespace Rcpp;

// [[Rcpp::export(name = ".createS")]]
NumericMatrix createS(const int n, const NumericVector &K) {
  int L = (K.size() - 1) / 2;
  NumericMatrix ret = NumericMatrix(n, K.size());

  double sumK = 0.0;
  int indexK = K.size();
  for (int i = 0; i < L; ++i) {
    sumK += K[--indexK];
  }
  
  int i, j = n, l;
  int k;
  for (i = 0; i < L; ++i) {
    sumK += K(--indexK);

    l = i;
    for (k = K.size() - 1; k >= L; --k) {
      ret(l, k) = K(k) / sumK;
    }
    for (; l > 0; --k) {
      ret(--l, k) = K(k) / sumK;
    }

    --j;
    l = j;
    for (k = K.size() - 1; k >= L; --k) {
      ret(l, k) = K(k) / sumK;
    }
    for (; k >= 0; --k) {
      ret(--l, k) = K(k) / sumK;
    }
  }

  sumK += K(--indexK);
  while (i < j) {
    l = i;
    for (k = K.size() - 1; k >= L; --k) {
      ret(l, k) = K(k) / sumK;
    }
    for (; k >= 0; --k) {
      ret(--l, k) = K(k) / sumK;
    }
    ++i;
  }
  
  return ret;
}

// [[Rcpp::export(name = ".createImSX")]]
NumericMatrix createImSX(const int n, const NumericVector &K) {
  int L = (K.size() - 1) / 2;
  NumericMatrix ret = NumericMatrix(n, K.size() - 1);

  int i, j = n, l;
  int k;
  for (i = 0; i < L; ++i) {
    l = i;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - K(k + 1) / K(L - i);
    }
    for (; l > 0; --k) {
      ret(--l, k) = 1 - K(k + 1) / K(L - i);
    }

    --j;
    l = j;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - (K(k + 1) - K(L + 1 + i)) / K(L - i);
    }
    for (; k >= 0; --k) {
      ret(--l, k) = 1 - (K(k + 1) - K(L + 1 + i)) / K(L - i);
    }
  }

  while (i < j) {
    l = i;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - K(k + 1) / K(0);
    }
    for (; k >= 0; --k) {
      ret(--l, k) = 1 - K(k + 1) / K(0);
    }
    ++i;
  }
  
  return ret;
}

// [[Rcpp::export(name = ".createImSXnumerator")]]
NumericMatrix createImSXnumerator(const int n, const NumericVector &K) {
  int L = (K.size() - 1) / 2;
  NumericMatrix ret = NumericMatrix(n, K.size() - 1);
  
  int i, j = n, l;
  int k;
  for (i = 0; i < L; ++i) {
    l = i;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - K(k + 1);
    }
    for (; l > 0; --k) {
      ret(--l, k) = 1 - K(k + 1);
    }
    
    --j;
    l = j;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - (K(k + 1) - K(L + 1 + i));
    }
    for (; k >= 0; --k) {
      ret(--l, k) = 1 - (K(k + 1) - K(L + 1 + i));
    }
  }
  
  while (i < j) {
    l = i;
    for (k = K.size() - 2; k >= L; --k) {
      ret(l, k) = - K(k + 1);
    }
    for (; k >= 0; --k) {
      ret(--l, k) = 1 - K(k + 1);
    }
    ++i;
  }
  
  return ret;
}

// [[Rcpp::export(name = ".prepareImSXcv")]]
List prepareImSXcv(const NumericVector &K, const int Vm1, IntegerVector modValues) {
  int L = (K.size() - 1) / 2;
  
  NumericMatrix val = NumericMatrix(Vm1, K.size());
  IntegerVector lengthsRight = IntegerVector(Vm1, 0);
  
  int indexK;
  for (indexK = K.size() - 1; indexK > L; --indexK) {
    for (int i = 0; i < Vm1; ++i) {
      if (modValues[indexK] != i + 1) {
        val(i, lengthsRight[i]++) = K[indexK];
      }
    }
  }
  
  IntegerVector lengths = clone(lengthsRight);
  for (int i = 0; i < Vm1; ++i) {
    val(i, lengths[i]++) = K[indexK];
  }
  --indexK;

  for (; indexK >= 0; --indexK) {
    for (int i = 0; i < Vm1; ++i) {
      if (modValues[indexK] != i + 1) {
        val(i, lengths[i]++) = K[indexK];
      }
    }
  }
  
  // cumulated sum
  for (int i = 0; i < Vm1; ++i) {
    for (int j = 1; j < lengths[i]; ++j) {
      val(i, j) += val(i, j - 1);
    }
  }
  List ret = List::create(Named("val") = val, Named("lengthsRight") = lengthsRight, Named("lengths") = lengths);
  
  return ret;
}

// [[Rcpp::export(name = ".createImSXcv")]]
NumericMatrix createImSXcv(const int n, const List &precomputed,
                           const int maxLeft, const int maxRight,
                           const IntegerVector rows) {
  NumericMatrix val = precomputed["val"];
  IntegerVector lengths = precomputed["lengths"];
  IntegerVector lengthsRight = precomputed["lengthsRight"];
  NumericMatrix ret = NumericMatrix(n, maxLeft + maxRight);

  int i, j, k;
  int l;
  for (i = 0; i < maxLeft; ++i) {
    l = i;
    for (j = 0, k = maxLeft - 1 + lengthsRight(rows(i)); j < lengthsRight(rows(i)); ++j, --k) {
      ret(l, k) = - val(rows(i), j) / val(rows(i),
          std::min(lengths(rows(i)) - 1, lengthsRight(rows(i)) + i));
    }
    for (; j < lengths(rows(i)) - 1 && l >= 0; ++j, --k) {
      ret(l--, k) = 1 - val(rows(i), j)  / val(rows(i),
          std::min(lengths(rows(i)) - 1, lengthsRight(rows(i)) + i));
    }
  }
  
  for (; i < n - maxRight; ++i) {
    l = i;
    for (j = 0, k = maxLeft - 1 + lengthsRight(rows(i)); j < lengthsRight(rows(i)); ++j, --k) {
      ret(l, k) = - val(rows(i), j) / val(rows(i),lengths(rows(i)) - 1);
    }
    for (; j < lengths(rows(i)) - 1 && l >= 0; ++j, --k) {
      ret(l--, k) = 1 - val(rows(i), j)  / val(rows(i),lengths(rows(i)) - 1);
    }
  }
  
  for (; i < n; ++i) {
    l = i;
    for (j = 0, k = maxLeft - 1 + lengthsRight(rows(i)); j < lengthsRight(rows(i)); ++j, --k) {
      ret(l, k) = - (val(rows(i), j) - val(rows(i), lengthsRight(rows(i)) - n + i)) / 
        (val(rows(i), lengths(rows(i)) - 1)  - val(rows(i), lengthsRight(rows(i)) - n + i));
    }
    for (; j < lengths(rows(i)) - 1 && l >= 0; ++j, --k) {
      ret(l--, k) = 1 - (val(rows(i), j) - val(rows(i), lengthsRight(rows(i)) - n + i)) / 
        (val(rows(i), lengths(rows(i)) - 1)  - val(rows(i), lengthsRight(rows(i)) - n + i));
    }
  }
  
  return ret;
}
