Skip to content

Commit

Permalink
Improved folds
Browse files Browse the repository at this point in the history
  • Loading branch information
nt-williams committed Apr 25, 2024
1 parent b96a9a0 commit 83deac3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion R/lmtp_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lmtp_Task <- R6::R6Class(
self$bounds <- y_bounds(data[[final_outcome(outcome)]], self$outcome_type, bounds)
data$lmtp_id <- create_ids(data, id)
self$id <- data$lmtp_id
self$folds <- setup_cv(data, data$lmtp_id, V)
self$folds <- setup_cv(data, V, data$lmtp_id, final_outcome(outcome), self$outcome_type)
self$multivariate <- is.list(trt)

shifted <- {
Expand Down
11 changes: 9 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ determine_tau <- function(outcome, trt) {
length(outcome)
}

setup_cv <- function(data, id, V = 10) {
out <- origami::make_folds(data, cluster_ids = id, V = V)
setup_cv <- function(data, V = 10, id, strata, outcome_type) {
if (length(unique(id)) == nrow(data) & outcome_type == "binomial") {
strata <- data[[strata]]
strata[is.na(strata)] <- 2
out <- origami::make_folds(data, V = V, strata_ids = strata)
} else {
out <- origami::make_folds(data, cluster_ids = id, V = V)
}

if (V > 1) {
return(out)
}
Expand Down

0 comments on commit 83deac3

Please sign in to comment.