#' Fit a Neural Network Lee-Carter Model
#'
#' Fit a neural network mortality model based on prepared \code{NNMoMo} and
#' \code{NNMoMoData} objects. The function allows selection of specific ages
#' and years for fitting, restriction to particular series (female", "male" or
#' "both"), and control of the training epochs. The model is trained using
#' \code{torch} and \code{luz} and returns a list of \code{fitStMoMo}-like
#' objects.
#'
#' @param object An object of class \code{NNMoMo} specifying the model structure
#' (e.g., embedding dimension, model type, activation function, loss function).
#' For more information see \code{\link{lcNN}}.
#' @param data An object of class \code{NNMoMoData} containing mortality rates,
#' population size, and relevant demographic features. See
#' \code{\link{NNMoMoData}} for details.
#' @param ages.fit Optional numeric vector specifying which ages to include
#' in the fitting process. By default, all available ages are used.
#' @param years.fit Optional numeric vector specifying which years to include
#' in the fitting process. By default, all available years are used.
#' @param series Indicates whether to fit the model for "female", "male" or
#' "both". Default and recommended is "both".
#' @param fitting.epochs Positive integer specifying the number of training
#' epochs for the neural network. Values much lower than 2000 are not
#' recommended.
#' @param batch.size Positive integer specifying the batch size when training
#' the model. A batch size of 128 is recommended as it was found to work the
#' smoothest.
#' @param ... Arguments to be passed to or from other methods.
#'
#' @return An object of class \code{fitNNMoMo}, where each element is a
#' \code{fitStMoMo}-like object containing the estimated parameters \eqn{a_x},
#' \eqn{b_x}, \eqn{k_t}, the underlying data, and other metadata from the model
#' fitting. These objects can be further used with functions from the
#' \pkg{StMoMo} package for analysis, plotting, or extracting fitted mortality
#' rates and exposures.
#'
#' @details
#' Missing or zero values in the mortality rates are imputed using the average
#' value at a certain age across all countries for that gender in that year.
#' Data are processed for each combination of year and sex. Categorical
#' variables such as country and sex are embedded via embedding layers.
#'
#' @examples
#' \donttest{
#' if (torch::torch_is_installed()) {
#'   # Example: fitting with random data, do not expect to get
#'   # any meaningful results.
#'
#'   # creating example data for fitting:
#'   demography_obj <- demography::demogdata(
#'     data = matrix(runif(10*5), nrow = 10),
#'     pop = matrix(runif(10*5, 1000, 2000), nrow = 10),
#'     ages = 50:59,
#'     years = 2000:2004,
#'     type = "mortality",
#'     name = "male",
#'     label = "France"
#'   )
#'   nn_data <- NNMoMoData(demography_obj)
#'
#'   # fitting the data in 10 epochs (in practice, many more epochs are needed):
#'   fit <- fit(object = lcNN(),
#'              data = nn_data,
#'              fitting.epochs = 10)
#'
#'   plot(fit$France_male)
#' }
#' }
#'
#' @export
fit.NNMoMo <- function(object,
                       data,
                       ages.fit = NULL,
                       years.fit = NULL,
                       series = c("both", "female", "male"),
                       fitting.epochs = 2000,
                       batch.size = 128,
                       ...) {

  mc <- match.call()   # for call later

# handle exceptions and wrong input
  if (!inherits(data, "NNMoMoData")) {
    stop("Argument data needs to be of class NNMoMoData")
  }

  if (!is.numeric(fitting.epochs) ||
      length(fitting.epochs) != 1 ||
      fitting.epochs <= 0 ||
      fitting.epochs %% 1 != 0) {
    stop("Agument fitting.epochs needs to be a positive integer")
  }

  if (!is.null(ages.fit) && (!is.numeric(ages.fit) ||
                             length(ages.fit) < 1 ||
                             any(ages.fit < 0) ||
                             any(ages.fit %% 1 != 0))) {
    stop(
      paste0(
        "Argument ages.fit needs to be a numeric vector ",
        "of non-negative integers"
      )
    )
  }
  if (is.null(ages.fit)) {
    ages.fit <- grep("^rate_", names(data), value = TRUE)
    ages.fit <- as.integer(sub("^rate_", "", ages.fit))
    if (any(is.na(ages.fit))) {
      stop("Some column names in `data` could not be converted to integer ages")
    }
    ages.fit <- sort(unique(ages.fit))
  }

  if (!is.null(years.fit) && (!is.numeric(years.fit) ||
                              length(years.fit) < 1 ||
                              any(years.fit %% 1 != 0))) {
    stop("Argument years.fit needs to be a numeric vector of integers")
  }
  if (is.null(years.fit)) {
    years.fit <- min(data$year):max(data$year)
  }
  else {
    # Filter by specified years
    data <- data[data$year %in% years.fit, ]
  }

  series <- match.arg(series)

  if (!is.numeric(batch.size)||
      length(batch.size) != 1 ||
      batch.size <= 0 ||
      batch.size %% 1 != 0) {
    stop("Agument batch.size needs to be a positive integer")
  }

# Filter by specified ages
  pattern <- "^(rate|pop)_(\\d+)$"
  keep_cols <- sapply(names(data), function(col) {
    if (grepl(pattern, col)) {
      idx <- as.integer(sub(pattern, "\\2", col))
      return(idx %in% ages.fit)
    }
    TRUE
  })
  data <- data[, keep_cols, drop = FALSE]


# filter by specified gender
  if (series != "both"){
    data <- data[data$sex == series, , drop = FALSE]
  }

# Imputation von NA's und 0's (sometimes problems when not enough countries)
  rate_cols <- names(data)[startsWith(names(data), "rate")]
  groups <- split(seq_len(nrow(data)), paste(data$country,
                                             data$year,
                                             data$sex,
                                             sep = "_"))

  for (col in rate_cols) {
    for (idx in groups) {
      entries <- data[idx, col]
      miss <- is.na(entries) | entries == 0

      if (any(miss)) {
        valid <- entries[is.finite(entries) & !is.na(entries) & entries != 0]

        if (length(valid) > 0) {
          m <- mean(valid)
        } else {
          col_valid <- data[[col]][is.finite(data[[col]]) & !is.na(data[[col]])
                                   & data[[col]] != 0]
          if (length(col_valid) > 0) {
            m <- mean(col_valid)
          } else {
            m <- NA_real_ # set to NA if not no values to calculate mean
          }
        }

        data[idx[miss], col] <- m
      }
    }
  }

# omitting NA's for safety
  data <- stats::na.omit(data)

# initializing torch dataset
  initialize_torch_dataset <- torch::dataset(

    name = "LC_NN_dataset",

    initialize = function(df, loss_type) {

      df <- stats::na.omit(df) # again: for safety

      # continuous input data
      rate_cols <- startsWith(names(df), "rate")
      x_rate <- torch::torch_tensor(as.matrix(df[, rate_cols]))
      log_x_rate <- torch::torch_log(x_rate)

      pop_cols <- startsWith(names(df), "pop")
      x_pop <- torch::torch_tensor(as.matrix(df[, pop_cols]))

      # categorical input data
      x_cat <- df[, c("country", "sex")]
      x_cat$country <- as.integer(x_cat$country)
      x_cat$sex <- as.integer(x_cat$sex)
      x_cat <- torch::torch_tensor(as.matrix(x_cat))

      self$x <- list(
        x_rate = log_x_rate,
        x_pop  = x_pop,
        x_cat  = x_cat
      )

      if (loss_type == "MSE") {
        self$y <- log_x_rate
      }
      else if (loss_type == "Poisson") {
        self$y <- torch::torch_multiply(x_rate, x_pop)
      }

    },

    .getitem = function(i) {
      list(
        x = list(
          x_rate = self$x$x_rate[i, ],
          x_pop  = self$x$x_pop[i, ],
          x_cat  = self$x$x_cat[i, ]
        ),
        y = self$y[i, ]
      )
    },

    .length = function() {
      self$x$x_rate$size()[[1]]
    }
  )
  torch_dataset <- initialize_torch_dataset(data, object$loss_type)


# setting up data loaders (one for fitting, one for evaluation)
  dataloader_learn <- torch::dataloader(torch_dataset,
                                        batch_size = batch.size,
                                        shuffle = TRUE)
  dataloader_eval <- torch::dataloader(torch_dataset,
                                       batch_size = length(torch_dataset),
                                       shuffle = FALSE)


# defining embedding layer (countries and sex)
  embedding_module <- torch::nn_module(

    initialize = function(cardinalities, q_e) {

      self$embeddings <- torch::nn_module_list(
        lapply(cardinalities, function(x)
          torch::nn_embedding(num_embeddings = x, embedding_dim = q_e))
      )

    },

    forward = function(x) {

      embedded <- vector(mode = "list", length = length(self$embeddings))
      for (i in 1:length(self$embeddings)) {
        embedded[[i]] <- self$embeddings[[i]](x[, i])
      }
      torch::torch_cat(embedded, dim = 2)
    }
  )


# defining LCN layer (not available as a default function in torch)
# probably has some issues - it is not as good as the others
  LocallyConnected1d <- torch::nn_module(

    initialize = function(in_channels,
                          out_channels,
                          kernel_size,
                          stride,
                          bias = TRUE) {

      self$in_channels <- in_channels
      self$out_channels <- out_channels
      self$kernel_size <- kernel_size
      self$stride <- stride
      self$weight <- NULL
      self$bias <- if (bias) NULL else NULL
      self$bias_flag <- bias
    },

    forward = function(x) {
      dims <- x$size()
      B <- dims[1]
      C_in <- dims[2]
      L_in <- dims[3]

      k <- self$kernel_size
      d <- self$stride
      L_out <- base::floor((L_in - k)/d) + 1

      device <- x$device

      if (is.null(self$weight)) {
        self$weight <- torch::nn_parameter(torch::torch_randn(self$out_channels,
                                                              C_in*k,
                                                              L_out,
                                                              device = device))
        if (self$bias_flag) {
          self$bias <- torch::nn_parameter(torch::torch_randn(self$out_channels,
                                                              L_out,
                                                              device = device))
        }
      } else {
        self$weight <- self$weight$to(device = device)
        if (!is.null(self$bias)) {
          self$bias <- self$bias$to(device = device)
        }
      }

      x_unf <- x$unfold(dimension = 2, size = k, step = d)
      x_unf_flat <- x_unf$permute(c(1,2,4,3))$reshape(c(B, C_in*k, L_out))

      out_list <- vector("list", L_out)
      for (i in seq_len(L_out)) {
        out_list[[i]] <- torch::torch_matmul(x_unf_flat[,,i],
                                             self$weight[,,i]$t())
      }
      out <- torch::torch_stack(out_list, dim = 3)

      if (!is.null(self$bias)) {
        out <- out + self$bias$unsqueeze(1)
      }

      out
    }
  )


# defining neural network
  net <- torch::nn_module(

    "LC_net",

    initialize = function(cardinalities,
                          q_e,
                          q_z1,
                          act_fun,
                          mod_type,
                          loss_type,
                          age_length
    ) {

      self$age_length <- age_length
      self$act_fun <- act_fun
      self$mod_type <- mod_type
      self$loss_type <- loss_type

      self$embedder_a <- embedding_module(cardinalities = cardinalities,
                                          q_e = q_e)
      self$embedder_b <- embedding_module(cardinalities = cardinalities,
                                          q_e = q_e)

      self$a_fc <- torch::nn_linear(in_features = length(cardinalities) * q_e,
                                    out_features = age_length)

      self$b_fc <- torch::nn_linear(in_features = length(cardinalities) * q_e,
                                    out_features = age_length)

      if (mod_type == "FCN") {
        self$k_fc1 <- torch::nn_linear(in_features = age_length,
                                       out_features = q_z1)
      }
      else if (mod_type == "LCN") {
        self$k_fc1 <- LocallyConnected1d(
          in_channels = 1,
          out_channels = 1,
          kernel_size = as.integer(age_length/q_z1),
          stride = as.integer(age_length/q_z1))
      }
      else if (mod_type == "CNN") {
        self$k_fc1 <- torch::nn_conv1d(
          in_channels = 1,
          out_channels = 1,
          kernel_size = as.integer(age_length/q_z1),
          stride = as.integer(age_length/q_z1))
      }

      self$k_fc2 <- torch::nn_linear(in_features = q_z1, out_features = 1)

      self$dropout50 <- torch::nn_dropout(p = 0.5)
      self$dropout5 <- torch::nn_dropout(p = 0.05)

      if (act_fun == "tanh") {
        self$tanh <- torch::nn_tanh()
      }

      self$a_x <- NULL
      self$b_x <- NULL
      self$k_t <- NULL

    },

    forward = function(x) {

      x_rate <- x$x_rate
      x_pop <- x$x_pop
      x_cat <- x$x_cat

      batch_size <- x_rate$size(1)

      embedded_a <- self$embedder_a(x_cat)
      embedded_b <- self$embedder_b(x_cat)

      # a_x
      a_x <- self$dropout50(embedded_a)
      a_x <- self$a_fc(a_x)
      a_x <- a_x$view(c(batch_size,self$age_length))
      self$a_x <- a_x

      # b_x
      b_x <- self$dropout50(embedded_b)
      b_x <- self$b_fc(b_x)
      b_x <- b_x$view(c(batch_size,self$age_length))
      self$b_x <- b_x

      # k_t
      if (self$mod_type == "FCN") {
        if (self$act_fun == "linear") {
          k_t <- self$k_fc1(x_rate)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t)
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
        else if (self$act_fun == "tanh") {
          k_t <- self$k_fc1(x_rate)
          k_t <- self$tanh(k_t)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t)
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
        else if (self$act_fun == "relu") {
          k_t <- self$k_fc1(x_rate)
          k_t <- torch::nnf_relu(k_t)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t)
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
      }
      else if (self$mod_type %in% c("LCN","CNN")) {
        x_rate_unsq <- x_rate$unsqueeze(2)
        if (self$act_fun == "linear") {
          k_t <- self$k_fc1(x_rate_unsq)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t$view(c(batch_size, -1)))
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
        else if (self$act_fun == "tanh") {
          k_t <- self$k_fc1(x_rate_unsq)
          k_t <- self$tanh(k_t)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t$view(c(batch_size, -1)))
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
        else if (self$act_fun == "relu") {
          k_t <- self$k_fc1(x_rate_unsq)
          k_t <- torch::nnf_relu(k_t)
          k_t <- self$dropout5(k_t)
          k_t <- self$k_fc2(k_t$view(c(batch_size, -1)))
          k_t <- k_t$view(c(batch_size, 1))
          self$k_t <- k_t
        }
      }

      # bx_kt (multiply bx and kt)
      bx_kt <- torch::torch_multiply(k_t,b_x)
      bx_kt <- self$dropout5(bx_kt)

      # output (log(m_xt) for MSE and D_xt for Poisson)
      if (self$loss_type == "MSE") {
        m_xt <- torch::torch_add(a_x,bx_kt)
        m_xt
      }
      else if (self$loss_type == "Poisson") {
        m_xt <- torch::torch_add(a_x,bx_kt)
        m_xt <- torch::torch_exp(m_xt)
        D_xt <- torch::torch_multiply(m_xt, x_pop)
        D_xt
      }
    }
  )


# fitting process
  fitted_setup <- luz::setup(
    net,
    loss = function(y_hat, y_true) {
      if (object$loss_type == "MSE") {
        torch::nnf_mse_loss(y_hat, y_true)
      } else if (object$loss_type == "Poisson") {
        torch::nnf_poisson_nll_loss(y_hat, y_true, log_input = FALSE)
      }
    },
    optimizer = torch::optim_adam
  )

  # safety meassure to compute only available years
  age_cols <- grep("^rate_", names(data), value = TRUE)

  fitted_hparams <- luz::set_hparams(
    fitted_setup,
    cardinalities = c(length(levels(data$country)), length(levels(data$sex))),
    q_e = object$q_e,
    q_z1 = object$q_z1,
    act_fun = object$activation,
    mod_type = object$model_type,
    loss_type = object$loss_type,
    age_length = length(age_cols)
  )

  fitted_opt <- luz::set_opt_hparams(fitted_hparams, lr = 0.001)

  fitted <- suppressWarnings(luz::fit(fitted_opt,
                                      dataloader_learn,
                                      epochs = fitting.epochs))

  message("Fitting successful, creating output...")

# converting NN output back to list of StMoMo objects
  group_keys <- paste(data$country, data$sex, sep = "_")
  split_indices <- split(seq_len(nrow(data)), group_keys)
  group_data <- list()

  # rearrange data to fit output of NN to match the countries
  for (grp in names(split_indices)) {
    idx <- split_indices[[grp]]
    data_grp <- data[idx, ]

    rate_cols <- grep("^rate_", names(data_grp), value = TRUE)
    pop_cols  <- grep("^pop_",  names(data_grp), value = TRUE)

    age_order <- order(as.integer(sub("rate_", "", rate_cols)))
    rate_cols <- rate_cols[age_order]
    pop_cols  <- pop_cols[age_order]
    ages      <- as.integer(sub("rate_", "", rate_cols))

    rates <- as.matrix(data_grp[, rate_cols, drop = FALSE])
    pops  <- as.matrix(data_grp[, pop_cols,  drop = FALSE])

    Dxt <- rates * pops
    Ext <- pops
    years <- data_grp$year

    Dxt <- t(Dxt)
    Ext <- t(Ext)

    rownames(Dxt) <- ages
    rownames(Ext) <- ages
    colnames(Dxt) <- years
    colnames(Ext) <- years

    data_obj <- list(
      Dxt = Dxt,
      Ext = Ext,
      ages = ages,
      years = years,
      type = "central",
      series = as.character(data_grp$sex[1]),
      label = as.character(data_grp$country[1])
    )
    group_data[[grp]] <- data_obj
  }

  # evaluate NN and convert to ax, bx and kt
  evaluation <- suppressWarnings(luz::evaluate(fitted, dataloader_eval))

  a_x <- torch::as_array(evaluation$model$a_x)
  b_x <- torch::as_array(evaluation$model$b_x)
  k_t <- torch::as_array(evaluation$model$k_t)

  split_k_t <- split(as.data.frame(k_t), group_keys)

  first_rows_a_x <- list()
  first_rows_b_x <- list()
  for (grp in names(split_indices)) {
    idx <- split_indices[[grp]]
    ages <- group_data[[grp]]$ages

    vec_a <- as.vector(a_x[idx[1], ])
    names(vec_a) <- ages
    first_rows_a_x[[grp]] <- vec_a

    vec_b <- as.vector(b_x[idx[1], ])
    names(vec_b) <- ages
    first_rows_b_x[[grp]] <- vec_b
  }

  output <- list()
  for (grp in names(split_k_t)) {

    k_mat <- as.matrix(split_k_t[[grp]])
    a_vec <- first_rows_a_x[[grp]]
    b_vec <- first_rows_b_x[[grp]]

    k_vec <- as.numeric(k_mat)
    c1 <- mean(k_vec)
    c2 <- sum(b_vec)

    a_vec <- a_vec + c1 * b_vec
    b_vec <- b_vec / c2
    k_mat <- c2 * (k_mat - c1)

    years <- group_data[[grp]]$years
    ages  <- group_data[[grp]]$ages

    bx_mat <- matrix(b_vec, ncol = 1)
    rownames(bx_mat) <- ages
    names(a_vec) <- ages

    k_mat <- t(k_mat)
    colnames(k_mat) <- years

    Dxt <- group_data[[grp]]$Dxt
    Ext <- group_data[[grp]]$Ext

    oxt <- matrix(0, nrow = length(ages), ncol = length(years))
    wxt <- matrix(1, nrow = length(ages), ncol = length(years))
    rownames(oxt) <- ages; colnames(oxt) <- years
    rownames(wxt) <- ages; colnames(wxt) <- years

    cohorts <- (years[1] - ages[length(ages)]):(years[length(years)] - ages[1])

    output[[grp]] <- list(
      model = StMoMo::lc(),
      ax    = a_vec,
      bx    = bx_mat,
      kt    = k_mat,
      b0x   = NULL,
      gc    = NULL,
      data  = group_data[[grp]],
      Dxt   = Dxt,
      Ext   = Ext,
      oxt   = oxt,
      wxt   = wxt,
      ages  = ages,
      years = years,
      cohorts = cohorts,
      fittingModel = NULL,
      loglik = NA_real_,
      deviance = NA_real_,
      npar = NA_integer_,
      nobs = NA_integer_,
      conv = TRUE,
      fail = FALSE,
      call = mc
    )
    class(output[[grp]]) <- c("fitNNMoMo", "fitStMoMo")
  }
  class(output) <- "fitStMoMo_list"

  train_metrics <- fitted$records$metrics$train
  output$NN_fittingModel <- list(
    ages.fit = ages.fit,
    years.fit = years.fit,
    series = series,
    fitting.epochs = fitting.epochs,
    batch.size = batch.size,
    activation = object$activation,
    model_type = object$model_type,
    loss_type = object$loss_type,
    q_e = object$q_e,
    q_z1 = object$q_z1,
    final_loss = train_metrics[[length(train_metrics)]]$loss
  )

  message("Computation finished")
  output
}
