Skip to content

Commit a61664d

Browse files
committed
EIR output from data_sim function now correctly represents the monthly EIR
1 parent 77e7a73 commit a61664d

5 files changed

+106
-40
lines changed

R/inference_diagnostics.R

+84-32
Original file line numberDiff line numberDiff line change
@@ -52,55 +52,87 @@ plot_corr <- function(results) {
5252
#' @param results A list containing the MCMC output, typically with at least
5353
#' `results[[1]]` for chain data and `results[[2]]` for trace plot data.
5454
#' Should include an element `n_chains` for the number of chains.
55+
#' @param params A character vector specifying which diagnostics to display.
56+
#' Options include: "trace", "gelman", "corr", "ess", "acf", "quantiles", "acceptance".
57+
#' @param thin An integer specifying thinning interval for the chains.
58+
#' Thinning reduces the number of samples by keeping every `thin`th sample.
5559
#'
56-
#' @return This function returns a variety of printed diagnostic outputs:
57-
#' - Trace plot for parameter convergence
58-
#' - Gelman-Rubin statistic if multiple chains are present
59-
#' - Correlation plot between parameters
60-
#' - Effective sample size for each parameter
61-
#' - Autocorrelation function (ACF) plot for parameter dependence
62-
#' - Posterior quantiles for each parameter
63-
#' - Acceptance rate
60+
#' @return Diagnostic outputs based on the selected `params` are printed or plotted.
6461
#'
6562
#' @export
6663
#'
6764
#' @examples
6865
#' # Assuming 'results' is a valid MCMC result list with required structure:
69-
#' MCMC_diag(results)
70-
MCMC_diag <- function(results) {
66+
#' MCMC_diag(results, params = c("trace", "ess"))
67+
MCMC_diag <- function(results, params = c("trace", "gelman", "corr", "ess", "acf", "quantiles", "acceptance"), thin = 1) {
7168
suppressWarnings(suppressMessages({
72-
# Plot trace of parameters from all chains for convergence assessment
73-
#plot(results[[2]])
69+
# Extract parameter samples and split into chains
70+
pars <- results[[1]]$pars
71+
n_chains <- results$n_chains
72+
n_samples <- nrow(pars) / n_chains
73+
74+
# Split parameters into individual chains
75+
chain_list <- split(pars, rep(1:n_chains, each = n_samples))
76+
chains <- lapply(chain_list, function(x) {
77+
as.mcmc(matrix(x, nrow = n_samples, ncol = ncol(pars), dimnames = list(NULL, colnames(pars))))
78+
})
79+
mcmc_chains <- as.mcmc.list(chains)
80+
81+
# Apply thinning if requested
82+
if (thin > 1) {
83+
mcmc_chains <- lapply(mcmc_chains, function(chain) window(chain, thin = thin))
84+
mcmc_chains <- as.mcmc.list(mcmc_chains)
85+
}
86+
87+
# Trace plot
88+
if ("trace" %in% params) {
89+
cat("\n TRACE PLOT \n")
90+
plot(mcmc_chains) # Trace plot for convergence assessment
91+
}
7492

75-
# If multiple chains, create and plot individual chains
76-
if (results$n_chains > 1) {
77-
chains <- plot_chains(results[[1]])
78-
plot(chains) # Visualize convergence with individual chain traces
93+
# Gelman-Rubin statistic
94+
if ("gelman" %in% params && n_chains > 1) {
7995
cat("\n GELMAN-RUBIN STATISTIC \n")
80-
print(gelman.diag(chains)) # Numerical convergence assessment
96+
print(gelman.diag(mcmc_chains)) # Numerical convergence assessment
8197
}
8298

83-
# Display correlation plot for parameter relationships
84-
print(plot_corr(results))
99+
# Correlation plot
100+
if ("corr" %in% params) {
101+
cat("\n CORRELATION PLOT \n")
102+
print(plot_corr(results)) # Retain the previous large and clear plot
103+
}
85104

86-
# Print effective sample sizes to assess chain mixing
87-
cat("\n EFFECTIVE SAMPLE SIZE \n")
88-
print(effectiveSize(results$coda_pars))
105+
# Effective Sample Size
106+
if ("ess" %in% params) {
107+
cat("\n EFFECTIVE SAMPLE SIZE \n")
108+
print(effectiveSize(mcmc_chains))
109+
}
89110

90-
# Plot autocorrelation function (ACF) to check for dependency in the chains
91-
acf(results$coda_pars[,-c(1:3)])
111+
# Autocorrelation Function (ACF)
112+
if ("acf" %in% params) {
113+
cat("\n AUTOCORRELATION FUNCTION \n")
114+
acf(as.matrix(do.call(rbind, mcmc_chains)), main = "Autocorrelation")
115+
}
92116

93-
# Display posterior quantiles for each parameter
94-
cat("\n POSTERIOR QUANTILES \n")
95-
print(summary(results$coda_pars)$quantiles[-c(1:3),])
117+
# Posterior Quantiles
118+
if ("quantiles" %in% params) {
119+
cat("\n POSTERIOR QUANTILES \n")
120+
quantiles <- apply(as.matrix(do.call(rbind, mcmc_chains)), 2, quantile, probs = c(0.025, 0.5, 0.975))
121+
print(quantiles)
122+
}
96123

97-
# Calculate and print acceptance rate
98-
cat("\n ACCEPTANCE RATE \n ")
99-
print(1 - rejectionRate(results$coda_pars))
124+
# Acceptance Rate
125+
if ("acceptance" %in% params) {
126+
cat("\n ACCEPTANCE RATE \n")
127+
acceptance_rates <- 1 - rejectionRate(mcmc_chains)
128+
print(acceptance_rates)
129+
}
100130
}))
101131
}
102132

103133

134+
135+
104136
#' Plot Posterior Distributions of Estimated Parameters
105137
#'
106138
#' Generates histogram plots for each estimated parameter in the MCMC output, with annotated quantiles and
@@ -125,6 +157,11 @@ MCMC_diag <- function(results) {
125157
#' dim_plot <- c(1, 2) # Arrange in 1 row, 2 columns
126158
#' post_plot(results, params_to_estimate, dim_plot, show_true = TRUE, true_value = c(param1 = 0.5, param2 = -0.2), title = "Posterior Distributions")
127159
post_plot <- function(results, params_to_estimate, dim_plot, show_true = TRUE, true_value = NULL, title = "") {
160+
# Check if params_to_estimate is a named vector
161+
if (is.null(names(params_to_estimate)) || any(names(params_to_estimate) == "")) {
162+
stop("Error: 'params_to_estimate' must be a named vector where each parameter has an associated name.")
163+
}
164+
128165
# Extract posterior samples
129166
posterior <- data.frame(results[[1]]$pars)
130167
post_melt <- reshape2::melt(posterior)
@@ -135,15 +172,21 @@ post_plot <- function(results, params_to_estimate, dim_plot, show_true = TRUE, t
135172
quantiles <- apply(posterior, 2, function(x) quantile(x, probs = c(0.005, 0.025, 0.50, 0.975, 0.995)))
136173
quantiles_df <- as.data.frame(t(quantiles))
137174
colnames(quantiles_df) <- c("0.5%", "2.5%", "50%", "97.5%", "99.5%")
175+
rownames(quantiles_df) <- colnames(posterior)
138176

139177
# Set up true values if provided by user
140-
if(show_true) {
178+
if (show_true) {
141179
vertical_lines <- data.frame(variable = names(true_value), line_position = as.vector(true_value))
142-
vertical_lines <- vertical_lines[vertical_lines$variable %in% params_to_estimate,]
180+
vertical_lines <- vertical_lines[vertical_lines$variable %in% names(params_to_estimate),]
143181
}
144182

145183
# Generate a histogram plot for each parameter
146184
plot_list <- lapply(names(params_to_estimate), function(param) {
185+
if (!(param %in% colnames(posterior))) {
186+
warning(sprintf("Parameter '%s' not found in posterior samples.", param))
187+
return(NULL)
188+
}
189+
147190
p <- ggplot(subset(prior_post, variable == param), aes(x = value, fill = variable2)) +
148191
geom_histogram(data = subset(prior_post, variable2 == "posterior" & variable == param),
149192
aes(y = ..density..), bins = 100) + # Histogram with density normalization
@@ -163,12 +206,21 @@ post_plot <- function(results, params_to_estimate, dim_plot, show_true = TRUE, t
163206
color = "black", linetype = "solid", size = 0.8)
164207
}
165208

209+
# Remove the legend if show_true is FALSE
210+
if (!show_true) {
211+
p <- p + theme(legend.position = "none")
212+
}
213+
166214
p
167215
})
168216

217+
# Remove any NULL plots (if parameters are missing)
218+
plot_list <- Filter(Negate(is.null), plot_list)
219+
169220
# Combine individual parameter plots into a grid
170221
combined_plot <- patchwork::wrap_plots(plot_list, ncol = dim_plot[2]) +
171222
patchwork::plot_annotation(title = title)
172223

173224
return(combined_plot)
174225
}
226+

R/model_simulation.R

+6-6
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ data_sim <- function(model, param_inputs, start_date, end_date,
146146
month_no <- 0:(n_months - 1)
147147

148148
if(return_EIR){
149-
EIR <- x[mod$info()$index$EIR2,,][month_ind]
150-
EIR <- EIR[1:n_months]
149+
EIR_monthly <- x[mod$info()$index$EIR_monthly,,][month_ind]
150+
EIR_monthly <- EIR_monthly[1:n_months]
151151
inc_df <- data.frame(date_ymd = month, month_no, inc_A, inc_C,
152-
inc = inc_A + inc_C, EIR = EIR)
152+
inc = inc_A + inc_C, EIR_monthly = EIR_monthly)
153153
} else{inc_df <- data.frame(date_ymd = month, month_no, inc_A, inc_C, inc = inc_A + inc_C)}
154154

155155
} else {
@@ -174,9 +174,9 @@ data_sim <- function(model, param_inputs, start_date, end_date,
174174
week_no <- 0:(n_weeks - 1)
175175

176176
if(return_EIR){
177-
EIR <- x[mod$info()$index$EIR2,,][wk_ind]
178-
EIR <- EIR[1:n_weeks]
179-
inc_df <- data.frame(week, week_no, inc_A, inc_C, inc = inc_A + inc_C, EIR = EIR)
177+
EIR_monthly <- x[mod$info()$index$EIR_monthly,,][wk_ind]
178+
EIR_monthly <- EIR_monthly[1:n_weeks]
179+
inc_df <- data.frame(week, week_no, inc_A, inc_C, inc = inc_A + inc_C, EIR_monthly = EIR_monthly)
180180
} else{inc_df <- data.frame(week, week_no, inc_A, inc_C, inc = inc_A + inc_C)}
181181
}
182182

R/model_visualizations.R

+6-2
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ create_sim_df <- function(results, n, dates_sim, dates_obs, model){
260260
#' met = met_data, climate_facet = TRUE, prewarm_years = 2)
261261
plot_observed_vs_simulated <- function(results, obs_cases, start_date, end_date, model,
262262
add_ribbon = FALSE, n_samples = 100, groups = c("inc_A", "inc_C", "inc"),
263-
met = NULL, climate_facet = FALSE, prewarm_years = 2, days_per_year = 360) {
263+
met = NULL, climate_facet = FALSE, prewarm_years = 2, days_per_year = 360,
264+
title = NULL) {
264265

265266
prewarm_start_date <- paste0(year(as.Date(start_date)) - prewarm_years, "-", format(as.Date(start_date), "%m-%d"))
266267

@@ -317,7 +318,10 @@ plot_observed_vs_simulated <- function(results, obs_cases, start_date, end_date,
317318
geom_line(data = subset(combined_data_long, grepl("_sim$", variable)), aes(y = value),
318319
color = "red", size = 1) + # Simulated data as line
319320
facet_wrap(~ gsub("(_obs|_sim)", "", variable), scales = "free", ncol = 1, labeller = as_labeller(facet_labels)) +
320-
labs(title = "Observed vs Simulated Monthly Malaria Cases",
321+
if(is.null(title)){
322+
title = "Observed vs Simulated Monthly Malaria Cases"
323+
}
324+
labs(title = title,
321325
x = "Date",
322326
y = "Number of Monthly Malaria Cases") +
323327
theme_minimal() +

R/observation_functions.R

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ generate_incidence_comparison <- function(month, age_for_inf, incidence_df) {
6868

6969
if(month && age_for_inf == 'u5') {
7070
return(function(state, observed, pars = c("size")) {
71+
if (is.na(observed$inc_C)) {
72+
return(NULL)
73+
}
7174
incidence_observed_C <- observed$inc_C # this is the observed data
7275
mu_C <- state["month_inc_C", , drop = TRUE] # this is "x"
7376
size <- pars$size

inst/models/model_det_1.R

+7
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ C <- a
140140
# update(EIR) <- (A * p_HM * X / (B + (C * p_HM * X)))
141141
EIR <- (A * p_HM * X / (B + (C * p_HM * X)))
142142
#update(EIR) <- (A * p_HM * X / (B + (C * p_HM * X)))
143+
#EIR2 <- (A * p_HM * X / (B + (C * p_HM * X)))
144+
145+
# Define monthly EIR
146+
initial(EIR_monthly) <- EIR
147+
update(EIR_monthly) <- if ((step) %% steps_per_month == 0) EIR2 else EIR_monthly + EIR
148+
149+
# Define Daily EIR
143150
initial(EIR2) <- EIR
144151
update(EIR2) <- EIR
145152

0 commit comments

Comments
 (0)