@@ -52,55 +52,87 @@ plot_corr <- function(results) {
52
52
# ' @param results A list containing the MCMC output, typically with at least
53
53
# ' `results[[1]]` for chain data and `results[[2]]` for trace plot data.
54
54
# ' Should include an element `n_chains` for the number of chains.
55
+ # ' @param params A character vector specifying which diagnostics to display.
56
+ # ' Options include: "trace", "gelman", "corr", "ess", "acf", "quantiles", "acceptance".
57
+ # ' @param thin An integer specifying thinning interval for the chains.
58
+ # ' Thinning reduces the number of samples by keeping every `thin`th sample.
55
59
# '
56
- # ' @return This function returns a variety of printed diagnostic outputs:
57
- # ' - Trace plot for parameter convergence
58
- # ' - Gelman-Rubin statistic if multiple chains are present
59
- # ' - Correlation plot between parameters
60
- # ' - Effective sample size for each parameter
61
- # ' - Autocorrelation function (ACF) plot for parameter dependence
62
- # ' - Posterior quantiles for each parameter
63
- # ' - Acceptance rate
60
+ # ' @return Diagnostic outputs based on the selected `params` are printed or plotted.
64
61
# '
65
62
# ' @export
66
63
# '
67
64
# ' @examples
68
65
# ' # Assuming 'results' is a valid MCMC result list with required structure:
69
- # ' MCMC_diag(results)
70
- MCMC_diag <- function (results ) {
66
+ # ' MCMC_diag(results, params = c("trace", "ess") )
67
+ MCMC_diag <- function (results , params = c( " trace " , " gelman " , " corr " , " ess " , " acf " , " quantiles " , " acceptance " ), thin = 1 ) {
71
68
suppressWarnings(suppressMessages({
72
- # Plot trace of parameters from all chains for convergence assessment
73
- # plot(results[[2]])
69
+ # Extract parameter samples and split into chains
70
+ pars <- results [[1 ]]$ pars
71
+ n_chains <- results $ n_chains
72
+ n_samples <- nrow(pars ) / n_chains
73
+
74
+ # Split parameters into individual chains
75
+ chain_list <- split(pars , rep(1 : n_chains , each = n_samples ))
76
+ chains <- lapply(chain_list , function (x ) {
77
+ as.mcmc(matrix (x , nrow = n_samples , ncol = ncol(pars ), dimnames = list (NULL , colnames(pars ))))
78
+ })
79
+ mcmc_chains <- as.mcmc.list(chains )
80
+
81
+ # Apply thinning if requested
82
+ if (thin > 1 ) {
83
+ mcmc_chains <- lapply(mcmc_chains , function (chain ) window(chain , thin = thin ))
84
+ mcmc_chains <- as.mcmc.list(mcmc_chains )
85
+ }
86
+
87
+ # Trace plot
88
+ if (" trace" %in% params ) {
89
+ cat(" \n TRACE PLOT \n " )
90
+ plot(mcmc_chains ) # Trace plot for convergence assessment
91
+ }
74
92
75
- # If multiple chains, create and plot individual chains
76
- if (results $ n_chains > 1 ) {
77
- chains <- plot_chains(results [[1 ]])
78
- plot(chains ) # Visualize convergence with individual chain traces
93
+ # Gelman-Rubin statistic
94
+ if (" gelman" %in% params && n_chains > 1 ) {
79
95
cat(" \n GELMAN-RUBIN STATISTIC \n " )
80
- print(gelman.diag(chains )) # Numerical convergence assessment
96
+ print(gelman.diag(mcmc_chains )) # Numerical convergence assessment
81
97
}
82
98
83
- # Display correlation plot for parameter relationships
84
- print(plot_corr(results ))
99
+ # Correlation plot
100
+ if (" corr" %in% params ) {
101
+ cat(" \n CORRELATION PLOT \n " )
102
+ print(plot_corr(results )) # Retain the previous large and clear plot
103
+ }
85
104
86
- # Print effective sample sizes to assess chain mixing
87
- cat(" \n EFFECTIVE SAMPLE SIZE \n " )
88
- print(effectiveSize(results $ coda_pars ))
105
+ # Effective Sample Size
106
+ if (" ess" %in% params ) {
107
+ cat(" \n EFFECTIVE SAMPLE SIZE \n " )
108
+ print(effectiveSize(mcmc_chains ))
109
+ }
89
110
90
- # Plot autocorrelation function (ACF) to check for dependency in the chains
91
- acf(results $ coda_pars [,- c(1 : 3 )])
111
+ # Autocorrelation Function (ACF)
112
+ if (" acf" %in% params ) {
113
+ cat(" \n AUTOCORRELATION FUNCTION \n " )
114
+ acf(as.matrix(do.call(rbind , mcmc_chains )), main = " Autocorrelation" )
115
+ }
92
116
93
- # Display posterior quantiles for each parameter
94
- cat(" \n POSTERIOR QUANTILES \n " )
95
- print(summary(results $ coda_pars )$ quantiles [- c(1 : 3 ),])
117
+ # Posterior Quantiles
118
+ if (" quantiles" %in% params ) {
119
+ cat(" \n POSTERIOR QUANTILES \n " )
120
+ quantiles <- apply(as.matrix(do.call(rbind , mcmc_chains )), 2 , quantile , probs = c(0.025 , 0.5 , 0.975 ))
121
+ print(quantiles )
122
+ }
96
123
97
- # Calculate and print acceptance rate
98
- cat(" \n ACCEPTANCE RATE \n " )
99
- print(1 - rejectionRate(results $ coda_pars ))
124
+ # Acceptance Rate
125
+ if (" acceptance" %in% params ) {
126
+ cat(" \n ACCEPTANCE RATE \n " )
127
+ acceptance_rates <- 1 - rejectionRate(mcmc_chains )
128
+ print(acceptance_rates )
129
+ }
100
130
}))
101
131
}
102
132
103
133
134
+
135
+
104
136
# ' Plot Posterior Distributions of Estimated Parameters
105
137
# '
106
138
# ' Generates histogram plots for each estimated parameter in the MCMC output, with annotated quantiles and
@@ -125,6 +157,11 @@ MCMC_diag <- function(results) {
125
157
# ' dim_plot <- c(1, 2) # Arrange in 1 row, 2 columns
126
158
# ' post_plot(results, params_to_estimate, dim_plot, show_true = TRUE, true_value = c(param1 = 0.5, param2 = -0.2), title = "Posterior Distributions")
127
159
post_plot <- function (results , params_to_estimate , dim_plot , show_true = TRUE , true_value = NULL , title = " " ) {
160
+ # Check if params_to_estimate is a named vector
161
+ if (is.null(names(params_to_estimate )) || any(names(params_to_estimate ) == " " )) {
162
+ stop(" Error: 'params_to_estimate' must be a named vector where each parameter has an associated name." )
163
+ }
164
+
128
165
# Extract posterior samples
129
166
posterior <- data.frame (results [[1 ]]$ pars )
130
167
post_melt <- reshape2 :: melt(posterior )
@@ -135,15 +172,21 @@ post_plot <- function(results, params_to_estimate, dim_plot, show_true = TRUE, t
135
172
quantiles <- apply(posterior , 2 , function (x ) quantile(x , probs = c(0.005 , 0.025 , 0.50 , 0.975 , 0.995 )))
136
173
quantiles_df <- as.data.frame(t(quantiles ))
137
174
colnames(quantiles_df ) <- c(" 0.5%" , " 2.5%" , " 50%" , " 97.5%" , " 99.5%" )
175
+ rownames(quantiles_df ) <- colnames(posterior )
138
176
139
177
# Set up true values if provided by user
140
- if (show_true ) {
178
+ if (show_true ) {
141
179
vertical_lines <- data.frame (variable = names(true_value ), line_position = as.vector(true_value ))
142
- vertical_lines <- vertical_lines [vertical_lines $ variable %in% params_to_estimate ,]
180
+ vertical_lines <- vertical_lines [vertical_lines $ variable %in% names( params_to_estimate ) ,]
143
181
}
144
182
145
183
# Generate a histogram plot for each parameter
146
184
plot_list <- lapply(names(params_to_estimate ), function (param ) {
185
+ if (! (param %in% colnames(posterior ))) {
186
+ warning(sprintf(" Parameter '%s' not found in posterior samples." , param ))
187
+ return (NULL )
188
+ }
189
+
147
190
p <- ggplot(subset(prior_post , variable == param ), aes(x = value , fill = variable2 )) +
148
191
geom_histogram(data = subset(prior_post , variable2 == " posterior" & variable == param ),
149
192
aes(y = ..density.. ), bins = 100 ) + # Histogram with density normalization
@@ -163,12 +206,21 @@ post_plot <- function(results, params_to_estimate, dim_plot, show_true = TRUE, t
163
206
color = " black" , linetype = " solid" , size = 0.8 )
164
207
}
165
208
209
+ # Remove the legend if show_true is FALSE
210
+ if (! show_true ) {
211
+ p <- p + theme(legend.position = " none" )
212
+ }
213
+
166
214
p
167
215
})
168
216
217
+ # Remove any NULL plots (if parameters are missing)
218
+ plot_list <- Filter(Negate(is.null ), plot_list )
219
+
169
220
# Combine individual parameter plots into a grid
170
221
combined_plot <- patchwork :: wrap_plots(plot_list , ncol = dim_plot [2 ]) +
171
222
patchwork :: plot_annotation(title = title )
172
223
173
224
return (combined_plot )
174
225
}
226
+
0 commit comments