#' A function to fit the stochastic mortality model M2A1.
#'
#' Carry out Bayesian estimation of the stochastic mortality \bold{model M2A1}.
#' 
#' The model can be described mathematically as follows:
#' If \code{family="poisson"}, then
#' \deqn{d_{x,t,p} \sim \text{Poisson}(E^c_{x,t,p} m_{x,t,p}) , }
#' \deqn{\log(m_{x,t,p})=a_{x}+(c_p+b_x)k_t , }
#' where \eqn{d_{x,t,p}} represents the number of deaths at age \eqn{x} in year \eqn{t} of stratum \eqn{p},
#' while \eqn{E^c_{x,t,p}} and \eqn{m_{x,t,p}} represents respectively the corresponding central exposed to risk and central mortality rate at age \eqn{x} in year \eqn{t} of stratum \eqn{p}.
#' Similarly, if \code{family="nb"}, then a negative binomial distribution is fitted, i.e.
#' \deqn{d_{x,t,p} \sim \text{Negative-Binomial}(\phi,\frac{\phi}{\phi+E^c_{x,t,p} m_{x,t,p}}) , }
#' \deqn{\log(m_{x,t,p})=a_{x}+(c_p+b_x)k_t , }
#' where \eqn{\phi} is the overdispersion parameter. See Wong et al. (2018).
#' But if \code{family="binomial"}, then  
#' \deqn{d_{x,t,p} \sim \text{Binomial}(E^0_{x,t,p} , q_{x,t,p}) , }
#' \deqn{\text{logit}(q_{x,t,p})=a_{x}+(c_p+b_x)k_t , }
#' where \eqn{q_{x,t,p}} represents the initial mortality rate at age \eqn{x} in year \eqn{t} of stratum \eqn{p},
#' while \eqn{E^0_{x,t,p}\approx E^c_{x,t,p}+\frac{1}{2}d_{x,t,p}} is the corresponding initial exposed to risk.
#' Constraints used are:
#' \deqn{\sum_p c_p = 0, \sum_x b_x = 1, \sum_t k_t = 0 .}
#' If \code{forecast=TRUE}, then the following time series models are fitted on \eqn{k_{t,p}} as follows:
#' \deqn{k_{t,p} = \eta_1+\eta_2 t +\rho (k_{t-1,p}-(\eta_1+\eta_2 (t-1))) + \epsilon_{t,p} \text{ for }p=1,\ldots,P \text{ and } t=1,\ldots,T,}
#' where \eqn{\epsilon_{t,p}\sim N(0,\sigma_k^2)} for \eqn{t=1,\ldots,T}, while \eqn{\eta_1,\eta_2,\rho,\sigma_k^2} are additional parameters to be estimated.
#' Note that the forecasting models are inspired by Wong et al. (2023).
#' 
#' @references Jackie S. T. Wong, Jonathan J. Forster, and Peter W. F. Smith. (2018). Bayesian mortality forecasting with overdispersion, Insurance: Mathematics and Economics, Volume 2018, Issue 3, 206-221. \doi{https://doi.org/10.1016/j.insmatheco.2017.09.023}
#' @references Jackie S. T. Wong, Jonathan J. Forster, and Peter W. F. Smith. (2023). Bayesian model comparison for mortality forecasting, Journal of the Royal Statistical Society Series C: Applied Statistics, Volume 72, Issue 3, 566–586. \doi{https://doi.org/10.1093/jrsssc/qlad021}
#' 
#' @param death death data that has been formatted through the function \code{preparedata_fn}.
#' @param expo expo data that has been formatted through the function \code{preparedata_fn}.
#' @param n_iter number of iterations to run. Default is \code{n_iter=10000}. 
#' @param family a string of characters that defines the family function associated with the mortality model. "poisson" would assume that deaths follow a Poisson distribution and use a log link; "binomial" would assume that deaths follow a Binomial distribution and a logit link; "nb" (default) would assume that deaths follow a Negative-Binomial distribution and a log link.
#' @param n.chain number of parallel chains for the model.
#' @param thin thinning interval for monitoring purposes.
#' @param n.adapt the number of iterations for adaptation. See \code{?rjags::adapt} for details.
#' @param forecast a logical value indicating if forecast is to be performed (default is \code{FALSE}). See below for details.
#' @param h a numeric value giving the number of years to forecast. Default is \code{h=5}.
#' @param quiet if TRUE then messages generated during compilation will be suppressed, as well as the progress bar during adaptation.
#' @return A list with components:
#' \describe{
#'   \item{\code{post_sample}}{An \code{mcmc.list} object containing the posterior samples generated.}
#'   \item{\code{param}}{A vector of character strings describing the names of model parameters.}
#'   \item{\code{death}}{The death data that was used.}
#'   \item{\code{expo}}{The expo data that was used.}
#'   \item{\code{family}}{The family function used.}
#'   \item{\code{forecast}}{A logical value indicating if forecast has been performed.}
#'   \item{\code{h}}{The forecast horizon used.}
#' }
#' @keywords bayesian estimation models
#' @concept stochastic mortality models
#' @concept parameter estimation
#' @concept M2A1
#' @importFrom stats dnbinom dbinom dpois quantile sd
#' @export
#' @examples
#' #load and prepare mortality data
#' data("dxt_array_product");data("Ext_array_product")
#' death<-preparedata_fn(dxt_array_product,strat_name = c("ACI","DB","SCI"),ages=35:65)
#' expo<-preparedata_fn(Ext_array_product,strat_name = c("ACI","DB","SCI"),ages=35:65)
#' 
#' #fit the model (poisson family)
#' #NOTE: This is a toy example, please run it longer in practice.
#' fit_M2A1_result<-fit_M2A1(death=death,expo=expo,n_iter=50,n.adapt=50,family="poisson")
#' head(fit_M2A1_result)

fit_M2A1<-function(death,expo,n_iter=10000,family="nb",n.chain=1,thin=1,n.adapt=1000,forecast=FALSE,h=5,quiet=FALSE){
  
  p<-death$n_strat
  A<-death$n_ages
  T<-death$n_years
  
  if (p==1){
    #insight::print_colour("ERROR: Only 1 stratum detected (p=1). Please use the function fit_LC() instead.","red")
    stop("Only 1 stratum detected (p=1). Please use the function fit_LC() instead.",call. = FALSE)
  } else{
  prior_mean_beta<-rep(1/A,A-1)
  sigma2_beta<-0.001
  prior_prec_beta<-solve(sigma2_beta*(diag(rep(1,A-1))-1/A*(matrix(1,nrow=A-1,ncol=A-1))))
  
  if (forecast){
    t<-(1:T)-mean(1:T)
    matrix_kappa_X<-matrix(c(rep(1,T+h),c(t,t[T]+1:h)),byrow=F,ncol=2)
    
    prior_prec_eta<-solve(matrix(c(400,0,0,2),nrow=2));prior_mean_eta<-c(0,0)
    
    sigma2_kappa<-100
    
    death_forecast<-array(dim=c(p,A,T+h));expo_forecast<-array(dim=c(p,A,T+h))
    death_forecast[,,1:T]<-death$data
    death_forecast[,,(T+1):(T+h)]<-NA
    expo_forecast[,,1:T]<-expo$data
    expo_forecast[,,(T+1):(T+h)]<-expo$data[,,T]
    
    if (family=="binomial"){
      expo_forecast_initial<-expo_forecast
      expo_forecast_initial[,,1:T]<-round(expo_forecast[,,1:T]+0.5*death$data)
      expo_forecast_initial[,,(T+1):(T+h)]<-round(expo$data[,,T]+0.5*death$data[,,T])
      data<-list(dxt=death_forecast,ext=expo_forecast_initial,A=A,T=T,p=p,h=h,matrix_kappa_X=matrix_kappa_X,prior_mean_eta=prior_mean_eta,prior_prec_eta=prior_prec_eta,prior_mean_beta=prior_mean_beta,prior_prec_beta=prior_prec_beta)
      inits<-function() (list(alpha=rep(0,A),beta_rest=rep(1/A,A-1),kappa_rest=rep(0,T-1),c_p_rest=rep(0,p-1),rho=0.5,eta=c(0,0),i_sigma2_kappa=0.1))
      vars<-c("q","alpha","beta","kappa","c_p","eta","rho","sigma2_kappa")
      logit_LC_M2A1_alt_forecast_jags<-rjags::jags.model(system.file("models/logit_LC_M2A1_alt_forecast.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
      result_jags<-rjags::coda.samples(logit_LC_M2A1_alt_forecast_jags,vars,n.iter=n_iter,thin=thin)
    } 
    if (family=="poisson"){
      data<-list(dxt=death_forecast,ext=expo_forecast,A=A,T=T,p=p,h=h,matrix_kappa_X=matrix_kappa_X,prior_mean_eta=prior_mean_eta,prior_prec_eta=prior_prec_eta,prior_mean_beta=prior_mean_beta,prior_prec_beta=prior_prec_beta)
      inits<-function() (list(alpha=rep(0,A),beta_rest=rep(1/A,A-1),kappa_rest=rep(0,T-1),c_p_rest=rep(0,p-1),rho=0.5,eta=c(0,0),i_sigma2_kappa=0.1))
      vars<-c("q","alpha","beta","kappa","c_p","eta","rho","sigma2_kappa")
      log_LC_M2A1_alt_forecast_jags<-rjags::jags.model(system.file("models/log_LC_M2A1_alt_forecast.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
      result_jags<-rjags::coda.samples(log_LC_M2A1_alt_forecast_jags,vars,n.iter=n_iter,thin=thin)
    }
    if (family=="nb"){
      data<-list(dxt=death_forecast,ext=expo_forecast,A=A,T=T,p=p,h=h,matrix_kappa_X=matrix_kappa_X,prior_mean_eta=prior_mean_eta,prior_prec_eta=prior_prec_eta,prior_mean_beta=prior_mean_beta,prior_prec_beta=prior_prec_beta)
      inits<-function() (list(alpha=rep(0,A),beta_rest=rep(1/A,A-1),kappa_rest=rep(0,T-1),c_p_rest=rep(0,p-1),rho=0.5,eta=c(0,0),i_sigma2_kappa=0.1,phi=100))
      vars<-c("q","alpha","beta","kappa","c_p","eta","rho","sigma2_kappa","phi")
      nb_LC_M2A1_alt_forecast_jags<-rjags::jags.model(system.file("models/nb_LC_M2A1_alt_forecast.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
      result_jags<-rjags::coda.samples(nb_LC_M2A1_alt_forecast_jags,vars,n.iter=n_iter,thin=thin)
    }
  }
  
  if (!forecast){
    
    t<-(1:T)-mean(1:T)
    matrix_kappa_X<-matrix(c(rep(1,T),t),byrow=F,ncol=2)
    
    prior_prec_eta<-solve(matrix(c(400,0,0,2),nrow=2));prior_mean_eta<-c(0,0)
    
    sigma2_kappa<-100
    
  data<-list(dxt=death$data,ext=expo$data,A=A,T=T,p=p,matrix_kappa_X=matrix_kappa_X,prior_mean_eta=prior_mean_eta,prior_prec_eta=prior_prec_eta,prior_mean_beta=prior_mean_beta,prior_prec_beta=prior_prec_beta)
  inits<-function() (list(alpha=rep(0,A),beta_rest=rep(1/A,A-1),kappa_rest=rep(0,T-1),c_p_rest=rep(0,p-1),rho=0.5,eta=c(0,0),i_sigma2_kappa=0.1))
  vars<-c("q","alpha","beta","kappa","c_p","eta","rho","sigma2_kappa")
  if (family=="binomial"){
    expo_initial<-round(expo$data+0.5*death$data)
    data<-list(dxt=death$data,ext=expo_initial,A=A,T=T,p=p,matrix_kappa_X=matrix_kappa_X,prior_mean_eta=prior_mean_eta,prior_prec_eta=prior_prec_eta,prior_mean_beta=prior_mean_beta,prior_prec_beta=prior_prec_beta)
    logit_LC_M2A1_jags<-rjags::jags.model(system.file("models/logit_LC_M2A1.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
    result_jags<-rjags::coda.samples(logit_LC_M2A1_jags,vars,n.iter=n_iter,thin=thin)
  }
  
  if (family=="poisson"){
    log_LC_M2A1_jags<-rjags::jags.model(system.file("models/log_LC_M2A1.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
    result_jags<-rjags::coda.samples(log_LC_M2A1_jags,vars,n.iter=n_iter,thin=thin)
  }
  if (family=="nb"){
    inits<-function() (list(alpha=rep(0,A),beta_rest=rep(1/A,A-1),kappa_rest=rep(0,T-1),c_p_rest=rep(0,p-1),rho=0.5,eta=c(0,0),i_sigma2_kappa=0.1,phi=100))
    vars<-c("q","alpha","beta","kappa","c_p","eta","rho","sigma2_kappa","phi")
    nb_LC_M2A1_jags<-rjags::jags.model(system.file("models/nb_LC_M2A1.jags", package = "BayesMoFo"),data=data,inits=inits,n.chain=n.chain,n.adapt=n.adapt,quiet=quiet)
    result_jags<-rjags::coda.samples(nb_LC_M2A1_jags,vars,n.iter=n_iter,thin=thin)
  }
  }
  
  list(post_sample=result_jags,param=vars[-1],death=death,expo=expo,family=family,forecast=forecast,h=h)
  }
  
}