#' @title Out-of-Sample Validation Functions
#' @name validation
#' @description Functions for rolling window and leave-one-sector-out validation.
NULL

#' Rolling Window Cross-Validation
#'
#' Performs time-series cross-validation using rolling windows, comparing
#' fixed effects and Mundlak models.
#'
#' @param panel_data Data frame in panel format.
#' @param window_sizes Integer vector of training window sizes. Default c(20, 30).
#' @param step_size Integer step between windows. Default 2.
#' @param test_horizon Integer number of periods to forecast. Default 3.
#' @param verbose Logical. Print progress. Default TRUE.
#'
#' @return A data frame with validation results for each window, including:
#' \describe{
#'   \item{window_size}{Training window size}
#'   \item{window_start}{Start year of training window}
#'   \item{window_end}{End year of training window}
#'   \item{rmse_fe_all}{RMSE for FE model on all test observations}
#'   \item{rmse_m_all}{RMSE for Mundlak model on all test observations}
#'   \item{n_test_common}{Number of test obs from sectors in training}
#'   \item{n_test_new}{Number of test obs from new sectors}
#' }
#'
#' @details
#' For each rolling window, the function fits both a fixed effects model
#' (by sector) and a Mundlak CRE model, then evaluates predictions on
#' the test period. Results are separated by whether test sectors were
#' seen during training (common) or not (new).
#'
#' @examples
#' \donttest{
#' if (requireNamespace("plm", quietly = TRUE)) {
#'   set.seed(123)
#'   panel <- data.frame(
#'     year = rep(2000:2029, 5),
#'     sector = rep(LETTERS[1:5], each = 30),
#'     log_direct = rnorm(150, 5, 0.5),
#'     log_production = rnorm(150, 5, 0.5)
#'   )
#'   panel$log_production <- panel$log_direct * 0.95 + rnorm(150, 0, 0.1)
#'
#'   cv_results <- rolling_window_cv(panel, window_sizes = c(15, 20))
#'   print(head(cv_results))
#' }
#' }
#'
#' @export
rolling_window_cv <- function(panel_data,
                               window_sizes = c(20L, 30L),
                               step_size = 2L,
                               test_horizon = 3L,
                               verbose = TRUE) {

    validate_panel_data(panel_data, require_log = TRUE)

    years_sorted <- sort(unique(panel_data$year))
    n_years <- length(years_sorted)

    results <- list()
    result_idx <- 0L

    for (ws in window_sizes) {

        if (n_years <= ws + 1L) {
            next
        }

        for (start_idx in seq(1L, n_years - ws, by = step_size)) {

            train_years <- years_sorted[start_idx:(start_idx + ws - 1L)]
            test_end_idx <- min(start_idx + ws + test_horizon - 1L, n_years)
            test_years <- years_sorted[(start_idx + ws):test_end_idx]

            df_train <- panel_data[panel_data$year %in% train_years, ]
            df_test <- panel_data[panel_data$year %in% test_years, ]

            if (nrow(df_train) < 50L || nrow(df_test) < 5L) {
                next
            }

            train_sectors <- unique(df_train$sector)
            test_sectors <- unique(df_test$sector)
            common_sectors <- intersect(train_sectors, test_sectors)
            new_sectors <- setdiff(test_sectors, train_sectors)

            is_common <- df_test$sector %in% common_sectors
            is_new <- df_test$sector %in% new_sectors

            y_test <- df_test$log_production

            fe_pred <- fit_predict_fe_sector(df_train, df_test)

            m_pred <- fit_predict_mundlak(df_train, df_test)

            result_idx <- result_idx + 1L
            results[[result_idx]] <- data.frame(
                window_size = ws,
                window_start = min(train_years),
                window_end = max(train_years),
                test_start = min(test_years),
                test_end = max(test_years),
                n_train = nrow(df_train),
                n_test = nrow(df_test),
                n_test_common = sum(is_common),
                n_test_new = sum(is_new),
                rmse_fe_all = safe_rmse(y_test, fe_pred),
                mae_fe_all = safe_mae(y_test, fe_pred),
                rmse_fe_common = safe_rmse(y_test[is_common], fe_pred[is_common]),
                mae_fe_common = safe_mae(y_test[is_common], fe_pred[is_common]),
                rmse_fe_new = safe_rmse(y_test[is_new], fe_pred[is_new]),
                mae_fe_new = safe_mae(y_test[is_new], fe_pred[is_new]),
                rmse_m_all = safe_rmse(y_test, m_pred),
                mae_m_all = safe_mae(y_test, m_pred),
                rmse_m_common = safe_rmse(y_test[is_common], m_pred[is_common]),
                mae_m_common = safe_mae(y_test[is_common], m_pred[is_common]),
                rmse_m_new = safe_rmse(y_test[is_new], m_pred[is_new]),
                mae_m_new = safe_mae(y_test[is_new], m_pred[is_new]),
                stringsAsFactors = FALSE
            )

            if (verbose && result_idx %% 10L == 0L) {
                message(sprintf("Completed %d windows...", result_idx))
            }
        }
    }

    if (length(results) == 0L) {
        warning("No valid windows completed.")
        return(NULL)
    }

    do.call(rbind, results)
}


#' Fit and Predict with FE Sector Model
#'
#' Internal function to fit FE model and generate predictions.
#'
#' @param df_train Training data.
#' @param df_test Test data.
#'
#' @return Numeric vector of predictions.
#'
#' @keywords internal
fit_predict_fe_sector <- function(df_train, df_test) {

    fe_fit <- tryCatch(
        stats::lm(log_production ~ log_direct + factor(sector), data = df_train),
        error = function(e) NULL
    )

    if (is.null(fe_fit)) {
        return(rep(NA_real_, nrow(df_test)))
    }

    coefs <- stats::coef(fe_fit)
    beta <- unname(coefs["log_direct"])
    b0 <- unname(coefs["(Intercept)"])

    is_sector <- grepl("^factor\\(sector\\)", names(coefs))
    alpha <- coefs[is_sector]
    sector_names <- sub("^factor\\(sector\\)", "", names(alpha))
    alpha_map <- stats::setNames(unname(alpha), sector_names)

    sector_vec <- as.character(df_test$sector)
    alpha_s <- ifelse(sector_vec %in% names(alpha_map), alpha_map[sector_vec], 0)

    as.numeric(b0 + beta * df_test$log_direct + alpha_s)
}


#' Fit and Predict with Mundlak Model
#'
#' Internal function to fit Mundlak model and generate predictions.
#'
#' @param df_train Training data.
#' @param df_test Test data.
#'
#' @return Numeric vector of predictions.
#'
#' @keywords internal
fit_predict_mundlak <- function(df_train, df_test) {

    if (!requireNamespace("plm", quietly = TRUE)) {
        return(rep(NA_real_, nrow(df_test)))
    }

    df_train_m <- create_mundlak_data(df_train, x_var = "log_direct")

    pdata_m <- tryCatch(
        plm::pdata.frame(df_train_m, index = c("sector", "year")),
        error = function(e) NULL
    )

    if (is.null(pdata_m)) {
        return(rep(NA_real_, nrow(df_test)))
    }

    m_fit <- tryCatch(
        plm::plm(
            log_production ~ x_within + x_mean_sector,
            data = pdata_m,
            model = "random",
            random.method = "swar",
            effect = "individual"
        ),
        error = function(e) NULL
    )

    if (is.null(m_fit)) {
        return(rep(NA_real_, nrow(df_test)))
    }

    b <- stats::coef(m_fit)
    b0 <- unname(b["(Intercept)"])
    bw <- unname(b["x_within"])
    bb <- unname(b["x_mean_sector"])

    sector_means_train <- stats::aggregate(
        log_direct ~ sector,
        data = df_train,
        FUN = mean,
        na.rm = TRUE
    )
    names(sector_means_train)[2L] <- "x_mean_sector"

    x_bar_global <- mean(df_train$log_direct, na.rm = TRUE)

    df_test_m <- merge(
        df_test,
        sector_means_train,
        by = "sector",
        all.x = TRUE
    )

    df_test_m$x_mean_sector <- ifelse(
        is.na(df_test_m$x_mean_sector),
        x_bar_global,
        df_test_m$x_mean_sector
    )

    df_test_m$x_within <- df_test_m$log_direct - df_test_m$x_mean_sector

    as.numeric(b0 + bw * df_test_m$x_within + bb * df_test_m$x_mean_sector)
}


#' Leave-One-Sector-Out Cross-Validation
#'
#' Performs LOSO CV, leaving out each sector in turn as the test set.
#'
#' @param panel_data Data frame in panel format.
#' @param verbose Logical. Print progress. Default TRUE.
#'
#' @return A data frame with RMSE and MAE for each held-out sector.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("plm", quietly = TRUE)) {
#'   set.seed(123)
#'   panel <- data.frame(
#'     year = rep(2000:2019, 5),
#'     sector = rep(LETTERS[1:5], each = 20),
#'     log_direct = rnorm(100, 5, 0.5),
#'     log_production = rnorm(100, 5, 0.5)
#'   )
#'   panel$log_production <- panel$log_direct * 0.95 + rnorm(100, 0, 0.1)
#'
#'   loso_results <- leave_one_sector_out(panel)
#'   print(loso_results)
#' }
#' }
#'
#' @export
leave_one_sector_out <- function(panel_data, verbose = TRUE) {

    validate_panel_data(panel_data, require_log = TRUE)

    sectors <- sort(unique(panel_data$sector))
    n_sectors <- length(sectors)

    results <- vector("list", n_sectors)

    for (i in seq_along(sectors)) {

        s <- sectors[i]

        if (verbose) {
            message(sprintf("[%d/%d] Holding out sector: %s", i, n_sectors, s))
        }

        df_train <- panel_data[panel_data$sector != s, ]
        df_test <- panel_data[panel_data$sector == s, ]

        y_test <- df_test$log_production

        fe_pred <- fit_predict_fe_sector(df_train, df_test)
        m_pred <- fit_predict_mundlak(df_train, df_test)

        results[[i]] <- data.frame(
            sector = s,
            n_test = nrow(df_test),
            rmse_fe = safe_rmse(y_test, fe_pred),
            mae_fe = safe_mae(y_test, fe_pred),
            rmse_m = safe_rmse(y_test, m_pred),
            mae_m = safe_mae(y_test, m_pred),
            stringsAsFactors = FALSE
        )
    }

    do.call(rbind, results)
}


#' Summarize Rolling Window Results
#'
#' Computes summary statistics from rolling window cross-validation.
#'
#' @param cv_results Data frame from rolling_window_cv.
#' @param bootstrap_reps Number of bootstrap replications. Default 300.
#'
#' @return A data frame with summary statistics by model and partition.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("plm", quietly = TRUE)) {
#'   set.seed(123)
#'   panel <- data.frame(
#'     year = rep(2000:2029, 5),
#'     sector = rep(LETTERS[1:5], each = 30),
#'     log_direct = rnorm(150, 5, 0.5),
#'     log_production = rnorm(150, 5, 0.5)
#'   )
#'   panel$log_production <- panel$log_direct * 0.95 + rnorm(150, 0, 0.1)
#'
#'   cv_results <- rolling_window_cv(panel, window_sizes = c(15, 20))
#'   summary_stats <- summarize_cv_results(cv_results)
#'   print(summary_stats)
#' }
#' }
#'
#' @export
summarize_cv_results <- function(cv_results, bootstrap_reps = 300L) {

    if (is.null(cv_results) || nrow(cv_results) == 0L) {
        warning("No CV results to summarize.")
        return(NULL)
    }

    metrics_config <- list(
        list(model = "FE", partition = "all", rmse_col = "rmse_fe_all", mae_col = "mae_fe_all"),
        list(model = "FE", partition = "common", rmse_col = "rmse_fe_common", mae_col = "mae_fe_common"),
        list(model = "FE", partition = "new", rmse_col = "rmse_fe_new", mae_col = "mae_fe_new"),
        list(model = "Mundlak", partition = "all", rmse_col = "rmse_m_all", mae_col = "mae_m_all"),
        list(model = "Mundlak", partition = "common", rmse_col = "rmse_m_common", mae_col = "mae_m_common"),
        list(model = "Mundlak", partition = "new", rmse_col = "rmse_m_new", mae_col = "mae_m_new")
    )

    results <- lapply(metrics_config, function(cfg) {

        rmse_vals <- cv_results[[cfg$rmse_col]]
        mae_vals <- cv_results[[cfg$mae_col]]

        rmse_summary <- robust_summary(rmse_vals, bootstrap_reps = bootstrap_reps)
        mae_summary <- robust_summary(mae_vals, bootstrap_reps = bootstrap_reps)

        data.frame(
            model = cfg$model,
            partition = cfg$partition,
            rmse_mean = round(rmse_summary$mean, 4),
            rmse_sd = round(rmse_summary$sd, 4),
            rmse_median = round(rmse_summary$median, 4),
            rmse_ci_lo = round(rmse_summary$ci[1L], 4),
            rmse_ci_hi = round(rmse_summary$ci[2L], 4),
            mae_mean = round(mae_summary$mean, 4),
            mae_median = round(mae_summary$median, 4),
            n_windows = sum(is.finite(rmse_vals)),
            stringsAsFactors = FALSE
        )
    })

    do.call(rbind, results)
}


#' Compute OOS Degradation
#'
#' Compares out-of-sample metrics to in-sample metrics.
#'
#' @param insample_metrics Named list with in-sample metrics.
#' @param cv_results Data frame from rolling_window_cv.
#'
#' @return Data frame with degradation percentages.
#'
#' @examples
#' insample <- list(rmse_fe = 0.05, rmse_m = 0.04)
#' cv_res <- data.frame(
#'   rmse_fe_all = c(0.06, 0.055, 0.058),
#'   rmse_m_all = c(0.045, 0.042, 0.044)
#' )
#' degradation <- compute_oos_degradation(insample, cv_res)
#' print(degradation)
#'
#' @export
compute_oos_degradation <- function(insample_metrics, cv_results) {

    results <- list()

    if (!is.null(insample_metrics$rmse_fe) && "rmse_fe_all" %in% names(cv_results)) {
        oos_mean <- mean(cv_results$rmse_fe_all, na.rm = TRUE)
        is_val <- insample_metrics$rmse_fe
        if (is.finite(is_val) && is_val > 0) {
            deg <- (oos_mean - is_val) / is_val * 100
            results$fe <- data.frame(
                model = "FE",
                insample_rmse = is_val,
                oos_rmse_mean = oos_mean,
                degradation_pct = deg,
                stringsAsFactors = FALSE
            )
        }
    }

    if (!is.null(insample_metrics$rmse_m) && "rmse_m_all" %in% names(cv_results)) {
        oos_mean <- mean(cv_results$rmse_m_all, na.rm = TRUE)
        is_val <- insample_metrics$rmse_m
        if (is.finite(is_val) && is_val > 0) {
            deg <- (oos_mean - is_val) / is_val * 100
            results$mundlak <- data.frame(
                model = "Mundlak",
                insample_rmse = is_val,
                oos_rmse_mean = oos_mean,
                degradation_pct = deg,
                stringsAsFactors = FALSE
            )
        }
    }

    if (length(results) == 0L) {
        return(NULL)
    }

    do.call(rbind, results)
}
