/*
  Author: Xuye Luo
  Date: December 12, 2025
*/

#include <Rcpp.h>
#include <unordered_map>
#include <cmath>
using namespace Rcpp;

// [[Rcpp::interfaces(r, cpp)]]

/* 
//' @title Fast G-Test Statistic (C++ Backend)
//' @description Calculates the G-test (Likelihood Ratio) statistic directly from raw data vectors
//' using hash maps for efficient counting.
//' @param x Numeric vector of the first variable.
//' @param y Numeric vector of the second variable.
//' @return A List containing the statistic, sample size (n), row count (nr), and col count (nc).
 */

// [[Rcpp::export]]
List gtest_cpp(const NumericVector &x, const NumericVector &y) {
  
  int n = x.size();
  
  // Input Validation
  if (n != y.size()) {
    stop("Lengths of 'x' and 'y' must match.");
  }
  
  if (n == 0) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = 0,
                        Named("nr") = 0,
                        Named("nc") = 0);
  }

  // Build Contingency Table using Hash Maps
  std::unordered_map<double, std::unordered_map<double, double>> observed;
  std::unordered_map<double, double> row_sum;
  std::unordered_map<double, double> col_sum;
  
  for (int i = 0; i < n; i++) {
    double val_x = x[i];
    double val_y = y[i];
    
    observed[val_x][val_y]++;
    row_sum[val_x]++;
    col_sum[val_y]++;
  }
  
  int nr = row_sum.size();
  int nc = col_sum.size();
  

  if (nr < 2 || nc < 2) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = n,
                        Named("nr") = nr,
                        Named("nc") = nc);
  }
  
  // Calculate G Statistic
  // Formula: G = 2 * Sum( O * ln(O / E) )
  double statistic = 0.0;
  double N_dbl = (double)n;

  for (auto const& row : observed) {
    double r_sum = row_sum[row.first]; // Row marginal
    
    for (auto const& cell : row.second) {
      double O = cell.second; // Observed count
      double c_sum = col_sum[cell.first]; // Col marginal
      
      // Expected = (RowSum * ColSum) / N
      double E = (r_sum * c_sum) / N_dbl;
      


      if (O > 0 && E > 0) {
        statistic += O * std::log(O / E);
      }
    }
  }
  
  statistic *= 2.0;

  if (statistic < 0) statistic = 0;

  return List::create(Named("statistic") = statistic, 
                      Named("n")  = n,
                      Named("nr") = nr,
                      Named("nc") = nc);
}
