Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nt-williams committed Apr 25, 2024
1 parent 2be7931 commit b96a9a0
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 37 deletions.
16 changes: 6 additions & 10 deletions R/density_ratios.R
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
cf_r <- function(Task, learners, mtp, lrnr_folds, trim, full_fits, pb) {
fopts <- options("lmtp.bound", "lmtp.trt.length")

cf_r <- function(Task, learners, mtp, control, pb) {
out <- vector("list", length = length(Task$folds))
for (fold in seq_along(Task$folds)) {
out[[fold]] <- future::future({
options(fopts)

estimate_r(
get_folded_data(Task$natural, Task$folds, fold),
get_folded_data(Task$shifted, Task$folds, fold),
Task$trt, Task$cens, Task$risk, Task$tau, Task$node_list$trt,
learners, pb, mtp, lrnr_folds, full_fits
learners, pb, mtp, control
)
},
seed = TRUE)
}

trim_ratios(recombine_ratios(future::value(out), Task$folds), trim)
trim_ratios(recombine_ratios(future::value(out), Task$folds), control$.trim)
}

estimate_r <- function(natural, shifted, trt, cens, risk, tau, node_list, learners, pb, mtp, lrnr_folds, full_fits) {
estimate_r <- function(natural, shifted, trt, cens, risk, tau, node_list, learners, pb, mtp, control) {
densratios <- matrix(nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)

Expand Down Expand Up @@ -47,10 +43,10 @@ estimate_r <- function(natural, shifted, trt, cens, risk, tau, node_list, learne
learners,
"binomial",
stacked[jrt & drt, ]$lmtp_id,
lrnr_folds
control$.learners_trt_folds
)

if (full_fits) {
if (control$.return_full_fits) {
fits[[t]] <- fit
} else {
fits[[t]] <- extract_sl_weights(fit)
Expand Down
19 changes: 9 additions & 10 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(Task$tau*folds*2)

ratios <- cf_r(Task, learners_trt, mtp, control$.learners_trt_folds, control$.trim, control$.return_full_fits, pb)
estims <- cf_tmle(Task, "tmp_lmtp_scaled_outcome", ratios$ratios, learners_outcome, control$.learners_outcome_folds, control$.return_full_fits, pb)
ratios <- cf_r(Task, learners_trt, mtp, control, pb)
estims <- cf_tmle(Task, "tmp_lmtp_scaled_outcome",
ratios$ratios, learners_outcome, control, pb)

theta_dr(
list(
Expand Down Expand Up @@ -319,10 +320,9 @@ lmtp_sdr <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(Task$tau*folds*2)

ratios <- cf_r(Task, learners_trt, mtp,
control$.learners_trt_folds, control$.trim, control$.return_full_fits, pb)
estims <- cf_sdr(Task, "tmp_lmtp_scaled_outcome", ratios$ratios, learners_outcome,
control$.learners_outcome_folds, control$.return_full_fits, pb)
ratios <- cf_r(Task, learners_trt, mtp, control, pb)
estims <- cf_sdr(Task, "tmp_lmtp_scaled_outcome", ratios$ratios,
learners_outcome, control, pb)

theta_dr(
list(
Expand Down Expand Up @@ -469,8 +469,8 @@ lmtp_sub <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens

pb <- progressr::progressor(Task$tau*folds)

estims <- cf_sub(Task, "tmp_lmtp_scaled_outcome", learners,
control$.learners_outcome_folds, control$.return_full_fits, pb)
estims <- cf_sub(Task, "tmp_lmtp_scaled_outcome",
learners, control, pb)

theta_sub(
eta = list(
Expand Down Expand Up @@ -620,8 +620,7 @@ lmtp_ipw <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens

pb <- progressr::progressor(Task$tau*folds)

ratios <- cf_r(Task, learners, mtp, control$.learners_trt_folds,
control$.trim, control$.return_full_fits, pb)
ratios <- cf_r(Task, learners, mtp, control, pb)

theta_ipw(
eta = list(
Expand Down
10 changes: 5 additions & 5 deletions R/gcomp.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cf_sub <- function(Task, outcome, learners, lrnr_folds, full_fits, pb) {
cf_sub <- function(Task, outcome, learners, control, pb) {
out <- vector("list", length = length(Task$folds))
for (fold in seq_along(Task$folds)) {
out[[fold]] <- future::future({
Expand All @@ -8,7 +8,7 @@ cf_sub <- function(Task, outcome, learners, lrnr_folds, full_fits, pb) {
Task$trt, outcome,
Task$node_list$outcome, Task$cens,
Task$risk, Task$tau, Task$outcome_type,
learners, lrnr_folds, pb, full_fits
learners, control, pb
)
},
seed = TRUE)
Expand All @@ -23,7 +23,7 @@ cf_sub <- function(Task, outcome, learners, lrnr_folds, full_fits, pb) {
}

estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
tau, outcome_type, learners, lrnr_folds, pb, full_fits) {
tau, outcome_type, learners, control, pb) {

m <- matrix(nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)
Expand Down Expand Up @@ -51,10 +51,10 @@ estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
learners,
outcome_type,
id = natural$train[i & rt, ][["lmtp_id"]],
lrnr_folds
control$.learners_outcome_folds
)

if (full_fits) {
if (control$.return_full_fits) {
fits[[t]] <- fit
} else {
fits[[t]] <- extract_sl_weights(fit)
Expand Down
14 changes: 7 additions & 7 deletions R/sdr.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cf_sdr <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb) {
cf_sdr <- function(Task, outcome, ratios, learners, control, pb) {
out <- vector("list", length = length(Task$folds))
for (fold in seq_along(Task$folds)) {
out[[fold]] <- future::future({
Expand All @@ -8,7 +8,7 @@ cf_sdr <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb) {
Task$trt, outcome, Task$node_list$outcome,
Task$cens, Task$risk, Task$tau, Task$outcome_type,
get_folded_data(ratios, Task$folds, fold)$train,
learners, lrnr_folds, pb, full_fits
learners, control, pb
)
},
seed = TRUE)
Expand All @@ -22,7 +22,7 @@ cf_sdr <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb) {
}

estimate_sdr <- function(natural, shifted, trt, outcome, node_list, cens, risk, tau,
outcome_type, ratios, learners, lrnr_folds, pb, full_fits) {
outcome_type, ratios, learners, control, pb) {

m_natural_train <- m_shifted_train <-
cbind(matrix(nrow = nrow(natural$train), ncol = tau), natural$train[[outcome]])
Expand All @@ -49,9 +49,9 @@ estimate_sdr <- function(natural, shifted, trt, outcome, node_list, cens, risk,
learners,
outcome_type,
id = natural$train[i & rt, ][["lmtp_id"]],
lrnr_folds)
control$.learners_outcome_folds)

if (full_fits) {
if (control$.return_full_fits) {
fits[[t]] <- fit
} else {
fits[[t]] <- extract_sl_weights(fit)
Expand All @@ -73,9 +73,9 @@ estimate_sdr <- function(natural, shifted, trt, outcome, node_list, cens, risk,
learners,
"continuous",
id = natural$train[i & rt, ][["lmtp_id"]],
lrnr_folds)
control$.learners_outcome_folds)

if (full_fits) {
if (control$.return_full_fits) {
fits[[t]] <- fit
} else {
fits[[t]] <- extract_sl_weights(fit)
Expand Down
10 changes: 5 additions & 5 deletions R/tmle.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cf_tmle <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb) {
cf_tmle <- function(Task, outcome, ratios, learners, control, pb) {
out <- vector("list", length = length(Task$folds))
ratios <- matrix(t(apply(ratios, 1, cumprod)),
nrow = nrow(ratios),
Expand All @@ -12,7 +12,7 @@ cf_tmle <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb)
Task$trt, outcome, Task$node_list$outcome, Task$cens, Task$risk,
Task$tau, Task$outcome_type,
get_folded_data(ratios, Task$folds, fold)$train,
learners, lrnr_folds, pb, full_fits
learners, control, pb
)
},
seed = TRUE)
Expand All @@ -28,7 +28,7 @@ cf_tmle <- function(Task, outcome, ratios, learners, lrnr_folds, full_fits, pb)
}

estimate_tmle <- function(natural, shifted, trt, outcome, node_list, cens,
risk, tau, outcome_type, ratios, learners, lrnr_folds, pb, full_fits) {
risk, tau, outcome_type, ratios, learners, control, pb) {
m_natural_train <- m_shifted_train <- matrix(nrow = nrow(natural$train), ncol = tau)
m_natural_valid <- m_shifted_valid <- matrix(nrow = nrow(natural$valid), ncol = tau)

Expand Down Expand Up @@ -56,10 +56,10 @@ estimate_tmle <- function(natural, shifted, trt, outcome, node_list, cens,
learners,
outcome_type,
id = natural$train[i & rt,][["lmtp_id"]],
lrnr_folds
control$.learners_outcome_folds
)

if (full_fits) {
if (control$.return_full_fits) {
fits[[t]] <- fit
} else {
fits[[t]] <- extract_sl_weights(fit)
Expand Down

0 comments on commit b96a9a0

Please sign in to comment.