# =============================================================================
# PTL estimation methods (requires: DoubleML, ncvreg)
# =============================================================================

#' PTL (Transfer Learning) fit: cross-fit + outcome reconstruction through partial transfer learning.
#'
#' @param D_t Target treatment; n_t x 1.
#' @param X_t Target design matrix; n_t x p.
#' @param Y_t Target outcome; n_t x 1.
#' @param D_s Source treatment; n_s x 1.
#' @param X_s Source design matrix; n_s x p.
#' @param Y_s Source outcome; n_s x 1.
#' @param ml_f Outcome learner for DoubleML (e.g. lrn("regr.cv_glmnet")).
#' @param ml_g Treatment learner for DoubleML (same type).
#' @param fold Number of folds for cross-fitting (default 5).
#' @return List with: hat_rho_s (source causal estimate), beta_hat_s (source nuisance
#'   estimate), E_s (estimated E[Y - rho*D] on source), hat_rho_PTL (PTL
#'   causal estimate on target).
fit_PTL <- function(D_t, X_t, Y_t, D_s, X_s, Y_s, ml_f, ml_g, fold = 5L) {
  n_s <- nrow(Y_s)
  if (n_s %% fold != 0) {
    stop("Number of source observations must be divisible by the number of folds.")
  }
  q   <- ncol(D_s)
  p   <- ncol(X_s)

  # cross-fitting
  rho_hat   <- c()
  beta_hat  <- matrix(NA, nrow = p, ncol = fold)
  for (i in 1:fold) {
    selector   <- rep(1:fold, each = n_s/fold)
    D_s_train  <- D_s[selector != i, ]
    X_s_train  <- X_s[selector != i, ]
    Y_s_train  <- Y_s[selector != i, ]
    D_s_test   <- D_s[selector == i, ]
    X_s_test   <- X_s[selector == i, ]
    Y_s_test   <- Y_s[selector == i, ]

    dml_data_s_train <- DoubleML::double_ml_data_from_matrix(X = X_s_train, y = Y_s_train, d = D_s_train)
    obj_dml_s <- DoubleML::DoubleMLPLR$new(dml_data_s_train, ml_l = ml_f, ml_m = ml_g)
    fit_s <- obj_dml_s$fit()
    rho_hat[i] <- fit_s$all_coef

    Y_s_test_new <- Y_s_test - rho_hat[i] * D_s_test
    model_cv <- ncvreg::cv.ncvreg(X_s_test, Y_s_test_new, intercept = FALSE, nfolds = 5, penalty = "SCAD")
    beta_hat[, i] <- coef(model_cv)[-1]
  }
  hat_rho_s <- mean(rho_hat)
  beta_hat_s <- matrix(rowMeans(beta_hat))

  if (q == 1) {
    E_s <- mean(Y_s - hat_rho_s * D_s)
    Delta_hat <- matrix(colMeans(X_s) - colMeans(X_t))
    # outcome reconstruction through partial transfer learning
    hat_rho_PTL <- as.numeric(
        (mean(Y_t) - E_s + t(Delta_hat) %*% beta_hat_s) / mean(D_t))
  } else {
    E_s <- NA
    # outcome reconstruction through partial transfer learning
    Y_t_new <- Y_t - X_t %*% beta_hat_s
    lm_fit_t_new <- lm(Y_t_new ~ D_t + 0)
    hat_rho_PTL <- lm_fit_t_new$coefficients
  }

  list(
    hat_rho_s = hat_rho_s,
    beta_hat_s = beta_hat_s,
    E_s = E_s,
    hat_rho_PTL = hat_rho_PTL
  )
}



#' HPTL (Heterogeneous Partial Transfer Learning) fit
#'
#' @param D_t Target treatment; n_t x q matrix.
#' @param X_t Target design matrix; n_t x p (same covariates as sources).
#' @param Y_t Target outcome; n_t x 1.
#' @param D_s_all Source treatments concatenated by row; (sum of source_sizes) x q. Rows are split by source_sizes into K sources.
#' @param X_s_all Source design matrices concatenated by row; (sum of source_sizes) x p. Rows split by source_sizes.
#' @param Y_s_all Source outcomes concatenated by row; (sum of source_sizes) x 1. Rows split by source_sizes.
#' @param source_sizes Integer vector of length K: sample size of each source. Must sum to nrow(Y_s_all).
#' @param module_sizes Integer vector of length K: covariate module sizes for each source.
#' The k-th source uses columns (cumsum(module_sizes)[k-1]+1):cumsum(module_sizes)[k] of X;
#' length(module_sizes) must equal length(source_sizes).
#' @param ml_f Outcome learner for DoubleML (e.g. lrn("regr.cv_glmnet")).
#' @param ml_g Treatment learner for DoubleML (same type).
#' @param fold Number of folds for cross-fitting (default 5).
#' @return List with: hat_rho_HPTL (HPTL causal estimate on target).
fit_HPTL <- function(D_t, X_t, Y_t, D_s_all, X_s_all, Y_s_all, source_sizes, module_sizes, ml_f, ml_g, fold = 5) {

  # create source lists based on source_sizes
  if (sum(source_sizes) != nrow(Y_s_all)) {
    stop("Sum of source_sizes must equal number of rows in D_s_all, X_s_all or Y_s_all.")
  }
  if (length(module_sizes) != length(source_sizes)) {
    stop("length(module_sizes) must equal length(source_sizes).")
  }
  ends <- cumsum(source_sizes)
  starts <- c(1, ends[-length(ends)] + 1)
  D_s_list <- lapply(seq_along(source_sizes), function(k) {
    D_s_all[starts[k]:ends[k], , drop = FALSE]
  })
  X_s_list <- lapply(seq_along(source_sizes), function(k) {
    X_s_all[starts[k]:ends[k], , drop = FALSE]
  })
  Y_s_list <- lapply(seq_along(source_sizes), function(k) {
    Y_s_all[starts[k]:ends[k], , drop = FALSE]
  })

  # Call fit_PTL for cross-fitting each source and collect beta_hat_s
  # 1. Receive the full returned list
  results <- lapply(seq_along(source_sizes), function(k) {
      fit_k <- fit_PTL(D_t, X_t, Y_t,
                      D_s_list[[k]], X_s_list[[k]], Y_s_list[[k]],
                      ml_f, ml_g, fold)
      list(hat_rho_s = fit_k$hat_rho_s, beta_hat_s = fit_k$beta_hat_s)
  })
  # 2. Extract hat_rho_s and beta_hat_s separately
  hat_rho_s_list <- lapply(results, function(x) x$hat_rho_s)
  beta_hat_s_list <- lapply(results, function(x) x$beta_hat_s)
  # print(hat_rho_s_list)

  # Convert module_sizes to consecutive index vectors, e.g. c(100,50,150,200) -> c(1:100, 101:150, 151:300, 301:500)
  # 1. Compute end index for each module
  ends <- cumsum(module_sizes)
  # 2. Compute start index for each module (start = end - length + 1)
  starts <- ends - module_sizes + 1
  # 3. Generate sequence list; mapply processes each pair of starts and ends
  module_indexes <- mapply(seq, starts, ends, SIMPLIFY = FALSE)

  Y_t_new <- Y_t
  for (k in 1:length(source_sizes)) {
    # outcome reconstruction through partial transfer learning; take module coefficients from beta by row
    Y_t_new <- Y_t_new - X_t[, module_indexes[[k]], drop = FALSE] %*% beta_hat_s_list[[k]][module_indexes[[k]], 1, drop = FALSE]
  }
  hat_rho_HPTL <- lm(Y_t_new ~ D_t + 0)$coefficients
  list(hat_rho_HPTL = hat_rho_HPTL)
}
