# Copyright 2024 DARWIN EU®
#
# This file is part of TreatmentPatterns
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' export
#'
#' Export andromeda generated by \link[TreatmentPatterns]{computePathways}
#' object to sharable csv-files and/or a zip archive.
#'
#' @export
#'
#' @template param_andromeda
#' @param outputPath (`character`: `NULL`) Output path where to write output
#' files to. When set to `NULL` no files will be written, and only the results
#' object is returned.
#' @template param_ageWindow
#' @template param_minCellCount
#' @template param_censorType
#' @template param_archiveName
#' @param nonePaths (`logical(1)`) Should `None` paths be included? This will
#' fetch all persons included in the target cohort and assign them a `"None"`
#' pathway. Significantly impacts performance.
#' @param stratify (`logical(1)`) Should pathways be stratified? This will
#' perform pairwise stratification between age, sex, and index year.
#' Significantly impacts performance.
#'
#' @return `TreatmentPatternsResults` object 
#'
#' @examples
#' \donttest{
#' ableToRun <- all(
#'   require("CirceR", character.only = TRUE, quietly = TRUE),
#'   require("CDMConnector", character.only = TRUE, quietly = TRUE),
#'   require("TreatmentPatterns", character.only = TRUE, quietly = TRUE),
#'   require("dplyr", character.only = TRUE, quietly = TRUE)
#' )
#'
#' if (ableToRun) {
#'   library(TreatmentPatterns)
#'   library(CDMConnector)
#'   library(dplyr)
#'
#'   withr::local_envvar(
#'     R_USER_CACHE_DIR = tempfile(),
#'     EUNOMIA_DATA_FOLDER = Sys.getenv("EUNOMIA_DATA_FOLDER", unset = tempfile())
#'   )
#'
#'   tryCatch({
#'     if (Sys.getenv("skip_eunomia_download_test") != "TRUE") {
#'       CDMConnector::downloadEunomiaData(overwrite = TRUE)
#'     }
#'   }, error = function(e) NA)
#'
#'   con <- DBI::dbConnect(duckdb::duckdb(), dbdir = eunomiaDir())
#'   cdm <- cdmFromCon(con, cdmSchema = "main", writeSchema = "main")
#'
#'   cohortSet <- readCohortSet(
#'     path = system.file(package = "TreatmentPatterns", "exampleCohorts")
#'   )
#'
#'   cdm <- generateCohortSet(
#'     cdm = cdm,
#'     cohortSet = cohortSet,
#'     name = "cohort_table"
#'   )
#'
#'   cohorts <- cohortSet %>%
#'     # Remove 'cohort' and 'json' columns
#'     select(-"cohort", -"json") %>%
#'     mutate(type = c("event", "event", "event", "event", "exit", "event", "event", "target")) %>%
#'     rename(
#'       cohortId = "cohort_definition_id",
#'       cohortName = "cohort_name",
#'     ) %>%
#'     select("cohortId", "cohortName", "type")
#'
#'   outputEnv <- computePathways(
#'     cohorts = cohorts,
#'     cohortTableName = "cohort_table",
#'     cdm = cdm
#'   )
#'
#'   results <- export(
#'     andromeda = outputEnv
#'   )
#'
#'   Andromeda::close(outputEnv)
#'   DBI::dbDisconnect(con, shutdown = TRUE)
#' }
#' }
export <- function(
    andromeda,
    outputPath = NULL,
    ageWindow = 10,
    minCellCount = 5,
    censorType = "minCellCount",
    archiveName = NULL,
    nonePaths = FALSE,
    stratify = FALSE) {
  validateExport()
  
  nrows <- andromeda$treatmentHistoryFinal %>%
    dplyr::summarize(n()) %>%
    dplyr::pull()
  
  if (nrows == 0) {
    message("Treatment History table is empty. Nothing to export.")
    return(TreatmentPatternsResults$new())
  }

  if (!is.null(outputPath)) {
    dir.create(outputPath, showWarnings = FALSE, recursive = TRUE)
  }

  treatmentHistory <- andromeda$treatmentHistoryFinal %>%
    dplyr::inner_join(andromeda$cohorts, join_by(targetCohortId == cohortId)) %>%
    dplyr::collect() %>%
    dplyr::select(
      "personId", "indexYear", "age", "sex", "eventCohortName", "eventCohortId",
      targetCohortName = "cohortName",
      "targetCohortId", "eventSeq", "durationEra", "n_target"
    )

  targetsTH <- treatmentHistory %>%
    dplyr::group_by(.data$targetCohortName) %>%
    dplyr::group_split()
  
  analysisId <- andromeda$analyses %>%
    dplyr::pull(.data$analysis_id)
  
  analyses <- andromeda$analyses %>%
    dplyr::collect()

  metadata <- andromeda$metadata %>%
    dplyr::collect() %>%
    dplyr::mutate(analysis_id = analysisId)

  cdmSourceInfo <- andromeda$cdm_source_info %>%
    dplyr::collect() %>%
    dplyr::mutate(analysis_id = analysisId)

  arguments <- andromeda$arguments %>%
    dplyr::collect()

  results <- lapply(targetsTH, function(treatmentHistory) {
    targetCohortId <- unique(treatmentHistory$targetCohortId)
    targetCohortName <- unique(treatmentHistory$targetCohortName)

    treatmentHistory <- if (nonePaths) {
      dplyr::bind_rows(
        treatmentHistory,
        getFilteredSubjects(andromeda)
      ) %>%
        mutate(
          targetCohortId = targetCohortId,
          targetCohortName = targetCohortName
        )
    } else {
      treatmentHistory
    }
    
    attrition <- andromeda$attrition %>%
      dplyr::collect() %>%
      dplyr::mutate(
        analysis_id = analysisId,
        target_cohort_id = targetCohortId,
        target_cohort_name = targetCohortName
      )
    
    treatmentPathways <- computeTreatmentPathways(
      treatmentHistory = treatmentHistory,
      ageWindow = ageWindow,
      minCellCount = minCellCount,
      censorType = censorType,
      stratify = stratify
    ) %>%
      dplyr::distinct() %>%
      rename(
        index_year = "indexYear",
        pathway = "path"
      ) %>%
      dplyr::mutate(
        analysis_id = analysisId,
        target_cohort_id = targetCohortId,
        target_cohort_name = targetCohortName
      )
    
    summaryEventDuration <- computeStatsTherapy(treatmentHistory) %>%
      dplyr::mutate(
        analysis_id = analysisId,
        target_cohort_id = targetCohortId,
        target_cohort_name = targetCohortName
      )
    
    counts <- computeCounts(treatmentHistory, minCellCount)
    
    counts <- lapply(counts, function(item) {
      item %>%
        dplyr::mutate(
          analysis_id = analysisId,
          target_cohort_id = targetCohortId,
          target_cohort_name = targetCohortName
        )
    })

    TreatmentPatternsResults$new(
      attrition = attrition,
      treatmentPathways = treatmentPathways,
      summaryEventDuration = summaryEventDuration,
      countsAge = counts$age,
      countsSex = counts$sex,
      countsYear = counts$year
    )
  })

  attrition <- lapply(results, function(tpr) {
    tpr$attrition
  }) %>%
    dplyr::bind_rows()

  treatmentPathways <- lapply(results, function(tpr) {
    tpr$treatment_pathways
  }) %>%
    dplyr::bind_rows()

  summaryEventDuration <- lapply(results, function(tpr) {
    tpr$summary_event_duration
  }) %>%
    dplyr::bind_rows()

  countsAge <- lapply(results, function(tpr) {
    tpr$counts_age
  }) %>%
    dplyr::bind_rows()

  countsSex <- lapply(results, function(tpr) {
    tpr$counts_sex
  }) %>%
    dplyr::bind_rows()

  countsYear <- lapply(results, function(tpr) {
    tpr$counts_year
  }) %>%
    dplyr::bind_rows()

  tpr <- TreatmentPatternsResults$new(
    attrition = attrition,
    metadata = metadata,
    treatmentPathways = treatmentPathways,
    summaryEventDuration = summaryEventDuration,
    countsAge = countsAge,
    countsSex = countsSex,
    countsYear = countsYear,
    cdmSourceInfo = cdmSourceInfo,
    analyses = analyses,
    arguments = arguments
  )

  if (!is.null(outputPath)) {
    tpr$saveAsCsv(path = outputPath)
  }

  if (!is.null(outputPath) & !is.null(archiveName)) {
    tpr$saveAsZip(path = outputPath, name = archiveName)
  }
  return(tpr)
}

validateExport <- function() {
  args <- eval(
    expr = expression(mget(names(formals()))),
    envir = sys.frame(sys.nframe() - 1)
  )

  assertCol <- checkmate::makeAssertCollection()
  checkmate::assertTRUE(
    x = Andromeda::isAndromeda(args$andromeda),
    add = assertCol,
    .var.name = "andromeda"
  )

  checkmate::assertCharacter(
    x = args$outputPath,
    len = 1,
    null.ok = TRUE,
    add = assertCol,
    .var.name = "outputPath"
  )

  checkmate::assertIntegerish(
    x = args$ageWindow,
    min.len = 1,
    any.missing = FALSE,
    unique = TRUE,
    add = assertCol,
    .var.name = "ageWindow"
  )

  checkmate::assertIntegerish(
    x = args$minCellCount,
    len = 1,
    lower = 1,
    add = assertCol,
    .var.name = "minCellCount"
  )

  checkmate::assertChoice(
    x = args$censorType,
    choices = c("minCellCount", "remove", "mean"),
    .var.name = "censorType"
  )

  checkmate::assertCharacter(
    x = args$archiveName,
    len = 1,
    add = assertCol,
    null.ok = TRUE,
    .var.name = "archiveName"
  )

  checkmate::assertLogical(
    x = args$nonePaths,
    len = 1,
    add = assertCol,
    null.ok = FALSE,
    .var.name = "nonePaths"
  )
  checkmate::reportAssertions(assertCol)
}

#' computeStatsTherapy
#' 
#' @noRd
#'
#' @template param_treatmentHistory
#'
#' @return (`data.frame()`)
computeStatsTherapy <- function(treatmentHistory) {
  dplyr::bind_rows(
    treatmentHistory %>%
      dplyr::mutate(eventName = dplyr::case_when(
        nchar(.data$eventCohortId) > 1 ~ "combination-event",
        .default = "mono-event"
      )) %>%
      dplyr::group_by(.data$eventName) %>%
      dplyr::summarise(
        duration_min = min(.data$durationEra, na.rm = TRUE),
        duration_q1 = quantile(.data$durationEra, probs = 0.25, na.rm = TRUE),
        duration_median = stats::median(.data$durationEra, na.rm = TRUE),
        duration_q2 = stats::quantile(.data$durationEra, probs = 0.75, na.rm = TRUE),
        duration_max = max(.data$durationEra, na.rm = TRUE),
        duration_average = mean(.data$durationEra, na.rm = TRUE),
        duration_sd = stats::sd(.data$durationEra, na.rm = TRUE),
        event_count = n()
      ) %>%
      dplyr::mutate(line = "overall"),

    treatmentHistory %>%
      dplyr::group_by(.data$eventSeq) %>%
      dplyr::mutate(eventName = dplyr::case_when(
        nchar(.data$eventCohortId) > 1 ~ "combination-event",
        .default = "mono-event"
      )) %>%
      dplyr::ungroup() %>%
      dplyr::group_by(.data$eventName, .data$eventSeq) %>%
      dplyr::summarise(
        duration_min = min(.data$durationEra, na.rm = TRUE),
        duration_q1 = quantile(.data$durationEra, probs = 0.25, na.rm = TRUE),
        duration_median = stats::median(.data$durationEra, na.rm = TRUE),
        duration_q2 = quantile(.data$durationEra, probs = 0.75, na.rm = TRUE),
        duration_max = max(.data$durationEra, na.rm = TRUE),
        duration_average = mean(.data$durationEra, na.rm = TRUE),
        duration_sd = stats::sd(.data$durationEra, na.rm = TRUE),
        event_count = n()
      ) %>%
      mutate(line = as.character(.data$eventSeq)) %>%
      select(-"eventSeq"),
    
    treatmentHistory %>%
      dplyr::filter(.data$eventCohortName != "None") %>%
      dplyr::group_by(.data$eventCohortName) %>%
      dplyr::summarise(
        duration_min = min(.data$durationEra, na.rm = TRUE),
        duration_q1 = stats::quantile(.data$durationEra, probs = 0.25, na.rm = TRUE),
        duration_median = stats::median(.data$durationEra, na.rm = TRUE),
        duration_q2 = stats::quantile(.data$durationEra, probs = 0.75, na.rm = TRUE),
        duration_max = max(.data$durationEra, na.rm = TRUE),
        duration_average = mean(.data$durationEra, na.rm = TRUE),
        duration_sd = stats::sd(.data$durationEra, na.rm = TRUE),
        event_count = n()
      ) %>%
      dplyr::mutate(line = "overall") %>%
      dplyr::rename(eventName = "eventCohortName"),
    
    treatmentHistory %>%
      dplyr::filter(.data$eventCohortName != "None") %>%
      dplyr::group_by(.data$eventSeq, .data$eventCohortName) %>%
      dplyr::summarise(
        duration_min = min(.data$durationEra, na.rm = TRUE),
        duration_q1 = stats::quantile(.data$durationEra, probs = 0.25, na.rm = TRUE),
        duration_median = stats::median(.data$durationEra, na.rm = TRUE),
        duration_q2 = stats::quantile(.data$durationEra, probs = 0.75, na.rm = TRUE),
        duration_max = max(.data$durationEra, na.rm = TRUE),
        duration_average = mean(.data$durationEra, na.rm = TRUE),
        duration_sd = stats::sd(.data$durationEra, na.rm = TRUE),
        event_count = n(), .groups = "drop"
      ) %>%
      dplyr::mutate(line = as.character(.data$eventSeq)) %>%
      dplyr::select(-"eventSeq") %>%
      dplyr::rename(eventName = "eventCohortName")
  ) %>%
    dplyr::rename(event_name = "eventName")
}

countYear <- function(treatmentHistory, minCellCount) {
  treatmentHistory %>%
    dplyr::group_by(.data$personId) %>%
    dplyr::slice(which.min(.data$indexYear)) %>%
    dplyr::group_by(.data$indexYear) %>%
    dplyr::count() %>%
    dplyr::ungroup() %>%
    dplyr::mutate(n = case_when(
      .data$n < minCellCount ~ sprintf("<%s", minCellCount),
      .default = as.character(.data$n)
    )) %>%
    dplyr::rename(index_year = "indexYear")
}

countSex <- function(treatmentHistory, minCellCount) {
  treatmentHistory %>%
    dplyr::group_by(.data$personId) %>%
    dplyr::slice(which.min(.data$indexYear)) %>%
    dplyr::group_by(.data$sex) %>%
    dplyr::count() %>%
    dplyr::ungroup() %>%
    dplyr::mutate(n = case_when(
      .data$n < minCellCount ~ sprintf("<%s", minCellCount),
      .default = as.character(.data$n)
    ))
}

countAge <- function(treatmentHistory, minCellCount) {
  treatmentHistory %>%
    dplyr::group_by(.data$personId) %>%
    dplyr::slice(which.min(.data$indexYear)) %>%
    dplyr::group_by(.data$age) %>%
    dplyr::count() %>%
    dplyr::ungroup() %>%
    dplyr::mutate(n = case_when(
      .data$n < minCellCount ~ sprintf("<%s", minCellCount),
      .default = as.character(.data$n)
    ))
}

#' computeCounts
#'
#' @noRd
#'
#' @template param_treatmentHistory
#' @template param_minCellCount
#'
#' @return (`list()`)
computeCounts <- function(treatmentHistory, minCellCount) {
  # n per Year
  list(
    year = countYear(treatmentHistory, minCellCount),
    age = countAge(treatmentHistory, minCellCount),
    sex = countSex(treatmentHistory, minCellCount)
  )
}

#' censorminCellCount
#' @param treatmentPathways data.frame()
#' @param minCellCount numeric(1)
#' 
#' @noRd
censorminCellCount <- function(treatmentPathways, minCellCount) {
  treatmentPathways %>%
    dplyr::mutate(freq = dplyr::case_when(
      .data$freq >= minCellCount ~ .data$freq,
      .data$freq < minCellCount ~ minCellCount,
      .default = .data$freq))
}

#' censorRemove
#' @param treatmentPathways data.frame()
#' @param minCellCount numeric(1)
#' 
#' @noRd
censorRemove <- function(treatmentPathways, minCellCount) {
  treatmentPathways %>%
    dplyr::filter(.data$freq >= minCellCount)
}

#' censorRemove
#' @param treatmentPathways data.frame()
#' @param minCellCount numeric(1)
#' @param meanCount numeric(1)
#' 
#' @noRd
censorMean <- function(treatmentPathways, minCellCount) {
  meanFreq <- treatmentPathways %>%
    dplyr::filter(.data$freq < minCellCount) %>%
    dplyr::pull(.data$freq) %>%
    mean() %>%
    round()
  
  treatmentPathways %>%
    dplyr::mutate(freq = dplyr::case_when(
      .data$freq >= minCellCount ~ .data$freq,
      .data$freq < minCellCount ~ meanFreq,
      .default = .data$freq))
}

#' censorData
#' @param treatmentPathways data.frame()
#' @param minCellCount numeric(1)
#' @param censorType character(1)
#' 
#' @noRd
censorData <- function(treatmentPathways, minCellCount, censorType) {
  nCensored <- treatmentPathways %>%
    dplyr::filter(.data$freq < minCellCount) %>%
    nrow()
  
  treatmentPathways <- switch(
    censorType,
    "minCellCount" = {
      message(sprintf("Censoring %s pathways with a frequency <%s to %s.", nCensored, minCellCount, minCellCount))
      censorminCellCount(treatmentPathways, minCellCount)
    },
    "remove" = {
      message(sprintf("Removing %s pathways with a frequency <%s.", nCensored, minCellCount))
      censorRemove(treatmentPathways, minCellCount)
    },
    "mean" = {
      message(sprintf("Censoring %s pathways with a frequency <%s to mean.", nCensored, minCellCount))
      censorMean(treatmentPathways, minCellCount)
    })
  return(treatmentPathways)
}

#' makeAgeWindow
#' 
#' @param ageWindow numeric(n)
#' 
#' @noRd
makeAgeWindow <- function(ageWindow) {
  if (length(ageWindow) > 1) {
    return(ageWindow)
  } else {
    return(seq(0, 150, ageWindow))
  }
}

#' groupByAgeWindow
#' 
#' @param treatmentHistory data.frame()
#' @param ageWindow numeric(n)
#'
#' @noRd
groupByAgeWindow <- function(treatmentHistory, ageWindow) {
  treatmentHistory %>%
    dplyr::rowwise() %>%
    dplyr::mutate(
      ageBin = paste(
        unlist(stringi::stri_extract_all(
          str = as.character(cut(.data$age, makeAgeWindow(ageWindow))),
          regex = "\\d+"
        )),
        collapse = "-"
      )
    )
}

#' computeTreatmentPathways
#'
#' @param treatmentHistory data.frame()
#' @param ageWindow numeric(n)
#' @param minCellCount numeric(1)
#' @param censorType character(1)
#' @param stratify (logical(1))
#'
#' @return (`data.frame()`)
#' 
#' @noRd
computeTreatmentPathways <- function(treatmentHistory, ageWindow, minCellCount, censorType, stratify) {
  treatmentPathways <- groupByAgeWindow(treatmentHistory, ageWindow)
  
  treatmentPathways <- treatmentPathways %>%
    dplyr::mutate(indexYear = as.character(.data$indexYear))
  
  treatmentPathways <- if (stratify) {
    treatmentPathways <- stratisfy(treatmentPathways)
    treatmentPathways[is.na(treatmentPathways)] <- "all"
    treatmentPathways <- censorData(treatmentPathways, minCellCount, censorType)
    treatmentPathways$path[treatmentPathways$path == "NA"] <- "None"
    treatmentPathways
  } else {
    treatmentHistory %>%
      dplyr::collect() %>%
      dplyr::mutate(n_target = dplyr::case_when(
        is.na(.data$n_target) ~ 1,
        .default = .data$n_target
      )) %>%
      dplyr::group_by(.data$n_target, .data$personId) %>%
      dplyr::arrange(.data$eventSeq) %>%
      dplyr::distinct(
        .data$personId, .data$eventCohortName, .data$eventCohortId,
        .data$targetCohortName, .data$targetCohortId, .data$eventSeq,
        .data$durationEra, .data$n_target
      ) %>%
      dplyr::reframe(
        dplyr::across(
          "eventCohortName", paste, collapse = "-"
        )
      ) %>%
      dplyr::rename(path = "eventCohortName") %>%
      dplyr::group_by(path) %>%
      dplyr::summarise(freq = n()) %>%
      dplyr::mutate(age = "all", sex = "all", indexYear = "all") %>%
      dplyr::arrange(desc(.data$freq), .data$path)
  }

  treatmentHistory <- treatmentHistory %>%
    dplyr::mutate(
      n_target = dplyr::case_when(
        is.na(.data$n_target) ~ 1,
        .default= .data$n_target
      )
    )
  return(treatmentPathways)
}

collapsePaths <- function(treatmentHistory) {
  treatmentHistory %>%
    dplyr::arrange(.data$eventSeq) %>%
    dplyr::group_by(.data$personId, .data$indexYear) %>%
    dplyr::mutate(
      pathway = list(.data$eventCohortName[.data$eventSeq]),
      .groups = "drop"
    ) %>%
    dplyr::ungroup() %>%
    dplyr::group_by(.data$indexYear, .data$pathway) %>%
    dplyr::mutate(freq = length(.data$personId), .groups = "drop") %>%
    ungroup() %>%
    rowwise() %>%
    mutate(path = paste(.data$pathway, collapse = "-")) %>%
    dplyr::group_by(.data$personId) %>%
    dplyr::slice(which.min(.data$indexYear))
}

stratisfyAgeSexYear <- function(treatmentHistory) {
  collapsePaths(treatmentHistory) %>%
    group_by(.data$path, .data$ageBin, .data$sex, .data$indexYear) %>%
    summarise(freq = n(), .groups = "drop") %>%
    mutate(
      indexYear = as.character(.data$indexYear)
    )
}

# All
stratAll <- function(treatmentPathways) {
  treatmentPathways %>%
    group_by(path) %>%
    summarize(freq = sum(freq)) %>%
    mutate(indexYear = "all", sex = "all", ageBin = "all")
}

# sex
stratSex <- function(treatmentPathways) {
  dplyr::bind_rows(
    treatmentPathways %>%
      group_by(.data$path, .data$indexYear, .data$ageBin) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(sex = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$ageBin) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(sex = "all", indexYear = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$indexYear) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(sex = "all", ageBin = "all")
  )
}

stratAgeBin <- function(treatmentPathways) {
  dplyr::bind_rows(
    treatmentPathways %>%
      group_by(.data$path, .data$indexYear, .data$sex) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(ageBin = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$sex) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(ageBin = "all", indexYear = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$indexYear) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(ageBin = "all", sex = "all")
  )
}

stratIndexYear <- function(treatmentPathways) {
  dplyr::bind_rows(
    treatmentPathways %>%
      group_by(.data$path, .data$sex, .data$ageBin) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(indexYear = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$ageBin) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(sex = "all", indexYear = "all"),
    treatmentPathways %>%
      group_by(.data$path, .data$sex) %>%
      summarize(freq = sum(.data$freq), .groups = "drop") %>%
      mutate(indexYear = "all", ageBin = "all")
  )
}

stratisfy <- function(treatmentHistory) {
  treatmentPathways <- stratisfyAgeSexYear(treatmentHistory)
  dplyr::bind_rows(
    treatmentPathways,
    stratAll(treatmentPathways),
    stratAgeBin(treatmentPathways),
    stratSex(treatmentPathways),
    stratIndexYear(treatmentPathways)
  ) %>%
    mutate(sex = tolower(.data$sex)) %>%
    rename(age = "ageBin") %>%
    relocate("path", "freq", "age", "sex", "indexYear")
}

#' getFilteredSubjects
#' 
#' @noRd
#' 
#' @param andromeda andromeda
#' 
#' @return data.frame()
getFilteredSubjects <- function(andromeda) {
  targetCohortId <- andromeda$cohorts %>%
    dplyr::filter(.data$type == "target") %>%
    dplyr::pull(.data$cohortId)
  
  out <- andromeda$currentCohorts %>%
    dplyr::anti_join(andromeda$treatmentHistory, join_by(personId == personId)) %>%
    dplyr::filter(.data$cohortId %in% targetCohortId) %>%
    dplyr::mutate(
      indexYear = floor(.data$startDate / 365.25) + 1970,
      eventCohortName = "None",
      eventCohortId = "-1",
      durationEra = 0,
      eventSeq = 1) %>%
    dplyr::select("personId", "indexYear", "age", "sex", "eventCohortName", "eventCohortId", "eventSeq") %>%
    dplyr::collect()
  
  if (nrow(out) == 0) {
    return(NULL)
  } else {
    return(out)
  }
}
