
had_deriv_ot_barycenter_base_constr <- \(w, costm, N, K) {
    ROI::L_constraint(
        t(ot_barycenter_constrmat(K, N)),
        ROI::leq(K * N^2),
        do.call(c, lapply(w, `*`, c_byrow(costm)))
    )
}

had_deriv_ot_barycenter_null_constr <- \(mu, w, costm, N, K) {
    idx <- which(mu[1, ] > 0)
    m <- length(idx)
    ids <- slam::simple_triplet_matrix(
        rep(seq_len(m), each = K),
        sapply(idx, \(i) i + N * seq.int(0, K - 1)) |> c(),
        rep(1, m * K), nrow = m, ncol = K * N + (K - 1) * (N - 1)
    ) |> ROI::L_constraint(ROI::eq(m), rep(0, m))

    list(
        constr   = rbind(ids, had_deriv_ot_barycenter_base_constr(w, costm, N, K)),
        add.info = "Hadamard derivative of OT barycenter under null"
    )
}

had_deriv_ot_barycenter_alt_constr <- \(mu, w, costm, objval) {
    K <- nrow(mu)
    N <- ncol(mu)
    ids <- slam::simple_triplet_matrix(
        rep(1, K * N), seq_len(K * N),
        c_byrow(mu), nrow = 1, ncol = K * N + (K - 1) * (N - 1)
    ) |> ROI::L_constraint(ROI::eq(1), objval)

    list(
        constr   = rbind(ids, had_deriv_ot_barycenter_base_constr(w, costm, N, K)),
        add.info = "Hadamard derivative of OT barycenter under alternative"
    )
}

had_deriv_ot_barycenter <- \(G, dual_constr, solver) {

    N <- ncol(G)
    K <- nrow(G)

    lp_objval(
        objective   = ROI::L_objective(c(c_byrow(G), rep(0, (K - 1) * (N - 1)))),
        constraints = dual_constr$constr,
        types       = NULL,
        bounds      = ROI::V_bound(nobj = ncol(dual_constr$constr), ld = -Inf, ud = Inf),
        maximum     = TRUE,
        solver      = solver,
        add.info    = dual_constr$add.info
    )
}
