library(distfreereg)

all.equal.distfreereg <- distfreereg:::all.equal.distfreereg
test_dfr_functions <- distfreereg:::test_dfr_functions

n <- 1e2
func <- function(X, theta) theta[1] + theta[2]*X[,1] + theta[3]*X[,2]
set.seed(20250516)
Sig <- rWishart(1, df = n, Sigma = diag(n))[,,1]
X <- matrix(rexp(2*n, rate = 1), nrow = n)
theta <- c(2,5,1)
Y <- distfreereg:::f2ftheta(f = func, X)(theta) +
  as.vector(distfreereg:::rmvnorm(n = n, reps = 1, mean = rep(0,n), SqrtSigma = distfreereg:::matsqrt(Sig)))

df_nls <- as.data.frame(cbind(Y, X, rep(1:10, 10)))
names(df_nls) <- c("z", "x", "y", "g")
wt <- rexp(n) + 1
form_nls <- z ~ d + e*x + f*y

set.seed(20250516)
dfr_form_nls <- distfreereg(test_mean = form_nls, data = df_nls,
                            method_args = list(weights = wt),
                            method = "nls", verbose = FALSE,
                            control = list(return_on_error = FALSE))

dfr_form_nls


set.seed(20250516)
dfr_form_nls_no_weights <- distfreereg(test_mean = form_nls, data = df_nls,
                                       method = "nls",
                                       control = list(return_on_error = FALSE))

newdata_nls <- data.frame(a = rnorm(10), b = rnorm(10))
test_dfr_functions(dfr_form_nls, newdata = newdata_nls)

m_nls <- nls(form_nls, data = df_nls, weights = wt)
m_nls_na_omit <- update(m_nls, na.action = na.omit)

set.seed(20250516)
dfr_nls <- distfreereg(test_mean = m_nls, verbose = FALSE,
                       control = list(return_on_error = FALSE))

dfr_nls

test_dfr_functions(dfr_nls, newdata = newdata_nls)

stopifnot(all.equal(dfr_nls, dfr_form_nls))

set.seed(20250516)
dfr_nls_override <- distfreereg(test_mean = m_nls_na_omit,
                                override = list(J = dfr_nls[["J"]],
                                                fitted_values = dfr_nls[["fitted_values"]]),
                                control = list(return_on_error = FALSE))


cdfr_form_nls <- asymptotics(dfr_form_nls, reps = 5)
cdfr_nls <- asymptotics(dfr_nls, reps = 5)

signif(rejection(cdfr_form_nls, alpha = c(0.1, 0.5))[,2:3], digits = 3)
signif(rejection(cdfr_nls, alpha = c(0.1, 0.5))[,2:3], digits = 3)




# Orderings

set.seed(20250516)
dfr_nls_asis <- update(dfr_nls, ordering = "asis")
set.seed(20250516)
dfr_form_nls_asis <- update(dfr_form_nls, ordering = "asis")
stopifnot(all.equal(dfr_nls_asis, dfr_form_nls_asis))

set.seed(20250516)
dfr_nls_optimal <- update(dfr_nls, ordering = "optimal")
set.seed(20250516)
dfr_form_nls_optimal <- update(dfr_form_nls, ordering = "optimal")
stopifnot(all.equal(dfr_nls_optimal, dfr_form_nls_optimal))

set.seed(20250516)
dfr_nls_natural <- update(dfr_nls, ordering = "natural")
set.seed(20250516)
dfr_form_nls_natural <- update(dfr_form_nls, ordering = "natural")
stopifnot(all.equal(dfr_nls_natural, dfr_form_nls_natural))

set.seed(20250516)
form_nls_g <- z ~ d + e*x + f*y + I(0*is.numeric(g))
m_nls_g <- nls(form_nls_g, data = df_nls, weights = wt)
dfr_nls_g <- update(dfr_nls, ordering = list("g"), test_mean = m_nls_g)
set.seed(20250516)
dfr_form_nls_g <- update(dfr_form_nls, ordering = list("g"), test_mean = form_nls_g)
stopifnot(all.equal(dfr_nls_g, dfr_form_nls_g))

df_nls[dfr_nls_g[["res_order"]],][["g"]]

set.seed(20250516)
dfr_nls_g_grouped <- update(dfr_nls_g, group = TRUE)
set.seed(20250516)
dfr_form_nls_g_grouped <- update(dfr_form_nls_g, group = TRUE)
stopifnot(all.equal(dfr_nls_g_grouped, dfr_form_nls_g_grouped))





### Partial output

dfr_nls_partial <- distfreereg(test_mean = m_nls, verbose = FALSE,
                               control = list(orth_tol = 1e-100))
names(dfr_nls_partial)



### Failures

tryCatch(distfreereg:::distfreereg.formula(test_mean = "a", data = df_nls,
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = df_nls,
                                           theta_init = 1:3,
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = as.matrix(df_nls),
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = data.frame(),
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = df_nls,
                                           method_args = "weights = wt",
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = df_nls,
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE),
                                           override = list(theta_hat = 1:3)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = df_nls,
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE),
                                           override = list(1:3)),
         error = function(e) warning(e)
)

tryCatch(distfreereg:::distfreereg.formula(test_mean = form_nls, data = df_nls,
                                           method_args = list(weights = wt),
                                           method = "nls", verbose = FALSE,
                                           control = list(return_on_error = FALSE),
                                           override = list(a = 1:3)),
         error = function(e) warning(e)
)

tryCatch(distfreereg(test_mean = m_nls, control = "hello"),
         error = function(e) warning(e)
)
