#' Evaluate Strategy
#' 
#' Given an unevaluated strategy, an initial number of 
#' individual and a number of cycle to compute, returns the 
#' evaluated version of the objects and the count of 
#' individual per state per model cycle.
#' 
#' \code{init} need not be integer. E.g. specifying a vector
#' of type \code{c(A = 1, B = 0, C = 0, ...)} returns the 
#' probabilities for an individual starting in state A to be
#' in each state, per cycle.
#' 
#' @param strategy An \code{uneval_strategy} object.
#' @param parameters Optional. An object generated by 
#'   \code{\link{define_parameters}}.
#' @param cycles positive integer. Number of Markov Cycles 
#'   to compute.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#' @param expand_limit A named vector of state expansion
#'   limits.
#' @param inflow Numeric vector, similar to \code{init}.
#'   Number of new individuals in each state per cycle.
#' @param strategy_name Name of the strategy.
#'   
#' @return An \code{eval_strategy} object (actually a list of 
#'   evaluated parameters, matrix, states and cycles 
#'   counts).
#'   
#' @example inst/examples/example_eval_strategy.R
#'   
#' @keywords internal
eval_strategy <- function(strategy, parameters, cycles, 
                          init, method, expand_limit,
                          inflow, strategy_name) {
  stopifnot(
    cycles > 0,
    length(cycles) == 1,
    all(init >= 0)
  )
  
  uneval_transition <- get_transition(strategy)
  uneval_states <- get_states(strategy)
  
  i_parameters <- interp_heemod(parameters)
  
  i_uneval_transition <- interp_heemod(
    uneval_transition,
    more = as_expr_list(i_parameters)
  )
  
  i_uneval_states <- interp_heemod(
    uneval_states,
    more = as_expr_list(i_parameters)
  )
  
  
  td_tm <- has_state_time(i_uneval_transition)
  
  td_st <- has_state_time(i_uneval_states)
  
  # no expansion if 
  expand <- any(c(td_tm, td_st))
  
  
  if (expand) {
    
    if (inherits(uneval_transition, "part_surv")) {
      stop("Cannot use 'state_time' with partitionned survival.")
    }
    
    uneval_transition <- i_uneval_transition
    uneval_states <- i_uneval_states
    
    # parameters not needed anymore because of interp
    parameters <- define_parameters()
    
    # from cells to cols
    td_tm <- td_tm %>% 
      matrix(
        nrow = get_matrix_order(uneval_transition), 
        byrow = TRUE
      ) %>% 
      apply(1, any)
    
    to_expand <- sort(unique(c(
      get_state_names(uneval_transition)[td_tm],
      get_state_names(uneval_states)[td_st]
    )))
    
    message(sprintf(
      "%s: detected use of 'state_time', expanding state%s: %s.",
      strategy_name,
      plur(length(to_expand)),
      paste(to_expand, collapse = ", ")
    ))
    
    for (st in to_expand) {
      init <- insert(
        init,
        which(get_state_names(uneval_transition) == st),
        rep(0, expand_limit[st])
      )
      
      inflow <- insert(
        inflow,
        which(get_state_names(uneval_transition) == st),
        rep(0, expand_limit[st])
      )
    }
    
    for (st in to_expand) {
      uneval_transition <- expand_state(
        x = uneval_transition,
        state_pos = which(get_state_names(uneval_transition) == st),
        state_name = st,
        cycles = expand_limit[st]
      )
      
      uneval_states <- expand_state(
        x = uneval_states,
        state_name = st,
        cycles = expand_limit[st]
      )
    }
  }
  
  parameters <- eval_parameters(parameters,
                                cycles = cycles,
                                strategy_name = strategy_name)
  
  states <- eval_state_list(uneval_states, parameters)
  
  transition <- eval_transition(uneval_transition,
                                parameters)
  
  count_table <- compute_counts(
    x = transition,
    init = init,
    method = method,
    inflow = inflow
  )
  
  values <- compute_values(states, count_table)
  
  if (expand) {
    for (st in to_expand) {
      exp_cols <- sprintf(".%s_%i", st, seq_len(expand_limit[st] + 1))
      
      count_table[[st]] <- rowSums(count_table[exp_cols])
      count_table <- count_table[-which(names(count_table) %in% exp_cols)]
    }
  }
  
  structure(
    list(
      parameters = parameters,
      transition = transition,
      states = states,
      counts = count_table,
      values = values,
      init = init,
      cycles = cycles,
      expand_limit = expand_limit
    ),
    class = c("eval_strategy")
  )
}

#' Compute Count of Individual in Each State per Cycle
#' 
#' Given an initial number of individual and an evaluated 
#' transition matrix, returns the number of individual per 
#' state per cycle.
#' 
#' Use the \code{method} argument to specify if transitions 
#' are supposed to happen at the beginning or the end of 
#' each cycle. Alternatively linear interpolation between 
#' cycles can be performed.
#' 
#' @param x An \code{eval_matrix} or
#'   \code{eval_part_surv} object.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#' @param inflow numeric vector, similar to \code{init}.
#'   Number of new individuals in each state per cycle.
#'   
#' @return A \code{cycle_counts} object.
#'   
#' @keywords internal
compute_counts <- function(x, ...) {
  UseMethod("compute_counts")
}

#' @export
compute_counts.eval_matrix <- function(x, init,
                                       method, inflow,
                                       ...) {
  
  if (! length(init) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'init' vector (%i) differs from the number of states (%i).",
      length(init),
      get_matrix_order(x)
    ))
  }
  
  if (! length(inflow) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'inflow' vector (%i) differs from the number of states (%i).",
      length(inflow),
      get_matrix_order(x)
    ))
  }
  
  add_and_mult <- function(x, y) {
    (x + inflow) %*% y
  }
  
  list_counts <- Reduce(
    add_and_mult,
    x,
    init,
    accumulate = TRUE
  )
  
  res <- dplyr::as.tbl(
    as.data.frame(
      matrix(
        unlist(list_counts),
        byrow = TRUE,
        ncol = get_matrix_order(x)
      )
    )
  )
  
  colnames(res) <- get_state_names(x)
  
  n0 <- res[- nrow(res), ]
  n1 <- res[-1, ]
  
  switch(
    method,
    "beginning" = {
      out <- n1
    },
    "end" = {
      out <- n0
    },
    "half-cycle" = {
      warning(
        "Method 'half-cycle' is deprecated and will be removed soon.\n",
        "Consider using the 'life-table' method instead.\n",
        "See https://github.com/pierucci/heemod/issues/173 for a discussion of the reasons.",
        call. = FALSE
      )
      out <- n1
      out[1, ] <- out[1, ] + init / 2
      out[nrow(out), ] <- out[nrow(out), ] + out[nrow(out), ] / 2
    },
    "life-table" = {
      out <- (n0 + n1) / 2
    },
    {
      stop(sprintf("Unknown counting method, '%s'.", method))
    }
  )
  
  structure(out, class = c("cycle_counts", class(out)))
  
}

#' Compute State Values per Cycle
#' 
#' Given states and counts, computes the total state values 
#' per cycle.
#' 
#' @param states An object of class \code{eval_state_list}.
#' @param counts An object of class \code{cycle_counts}.
#'   
#' @return A data.frame of state values, one column per 
#'   state value and one row per cycle.
#'   
#' @keywords internal
compute_values <- function(states, counts) {
  
  states_names <- get_state_names(states)
  state_values_names <- get_state_value_names(states)
  
  res <- data.frame(
    markov_cycle = states[[1]]$markov_cycle
  )
  # bottleneck!
  for (state_value in state_values_names) {
    res[state_value] <- 0
    
    for (state in states_names) {
      res[state_value] <-
        res[state_value] +
        counts[, state] * 
        states[[state]][, state_value]
    }
  }
  res
}
