set.seed(123)
tol = 5e-4
data("mpdta", package = "did")

# Vignette example ----

gls_fips = c("IL" = 17, "IN" = 18, "MI" = 26, "MN" = 27,
             "NY" = 36, "OH" = 39, "PA" = 42, "WI" = 55)

mpdta$gls = substr(mpdta$countyreal, 1, 2) %in% gls_fips
hmod_att = emfx(etwfe(
  lemp ~ lpop,
  tvar = year, gvar = first.treat, data = mpdta,
  vcov = ~countyreal,
  xvar = gls ## <= het. TEs by gls
))

hmod_att_known = structure(
  list(
    term = c(".Dtreat", ".Dtreat"),
    contrast = c("mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)"),
    .Dtreat = c(TRUE, TRUE),
    gls = c(FALSE, TRUE),
    estimate = c(-0.0636770215654243, -0.0472387691288247), 
    std.error = c(0.037624325357995, 0.0271333509683866),
    statistic = c(-1.69244287995966, -1.7409854456924),
    p.value = c(0.0905615623717739, 0.0816861297323847),
    conf.low = c(-0.137419344209712, -0.100419159806747),
    conf.high = c(0.010065301078863, 0.00594162154909795),
    predicted = c(8.64148641537791, 5.4198122639013),
    predicted_hi = c(8.64148641537791, 5.4198122639013),
    predicted_lo = c(8.65392156078294, 5.43931783410604)
    ),
  class = "data.frame", row.names = c(NA, -2L)
)

expect_equal(data.frame(hmod_att), hmod_att_known, tolerance = tol)

# Simulation example ----

library(data.table)

# 70 indivs
# 20 time periods
# staggered treat rollout at t = 11 and t = 16
# one control group (0), followed by tree equi-sized treatment groups 1:3
# (with each treatment group separated across rollout periods)

set.seed(1234L)

ids = 70
periods = 20

dat = CJ(id = 1:ids, period = 1:periods)

dat[
  , 
  x := 0.1*period + runif(n = .N, max = 0.1)
][
  ,
  te_grp := fcase(
    id <= 10, 0,
    id <= 20, 1,
    id <= 30, 2,
    id <= 40, 3,
    id <= 50, 1,
    id <= 60, 2,
    id <= 70, 3
  )
][
  ,
  first_treat := fcase(
    te_grp == 0, Inf,
    id <= 40, 11,
    id <= 70, 16
  )
][
  ,
  te := 0
][
  period >= first_treat,
  te := te_grp*(period-first_treat) + rnorm(.N, sd = 0.01) # add a little noise to the TEs
][
  ,
  te_grp := as.factor(te_grp)
][
  ,
  y := 1*x + te + rnorm(n = .N, sd = 0.1)
][]

## known ATTs for the event study
# dat[
#   period >= first_treat,
#   .(ATE = mean(te)),
#   by = .(te_grp, event = period - first_treat)
# ]

sim_mod = etwfe(
  y ~ x, tvar = period, gvar = first_treat, xvar = te_grp, data = dat
)

sim_att = data.frame(emfx(sim_mod))
sim_es = data.frame(emfx(sim_mod, "event"))

sim_att_known = structure(
  list(
    term = c(".Dtreat", ".Dtreat", ".Dtreat"),
    contrast = c("mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)"),
    .Dtreat = c(TRUE, TRUE, TRUE),
    te_grp = structure(2:4, levels = c("0", "1", "2", "3"), class = "factor"),
    estimate = c(3.64409742004379, 7.31766444931729, 10.992499227732),
    std.error = c(0.00179407380719288, 0.00186140973502182, 0.00176875240027723),
    statistic = c(2031.18589961779, 3931.24861852702, 6214.83211896081),
    p.value = c(0, 0, 0),
    conf.low = c(3.64058109999608, 7.31401615327618, 10.9890325367299),
    conf.high = c(3.64761374009149, 7.32131274535841, 10.9959659187341),
    predicted = c(1.11928884683436, 1.21850313334936, 1.12187654842797),
    predicted_hi = c(1.11928884683436, 1.21850313334936, 1.12187654842797),
    predicted_lo = c(1.16033613149056, 1.18535353775717, 1.12366093209766)
    ),
class = "data.frame", row.names = c(NA, -3L)
)

sim_es_known = structure(
  list(
    term = c(".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat",
    ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat",
    ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat",
    ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat", ".Dtreat",
    ".Dtreat", ".Dtreat", ".Dtreat"),
    contrast = c("mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)",  "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)", 
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)",
    "mean(TRUE) - mean(FALSE)", "mean(TRUE) - mean(FALSE)"),
    event = c(0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 
    7, 7, 8, 8, 8, 9, 9, 9),
    te_grp = structure(c(2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L,
    4L, 2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L, 4L, 2L, 3L, 4L), levels =
    c("0", "1", "2", "3"), class = "factor"),
    estimate = c(-0.0433747860105764, 0.00938397936930824, -0.0237484443237233,
    0.957378771772232, 1.98084154437699, 2.9512740289894, 2.05926108133062,
    4.00769294397011, 6.01763770859873, 2.96004156883571, 5.95023186845035,
    9.01045082176087, 3.96227702406557, 7.98393664546649, 12.0024108692038,
    4.9681600484654, 9.97378176955636, 14.9746049439207, 5.99790531685268,
    12.0085642277481, 17.9713002198987, 7.00419009415206, 14.0485277165416,
    21.0492535202485, 7.9562384175109, 15.9211608341976, 23.9627880615387,
    8.94380010368864, 17.9487582284492, 27.0134917019151),
    std.error = c(0.00552369840254662, 0.00308039552453581, 0.00522544419598851,
    0.00205792076001212, 0.0026789275267014, 0.00226141711725387,
    0.00573571931679048, 0.00473955753152871, 0.00547305722516499,
    0.00603406011590234, 0.00529658914750381, 0.00729124032872791,
    0.00640318471551457, 0.00653206197111518, 0.00663945606125071,
    0.00295432965113791, 0.00299065385436665, 0.0029365265650914,
    0.00328914289021445, 0.00325696625420919, 0.00324725444927996,
    0.00355570109388208, 0.00359971776313436, 0.00364003698320623,
    0.00393673706399228, 0.00398142824742617, 0.00394463731991408,
    0.00432447553116933, 0.00426840396777843, 0.00433361672006238),
    statistic = c(-7.85248991700545, 3.04635534448854, -4.5447704411339,
    465.216538156013, 739.415876179386, 1305.05513842279, 359.024033010547,
    845.583773867064, 1099.50206274654, 490.555531761232, 1123.40823551598,
    1235.79122556956, 618.797863891883, 1222.26896817138, 1807.74008570557,
    1681.6539232688, 3334.9836708764, 5099.42771229606, 1823.5465946758,
    3687.03980651585, 5534.30613479754, 1969.84783287983, 3902.67477645504,
    5782.70320256688, 2021.02357566202, 3998.85665263212, 6074.77598525094,
    2068.18145673962, 4205.02800670739, 6233.47505026385),
    p.value = c(4.07857686425568e-15, 0.00231633906586648, 5.49951565749448e-06,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0),
    conf.low = c(-0.0542010359410292, 0.00334651508307969, -0.0339901267510846,
    0.95334532119957, 1.97559094290747, 2.94684173288556, 2.04801927804428,
    3.99840358190566, 6.00691071355208, 2.94821502832799, 5.93985074448034,
    8.99616025331394, 3.9497270126368, 7.97113403925832, 11.9893977744468,
    4.96236966875071, 9.96792019571157, 14.9688494576134, 5.99145871524785,
    12.002180691191, 17.9649357181295, 6.99722104806826, 14.0414723993714,
    21.0421191788591, 7.94852255464887, 15.9133573782256, 23.9550567144596,
    8.93532428739553, 17.9403923104009, 27.004997969221),
    conf.high = c(-0.0325485360801236, 0.0154214436555368, -0.0135067618963619,
    0.961412222344893, 1.98609214584652, 2.95570632509324, 2.07050288461696,
    4.01698230603457, 6.02836470364538, 2.97186810934343, 5.96061299242036,
    9.02474139020781, 3.97482703549433, 7.99673925167466, 12.0154239639607,
    4.97395042818009, 9.97964334340114, 14.9803604302279, 6.0043519184575,
    12.0149477643052, 17.9776647216679, 7.01115914023585, 14.0555830337119,
    21.056387861638, 7.96395428037293, 15.9289642901696, 23.9705194086178,
    8.95227591998176, 17.9571241464975, 27.0219854346092 ),
    predicted = c(1.11928884683436, 1.21850313334936, 1.12187654842797, 
    2.27223927934572, 3.25075526143081, 4.22562319157724, 3.41773624503622, 
    5.3304812086502, 7.34077718000793, 4.47983069450461, 7.39825140864402, 
    10.4551744912022, 5.51898758174659, 9.56024486593919, 13.5889707129183, 
    6.66219858423564, 11.6340618775755, 16.6295239026893, 7.78022284239805, 
    13.7830340706269, 19.7257106348443, 8.80670806738194, 15.8441030375044, 
    22.8421436158167, 9.92746440908355, 17.8971638091109, 25.9496961215649, 
    11.0738448212718, 20.0302999313819, 29.1008653474644),
    predicted_hi = c(1.11928884683436, 1.21850313334936, 1.12187654842797,
    2.27223927934572, 3.25075526143081, 4.22562319157724, 3.41773624503622,
    5.3304812086502, 7.34077718000793, 4.47983069450461, 7.39825140864402,
    10.4551744912022, 5.51898758174659, 9.56024486593919, 13.5889707129183,
    6.66219858423564, 11.6340618775755, 16.6295239026893, 7.78022284239805,
    13.7830340706269, 19.7257106348443, 8.80670806738194, 15.8441030375044,
    22.8421436158167, 9.92746440908355, 17.8971638091109, 25.9496961215649,
    11.0738448212718, 20.0302999313819, 29.1008653474644),
    predicted_lo = c(1.16033613149056, 1.18535353775717, 1.12366093209766,
    1.26205099780809, 1.27212519929651, 1.27203272619431, 1.32089535087513,
    1.4208849736772, 1.33648273429766, 1.50627669548398, 1.44663199346231,
    1.43147511648423, 1.50807469539785, 1.52173158588081, 1.52536019185747,
    1.70003561860265, 1.65563606422989, 1.6523401019225, 1.77668540575171,
    1.77121638557215, 1.76313994558782, 1.78773310213999, 1.78309603229577,
    1.78997830079291, 1.97307133184263, 1.90832808312756, 2.02874202293141,
    2.12682633399609, 2.08191648914996, 2.08931404157561)
    ),
    class = "data.frame", row.names = c(NA, -30L)
)


# Tests ----

for (col in c("estimate", "std.error", "conf.low", "conf.high")) {
  expect_equivalent(sim_att[[col]], sim_att_known[[col]], tolerance = tol)
  expect_equivalent(sim_es[[col]], sim_es_known[[col]], tolerance = tol)
}
