test_that("returns the right output", {
  model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1)

  tf <- tidypredict_fit(model)
  pm <- parse_model(model)

  expect_type(tf, "language")

  expect_s3_class(pm, "list")
  expect_equal(length(pm), 2)
  expect_equal(pm$general$model, "glmnet")
  expect_equal(pm$general$version, 1)

  expect_snapshot(
    rlang::expr_text(tf)
  )
})

test_that("Model can be saved and re-loaded", {
  model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1)

  pm <- parse_model(model)
  mp <- tempfile(fileext = ".yml")
  yaml::write_yaml(pm, mp)
  l <- yaml::read_yaml(mp)
  pm <- as_parsed_model(l)

  expect_identical(
    round_print(tidypredict_fit(model)),
    round_print(tidypredict_fit(pm))
  )
})

test_that("formulas produces correct predictions", {
  # gaussian
  expect_snapshot(
    tidypredict_test(
      glmnet::glmnet(mtcars[, -1], mtcars$mpg, family = "gaussian", lambda = 1),
      mtcars[, -1]
    )
  )

  # binomial
  expect_snapshot(
    tidypredict_test(
      glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "binomial", lambda = 1),
      mtcars[, -1]
    )
  )

  # poisson
  expect_snapshot(
    tidypredict_test(
      glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "poisson", lambda = 1),
      mtcars[, -1]
    )
  )
})

test_that("family function syntax works (#197)", {
  x <- as.matrix(mtcars[, -1])

  # gaussian()
  model <- glmnet::glmnet(x, mtcars$mpg, family = gaussian(), lambda = 0.5)
  expect_no_error(tidypredict_fit(model))

  # binomial()
  model <- glmnet::glmnet(x, mtcars$am, family = binomial(), lambda = 0.5)
  expect_no_error(tidypredict_fit(model))

  # poisson()
  model <- glmnet::glmnet(x, mtcars$carb, family = poisson(), lambda = 0.5)
  expect_no_error(tidypredict_fit(model))
})

test_that("family string syntax works (#197)", {
  x <- as.matrix(mtcars[, -1])

  # "gaussian"
  model <- glmnet::glmnet(x, mtcars$mpg, family = "gaussian", lambda = 0.5)
  expect_no_error(tidypredict_fit(model))

  # "binomial"
  model <- glmnet::glmnet(x, mtcars$am, family = "binomial", lambda = 0.5)
  expect_no_error(tidypredict_fit(model))

  # "poisson"
  model <- glmnet::glmnet(x, mtcars$carb, family = "poisson", lambda = 0.5)
  expect_no_error(tidypredict_fit(model))
})

test_that("errors if more than 1 penalty is selected", {
  model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg)

  expect_snapshot(
    error = TRUE,
    tidypredict_fit(model)
  )

  model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = c(1, 5))

  expect_snapshot(
    error = TRUE,
    tidypredict_fit(model)
  )
})

test_that("glmnet are handeld neatly with parsnip", {
  spec <- parsnip::linear_reg(engine = "glmnet", penalty = 1)

  model <- parsnip::fit(spec, mpg ~ ., mtcars)

  tf <- tidypredict_fit(model)
  pm <- parse_model(model)

  expect_type(tf, "language")

  expect_s3_class(pm, "list")
  expect_equal(length(pm), 2)
  expect_equal(pm$general$model, "glmnet")
  expect_equal(pm$general$version, 1)

  expect_snapshot(
    rlang::expr_text(tf)
  )
})

test_that("Gamma family works (#200)", {
  x <- as.matrix(mtcars[, -1])
  model <- glmnet::glmnet(x, mtcars$mpg, family = Gamma(), lambda = 0.5)

  fit <- tidypredict_fit(model)
  native <- unname(predict(model, x, type = "response")[, 1])
  tidy <- rlang::eval_tidy(fit, mtcars)

  expect_equal(tidy, native)
})

test_that("Cox family works (#201)", {
  skip_if_not_installed("survival")
  x <- as.matrix(mtcars[, -c(1, 8)])
  y <- survival::Surv(mtcars$mpg, mtcars$vs)
  model <- glmnet::glmnet(x, y, family = "cox", lambda = 0.1)

  fit <- tidypredict_fit(model)
  native <- unname(predict(model, x, type = "link")[, 1])
  tidy <- rlang::eval_tidy(fit, mtcars)

  expect_equal(tidy, native)
})

test_that("multinomial family errors with helpful message (#198)", {
  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial",
    lambda = 0.5
  )

  expect_snapshot(error = TRUE, tidypredict_fit(model))
})

test_that("mgaussian family errors with helpful message (#199)", {
  x <- as.matrix(mtcars[, -c(1, 4)])
  y <- cbind(mtcars$mpg, mtcars$hp)
  model <- glmnet::glmnet(x, y, family = "mgaussian", lambda = 0.5)

  expect_snapshot(error = TRUE, tidypredict_fit(model))
})

# Tests for .extract_glmnet_multiclass()

test_that(".extract_glmnet_multiclass returns correct structure", {
  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial",
    lambda = 0.5
  )

  result <- .extract_glmnet_multiclass(model)

  expect_type(result, "list")
  expect_length(result, 3)
  expect_named(result, levels(iris$Species))
  expect_type(result[[1]], "character")
})

test_that(".extract_glmnet_multiclass errors on non-multnet model", {
  model <- glmnet::glmnet(mtcars[, -1], mtcars$mpg, lambda = 1)

  expect_snapshot(error = TRUE, .extract_glmnet_multiclass(model))
})

test_that(".extract_glmnet_multiclass errors with multiple penalties", {
  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial"
  )

  expect_snapshot(error = TRUE, .extract_glmnet_multiclass(model))
})

test_that(".extract_glmnet_multiclass works with explicit penalty", {
  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial"
  )

  result <- .extract_glmnet_multiclass(model, penalty = 0.01)

  expect_type(result, "list")
  expect_length(result, 3)
})

test_that(".extract_glmnet_multiclass handles sparse coefficients", {
  # High penalty should zero out many coefficients

  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial",
    lambda = 10
  )

  result <- .extract_glmnet_multiclass(model)

  expect_type(result, "list")
  expect_length(result, 3)
})

test_that(".extract_glmnet_multiclass produces correct predictions", {
  model <- glmnet::glmnet(
    as.matrix(iris[, 1:4]),
    iris$Species,
    family = "multinomial",
    lambda = 0.01
  )

  eqs <- .extract_glmnet_multiclass(model)
  n_rows <- nrow(iris)

  # Evaluate each linear predictor, recycling scalars to full length
  logits <- sapply(eqs, function(eq) {
    val <- rlang::eval_tidy(rlang::parse_expr(eq), iris)
    if (length(val) == 1) rep(val, n_rows) else val
  })

  # Apply softmax
  exp_logits <- exp(logits)
  probs <- exp_logits / rowSums(exp_logits)

  # Compare to native predictions
  native <- predict(model, as.matrix(iris[, 1:4]), type = "response")[,, 1]

  expect_equal(unname(probs), unname(native), tolerance = 1e-10)
})

# Tests for .build_linear_pred()

test_that(".build_linear_pred handles intercept only", {
  result <- .build_linear_pred("(Intercept)", 5.5)

  expect_equal(result, "5.5")
})

test_that(".build_linear_pred handles single predictor", {
  result <- .build_linear_pred(c("(Intercept)", "x"), c(1.5, 2.0))

  expect_equal(result, "1.5 + (`x` * 2)")
})

test_that(".build_linear_pred handles multiple predictors", {
  result <- .build_linear_pred(
    c("(Intercept)", "x", "y"),
    c(1.0, 2.0, 3.0)
  )

  expect_equal(result, "1 + (`x` * 2) + (`y` * 3)")
})

test_that(".build_linear_pred skips zero coefficients", {
  result <- .build_linear_pred(
    c("(Intercept)", "x", "y", "z"),
    c(1.0, 0.0, 2.0, 0.0)
  )

  expect_identical(result, "1 + (`y` * 2)")
})

test_that(".build_linear_pred returns '0' when all coefficients are zero", {
  result <- .build_linear_pred(
    c("(Intercept)", "x", "y"),
    c(0, 0, 0)
  )

  expect_equal(result, "0")
})

test_that(".build_linear_pred handles negative coefficients", {
  result <- .build_linear_pred(
    c("(Intercept)", "x"),
    c(-1.5, -2.0)
  )

  expect_equal(result, "-1.5 + (`x` * -2)")
})

test_that(".build_linear_pred handles special characters in variable names", {
  result <- .build_linear_pred(
    c("(Intercept)", "var with space", "var.with.dots"),
    c(1.0, 2.0, 3.0)
  )

  expect_identical(result, "1 + (`var with space` * 2) + (`var.with.dots` * 3)")
})

test_that(".build_linear_pred handles no intercept", {
  result <- .build_linear_pred(c("x", "y"), c(2.0, 3.0))

  expect_equal(result, "(`x` * 2) + (`y` * 3)")
})

test_that(".build_linear_pred handles zero intercept", {
  result <- .build_linear_pred(
    c("(Intercept)", "x"),
    c(0, 2.0)
  )

  expect_equal(result, "(`x` * 2)")
})
