#' Linearized Bregman solver for linear, binomial, multinomial models
#'  with lasso, group lasso or column lasso penalty.
#' 
#' Solver for the entire solution path of coefficients for Linear Bregman iteration.
#' 
#' The Linearized Bregman solver computes the whole regularization path
#'  for different types of lasso-penalty for gaussian, binomial and 
#'  multinomial models through iterations. It is the Euler forward 
#'  discretized form of the continuous Bregman Inverse Scale Space 
#'  Differential Inclusion. For binomial models, the response variable y
#'  is assumed to be a vector of two classes which is transformed in to \{1,-1\}.
#'  For the multinomial models, the response variable y can be a vector of k classes
#'  or a n-by-k matrix that each entry is in \{0,1\} with 1 indicates 
#'  the class. Under all circumstances, two parameters, kappa 
#'  and alpha need to be specified beforehand. The definitions of kappa 
#'  and alpha are the same as that defined in the reference paper. 
#'  Parameter alpha is defined as stepsize and kappa is the damping factor
#'  of the Linearized Bregman Algorithm that is defined in the reference paper.
#'
#' @param X An n-by-p matrix of predictors
#' @param y Response Variable
#' @param kappa The damping factor of the Linearized Bregman Algorithm that is
#'  defined in the reference paper. See details. 
#' @param alpha Parameter in Linearized Bregman algorithm which controls the 
#' step-length of the discretized solver for the Bregman Inverse Scale Space. 
#' See details. 
#' @param c Normalized step-length. If alpha is missing, alpha is automatically generated by 
#' \code{alpha=n*c/(kappa*||X^T*X||_2)}. It should be in (0,2) for 
#' family = "gaussian"(Default is 1), (0,8) for family = "binomial"(Default is 4),
#'  (0,4) for family = "multinomial"(Default is 2).
#' If beyond these range the path may be oscillated at large t values.
#' @param family Response type
#' @param group Whether to use a group penalty, Default is FALSE.
#' @param index For group models, the index is a vector that determines the 
#' group of the parameters. Parameters of the same group should have equal 
#' value in index. Be careful that multinomial group model default assumes 
#' the variables in same column are in the same group, and a empty value of
#' index means each variable is a group.
#' @param intercept if TRUE, an intercept is included in the model (and not 
#' penalized), otherwise no intercept is included. Default is TRUE.
#' @param normalize if TRUE, each variable is scaled to have L2 norm 
#' square-root n. Default is TRUE.
#' @param tlist Parameters t along the path.
#' @param nt Number of t. Used only if tlist is missing. Default is 100.
#' @param trate tmax/tmin. Used only if tlist is missing. Default is 100.
#' @param print If TRUE, the percentage of finished computation is printed.
#' @return A "lb" class object is returned. The list contains the call, 
#' the type, the path, the intercept term a0 and value for alpha, kappa, 
#' iter, and meanvalue, scale factor of X, meanx and normx. For gaussian and
#' bonomial, path is a p-by-nt matrix, and for multinomial, path is a k-by-p-by-nt 
#' array, each dimension represents class, predictor and parameter t.
#' @references Ohser, Ruan, Xiong, Yao and Yin, Sparse Recovery via Differential
#'  Inclusions, \url{https://arxiv.org/abs/1406.7728}
#' @author Feng Ruan, Jiechao Xiong and Yuan Yao
#' @keywords regression
#' @examples
#' #Examples in the reference paper
#' library(MASS)
#' n = 80;p = 100;k = 30;sigma = 1
#' Sigma = 1/(3*p)*matrix(rep(1,p^2),p,p)
#' diag(Sigma) = 1
#' A = mvrnorm(n, rep(0, p), Sigma)
#' u_ref = rep(0,p)
#' supp_ref = 1:k
#' u_ref[supp_ref] = rnorm(k)
#' u_ref[supp_ref] = u_ref[supp_ref]+sign(u_ref[supp_ref])
#' b = as.vector(A%*%u_ref + sigma*rnorm(n))
#' kappa = 16
#' alpha = 1/160
#' object <- lb(A,b,kappa,alpha,family="gaussian",group=FALSE,
#'              trate=20,intercept=FALSE,normalize=FALSE)
#' plot(object,xlim=c(0,3),main=bquote(paste("LB ",kappa,"=",.(kappa))))
#' 
#' 
#' #Diabetes, linear case
#' library(Libra)
#' data(diabetes)
#' attach(diabetes)
#' object <- lb(x,y,100,1e-3,family="gaussian",group=FALSE)
#' plot(object)
#' detach(diabetes)
#' 
#' #Simulated data, binomial case
#' data('west10')
#' y<-2*west10[,1]-1;
#' X<-as.matrix(2*west10[,2:10]-1);
#' path <- lb(X,y,kappa = 1,family="binomial",trate=100,normalize = FALSE)
#' plot(path,xtype="norm",omit.zeros=FALSE)
#' 
#' #Simulated data, multinomial case
#' X <- matrix(rnorm(500*100), nrow=500, ncol=100)
#' alpha <- matrix(c(rnorm(30*3), rep(0,70*3)),nrow=3)
#' P <- exp(alpha%*%t(X))
#' P <- scale(P,FALSE,apply(P,2,sum))
#' y <- rep(0,500)
#' rd <- runif(500)
#' y[rd<P[1,]] <- 1
#' y[rd>1-P[3,]] <- -1
#' result <- lb(X,y,kappa=5,alpha=0.1,family="multinomial",
#'  group=TRUE,intercept=FALSE,normalize = FALSE)
#' plot(result)
#' 

lb <- function(X, y, kappa, alpha,c = 1, tlist,nt = 100,trate = 100,family = c("gaussian", "binomial", "multinomial"),
group = FALSE, index = NA, intercept = TRUE, normalize = TRUE,print=FALSE) 
{
  family <- match.arg(family)
  if (!is.matrix(X)) stop("X must be a matrix!")
  if (family!="multinomial"){
    if (!is.vector(y)) stop("y must be a vector unless in multinomial model!")
    if (nrow(X) != length(y)) stop("Number of rows of X must equal the length of y!")
    if (family=="binomial" & any(abs(y)!=1)) stop("y must be in {1,-1}")
  }
  if (family=="multinomial"){
    if (is.vector(y)){
      if(nrow(X) != length(y)) stop("Number of rows of X must equal the length of y!")
      y_unique <- unique(y)
      y = sapply(1:length(y_unique),function(x) as.numeric(y==y_unique[x]))
    }
    else if (is.matrix(y)){
      if(nrow(X) != nrow(y)) stop("Number of rows of X and y must equal!")
      if (any((y!=1)&(y!=0)) || any(rowSums(y)!=1)) stop("y should be indicator matrix!")
    }
    else
      stop("y must be a vector or matrix!")
  }
  if (group){
    if (missing(index)){
      if (family=="multinomial"){
        index=NA
      }else{
        group=FALSE
        print("Index is missing, using group=FALSE instead!")
      }
    }
    if (!is.vector(index)) stop("Index must be a vector!")
    if (length(index) != ncol(X))
      if (family!="multinomial" || !is.na(index))
      stop("Length of index must be the same as the number of columns of X!")
  }
  
  
  np <- dim(X)
  n <- np[1]
  p <- np[2]
  one <- rep(1, n)
	if(intercept){
    	meanx <- drop(one %*% X)/n
    	X <- scale(X, meanx, FALSE)
  }else meanx <- rep(0,p)
	if(normalize){
	    normx <- sqrt(drop(one %*% (X^2))/n)
	    X <- scale(X, FALSE, normx)
  }else normx <- rep(1,p)
  
  if (missing(tlist)) tlist<-rep(-1.0,nt)
  else nt <- length(tlist)
  
  alpha0_rate <- 1.0
  if (missing(alpha)){
      sigma <- norm(X,"2")
      c <- switch(family,gaussian = 1, binomial=4, multinomial=2)
      alpha <- n*c/kappa/sigma^2
      if (intercept)
        alpha0_rate <- sigma^2/n
  }
  if (intercept & !normalize){
      alpha0_rate <- max(drop(one %*% (X^2)))/n
  }
  
	if (family == "gaussian")
		object <- LB.lasso(X, y, kappa, alpha,alpha0_rate,tlist,nt,trate,intercept,group,index,print)
	else if (family == "binomial")
		object <- LB.logistic(X, y, kappa, alpha,alpha0_rate,tlist,nt ,trate, intercept,group,index,print)
	else if (family == "multinomial")
		object <- LB.multilogistic(X, y, kappa, alpha,alpha0_rate,tlist,nt ,trate, intercept,group,index,print)
	else stop("No such family type!")
	
  # seperate intercept from path
	if (intercept){
	  if (family != "multinomial"){
  		object$a0 <- object$path[p+1,,drop=TRUE]
	  	object$path <- object$path[-(p+1),,drop=FALSE]
	  }else{
	    object$a0 <- object$path[,p+1,,drop=TRUE]
	    object$path <- object$path[,-(p+1),,drop=FALSE]
	  }
	}else{
	  if (family == "multinomial"){
	    object$a0 <- matrix(rep(0,nt*ncol(y)),ncol = nt)
	  }else{
  	  object$a0 <- rep(0,nt)
	  }
	}
  # re-scale
  if (family == "multinomial"){
    object$path <- sapply(1:nt, function(x)
      scale(object$path[,,x],FALSE,normx) ,simplify = "array")
  	if (intercept) object$a0 <- object$a0 - 
  	    sapply(1:nt, function(x) object$path[,,x]%*%meanx)
  }else{
    object$path <- t(scale(t(object$path),FALSE,normx))
    if (intercept) object$a0 <- object$a0 - meanx%*%object$path
  }
	object$meanx <- meanx
	object$normx <- normx
	object$family <- family
	object$group <- group
#	fit <- predict(object,X)
#	object$fit <- fit
	return(object)
}

LB.lasso <- function(X, y, kappa, alpha,alpha0_rate,tlist,nt ,trate, intercept = TRUE,group,index,print=FALSE) {
	call <- match.call()
	row <- nrow(X)
	col <- ncol(X)
	intercept <- as.integer(intercept != 0)
	if (!group){
	  group_split <- 0
	  group_split_length <- 0
	}else{
	  ord <- order(index)
	  ord_rev <- order(ord)
	  X <- X[,ord]
	  group_size <- as.vector(table(index))
	  group_split <- c(0, cumsum(group_size))
	  group_split_length <- length(group_split)
	}
	print <- as.integer(print !=0)
	result_r <- vector(length = nt * (col + intercept))
	solution <- .C("LB_lasso",
		as.numeric(X),
		as.integer(row),
		as.integer(col),
		as.numeric(y),
		as.numeric(kappa),
		as.numeric(alpha),
		as.numeric(alpha0_rate),
		as.numeric(result_r),
		as.integer(group_split),
		as.integer(group_split_length),
		as.integer(intercept),
		as.numeric(tlist),
		as.integer(nt),
		as.numeric(trate),
    as.integer(print)
	)
	path <- matrix(solution[[8]], ncol = nt)
	if (group){
	  path[1:col,] <- path[1:col,][ord_rev,]; 
	}
	object <- list(call = call, kappa = kappa, alpha = alpha, path = path, nt = nt,t=solution[[12]])
	class(object) <- "lb"
	return(object)
}


LB.logistic <- function(X, y, kappa, alpha,alpha0_rate,tlist,nt ,trate, intercept = TRUE,group,index,print=FALSE) {
	call <- match.call()
	row <- nrow(X)
	col <- ncol(X)
	if (!group){
	  group_split <- 0
	  group_split_length <- 0
	}else{
	  ord <- order(index)
	  ord_rev <- order(ord)
	  X <- X[,ord]
	  group_size <- as.vector(table(index))
	  group_split <- c(0, cumsum(group_size))
	  group_split_length <- length(group_split)
	}
	intercept <- as.integer(intercept != 0)
	print <- as.integer(print !=0)
	result_r <- vector(length = nt * (col + intercept))
	solution <- .C("LB_logistic",
		as.numeric(X),
		as.integer(row),
		as.integer(col),
		as.numeric(y),
		as.numeric(kappa),
		as.numeric(alpha),
		as.numeric(alpha0_rate),
		as.numeric(result_r),
		as.integer(group_split),
		as.integer(group_split_length),
		as.integer(intercept),
		as.numeric(tlist),
		as.integer(nt),
		as.numeric(trate),
		as.integer(print)
	)
	path <- matrix(solution[[8]], ncol = nt)
	if (group){
	  path[1:col,] <- path[1:col,][ord_rev,]; 
	}
	object <- list(call = call, kappa = kappa, alpha = alpha, path = path,t=solution[[12]])
	class(object) <- "lb"
	return(object)
}


LB.multilogistic <- function(X, y, kappa, alpha,alpha0_rate,tlist,nt ,trate, intercept = TRUE,group,index,print=FALSE) {
	call <- match.call()
	row <- nrow(X)
	col <- ncol(X)
	if (!group){
	  group_split <- 0
	  group_split_length <- 0
	}else if(is.na(index)){
	  group_split <- 0
	  group_split_length <- 1
	}else{
	  ord <- order(index)
	  ord_rev <- order(ord)
	  X <- X[,ord]
	  group_size <- as.vector(table(index))
	  group_split <- c(0, cumsum(group_size))
	  group_split_length <- length(group_split)
	}
	category <- ncol(y)
	intercept <- as.integer(intercept != 0)
	result_r <- vector(length = nt * (col + intercept) * category)
	solution <- .C("LB_multi_logistic",
		as.numeric(X),
		as.integer(row),
		as.integer(col),
		as.numeric(y),
		as.integer(category),
		as.numeric(kappa),
		as.numeric(alpha),
		as.numeric(alpha0_rate),
		as.numeric(result_r),
		as.integer(group_split),
		as.integer(group_split_length),
		as.integer(intercept),
		as.numeric(tlist),
		as.integer(nt),
		as.numeric(trate),
		as.integer(print)
	)
	path.multi <- sapply(0:(nt - 1), function(x)
		matrix(solution[[9]][(1+x*category*(col+intercept)):((x+1)*category*(col+intercept))], category, col+intercept),simplify = "array")
	if (group && !is.na(index)){
	  path.multi[,1:col,] <- path.multi[,1:col,][,ord_rev,]
	}
	object <- list(call = call, kappa = kappa, alpha = alpha, path = path.multi,t=solution[[13]])
	class(object) <- "lb"
	return(object)
}


.onAttach = function(libname, pkgname) {
   packageStartupMessage("Loaded Libra ", as.character(packageDescription("Libra")[["Version"]]), "\n")
}