context("Momentum")

test_that("momentum schedules", {
  constant_m <- make_constant(0.5)
  expect_equal(constant_m(1), 0.5)
  expect_equal(constant_m(500), 0.5)
  expect_equal(constant_m(1000), 0.5)

  constant_m <- make_constant(0.25)
  expect_equal(constant_m(1), 0.25)
  expect_equal(constant_m(500), 0.25)
  expect_equal(constant_m(1000), 0.25)

  constant_m <- make_constant(0.95)
  expect_equal(constant_m(1), 0.95)
  expect_equal(constant_m(500), 0.95)
  expect_equal(constant_m(1000), 0.95)

  step_m <- make_switch(init_value = 0.2, final_value = 0.6, switch_iter = 100)
  expect_equal(step_m(1), 0.2)
  expect_equal(step_m(99), 0.2)
  expect_equal(step_m(100), 0.6)
  expect_equal(step_m(101), 0.6)
  expect_equal(step_m(1000), 0.6)

  linear_m <- make_ramp(init_value = 0.1, final_value = 0.8)
  expect_equal(linear_m(1, max_iter = 1000), 0.1)
  expect_equal(linear_m(500, max_iter = 1000), 0.45, tol = 1e-3)
  expect_equal(linear_m(1000, max_iter = 1000), 0.8)

  nest_m <- make_nesterov_convex_approx(burn_in = 0, use_init_mu = FALSE)
  expect_equal(nest_m(0), 0)
  expect_equal(nest_m(1), 0.5)
  expect_equal(nest_m(5), 0.7)
  expect_equal(nest_m(10), 0.8)
  expect_equal(nest_m(20), 0.88)
  expect_equal(nest_m(50), 0.9455, tolerance = 0.0001)
  expect_equal(nest_m(500), 0.9941, tolerance = 0.0001)
  expect_equal(nest_m(1000), 0.9970, tolerance = 0.0001)

  nest_m <- make_nesterov_convex_approx(burn_in = 0, use_init_mu = TRUE)
  expect_equal(nest_m(0), 0.4)
  expect_equal(nest_m(1), 0.5)
  expect_equal(nest_m(5), 0.7)
  expect_equal(nest_m(10), 0.8)
  expect_equal(nest_m(20), 0.88)
  expect_equal(nest_m(50), 0.9455, tolerance = 0.0001)
  expect_equal(nest_m(500), 0.9941, tolerance = 0.0001)
  expect_equal(nest_m(1000), 0.9970, tolerance = 0.0001)

  nest_m <- make_nesterov_convex_approx(burn_in = 1, use_init_mu = FALSE)
  expect_equal(nest_m(0), 0)
  expect_equal(nest_m(1), 0)
  expect_equal(nest_m(2), 0.5)
  expect_equal(nest_m(6), 0.7)
  expect_equal(nest_m(11), 0.8)
  expect_equal(nest_m(21), 0.88)
  expect_equal(nest_m(51), 0.9455, tolerance = 0.0001)
  expect_equal(nest_m(501), 0.9941, tolerance = 0.0001)
  expect_equal(nest_m(1001), 0.9970, tolerance = 0.0001)

  nest_m <- make_nesterov_convex_approx(burn_in = 1, use_init_mu = TRUE)
  expect_equal(nest_m(0), 0)
  expect_equal(nest_m(1), 0.4)
  expect_equal(nest_m(2), 0.5)
  expect_equal(nest_m(6), 0.7)
  expect_equal(nest_m(11), 0.8)
  expect_equal(nest_m(21), 0.88)
  expect_equal(nest_m(51), 0.9455, tolerance = 0.0001)
  expect_equal(nest_m(501), 0.9941, tolerance = 0.0001)
  expect_equal(nest_m(1001), 0.9970, tolerance = 0.0001)
})
