2
2
Tool for animating Markov Chain Monte Carlo simulations in 2D.
3
3
"""
4
4
from manim import *
5
+ import matplotlib
6
+ import matplotlib .pyplot as plt
7
+ from manim_ml .utils .mobjects .plotting import convert_matplotlib_figure_to_image_mobject
5
8
import numpy as np
6
9
import scipy
7
10
import scipy .stats
8
11
from tqdm import tqdm
12
+ import seaborn as sns
9
13
10
14
from manim_ml .utils .mobjects .probability import GaussianDistribution
11
15
12
16
######################## MCMC Algorithms #########################
13
17
14
- def gaussian_proposal (x , sigma = 1.0 ):
18
+ def gaussian_proposal (x , sigma = 0.3 ):
15
19
"""
16
20
Gaussian proposal distribution.
17
21
@@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
94
98
iterations = 25 ,
95
99
warm_up = 0 ,
96
100
ndim = 2 ,
101
+ sampling_seed = 1
97
102
):
98
103
"""Samples using a Metropolis-Hastings sampler.
99
104
@@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
119
124
candidate_samples: np.ndarray
120
125
numpy array of the candidate samples for each time step
121
126
"""
122
- assert warm_up == 0 , "Warmup not implemented yet"
127
+ np . random . seed ( sampling_seed )
123
128
# initialize chain, acceptance rate and lnprob
124
129
chain = np .zeros ((iterations , ndim ))
125
130
proposals = np .zeros ((iterations , ndim ))
@@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
156
161
157
162
#################### MCMC Visualization Tools ######################
158
163
164
+ def make_dist_image_mobject_from_samples (samples , ylim , xlim ):
165
+ # Make the plot
166
+ matplotlib .use ('Agg' )
167
+ plt .figure (figsize = (10 ,10 ), dpi = 100 )
168
+ print (np .shape (samples [:, 0 ]))
169
+ displot = sns .displot (
170
+ x = samples [:, 0 ],
171
+ y = samples [:, 1 ],
172
+ cmap = "Reds" ,
173
+ kind = "kde" ,
174
+ norm = matplotlib .colors .LogNorm ()
175
+ )
176
+ plt .ylim (ylim [0 ], ylim [1 ])
177
+ plt .xlim (xlim [0 ], xlim [1 ])
178
+ plt .axis ('off' )
179
+ fig = displot .fig
180
+ image_mobject = convert_matplotlib_figure_to_image_mobject (fig )
181
+
182
+ return image_mobject
183
+
184
+ class Uncreate (Create ):
185
+ def __init__ (
186
+ self ,
187
+ mobject ,
188
+ reverse_rate_function : bool = True ,
189
+ introducer : bool = True ,
190
+ remover : bool = True ,
191
+ ** kwargs ,
192
+ ) -> None :
193
+ super ().__init__ (
194
+ mobject ,
195
+ reverse_rate_function = reverse_rate_function ,
196
+ introducer = introducer ,
197
+ remover = remover ,
198
+ ** kwargs ,
199
+ )
200
+
159
201
class MCMCAxes (Group ):
160
202
"""Container object for visualizing MCMC on a 2D axis"""
161
203
@@ -166,7 +208,7 @@ def __init__(
166
208
accept_line_color = GREEN ,
167
209
reject_line_color = RED ,
168
210
line_color = BLUE ,
169
- line_stroke_width = 3 ,
211
+ line_stroke_width = 2 ,
170
212
x_range = [- 3 , 3 ],
171
213
y_range = [- 3 , 3 ],
172
214
x_length = 5 ,
@@ -180,6 +222,10 @@ def __init__(
180
222
self .line_color = line_color
181
223
self .line_stroke_width = line_stroke_width
182
224
# Make the axes
225
+ self .x_length = x_length
226
+ self .y_length = y_length
227
+ self .x_range = x_range
228
+ self .y_range = y_range
183
229
self .axes = Axes (
184
230
x_range = x_range ,
185
231
y_range = y_range ,
@@ -290,6 +336,7 @@ def visualize_metropolis_hastings_chain_sampling(
290
336
log_prob_fn = MultidimensionalGaussianPosterior (),
291
337
prop_fn = gaussian_proposal ,
292
338
show_dots = False ,
339
+ true_samples = None ,
293
340
sampling_kwargs = {},
294
341
):
295
342
"""
@@ -318,12 +365,14 @@ def visualize_metropolis_hastings_chain_sampling(
318
365
"""
319
366
# Compute the chain samples using a Metropolis Hastings Sampler
320
367
mcmc_samples , warm_up_samples , candidate_samples = metropolis_hastings_sampler (
321
- log_prob_fn = log_prob_fn , prop_fn = prop_fn , ** sampling_kwargs
368
+ log_prob_fn = log_prob_fn ,
369
+ prop_fn = prop_fn ,
370
+ ** sampling_kwargs
322
371
)
323
372
# print(f"MCMC samples: {mcmc_samples}")
324
373
# print(f"Candidate samples: {candidate_samples}")
325
374
# Make the animation for visualizing the chain
326
- animations = []
375
+ transition_animations = []
327
376
# Place the initial point
328
377
current_point = mcmc_samples [0 ]
329
378
current_point = Dot (
@@ -332,10 +381,11 @@ def visualize_metropolis_hastings_chain_sampling(
332
381
radius = self .dot_radius ,
333
382
)
334
383
create_initial_point = Create (current_point )
335
- animations .append (create_initial_point )
384
+ transition_animations .append (create_initial_point )
336
385
# Show the initial point's proposal distribution
337
386
# NOTE: visualize the warm up and the iterations
338
387
lines = []
388
+ warmup_points = []
339
389
num_iterations = len (mcmc_samples ) + len (warm_up_samples )
340
390
for iteration in tqdm (range (1 , num_iterations )):
341
391
next_sample = mcmc_samples [iteration ]
@@ -362,14 +412,50 @@ def visualize_metropolis_hastings_chain_sampling(
362
412
transition_animation , line = self .make_transition_animation (
363
413
current_point , next_point , candidate_point
364
414
)
415
+ # Save assets
365
416
lines .append (line )
366
- animations .append (transition_animation )
417
+ if iteration < len (warm_up_samples ):
418
+ warmup_points .append (candidate_point )
419
+
420
+ # Add the transition animation
421
+ transition_animations .append (transition_animation )
367
422
# Setup for next iteration
368
423
current_point = next_point
369
- # Make the final animation group
370
- animation_group = AnimationGroup (
371
- * animations ,
424
+ # Overall MCMC animation
425
+ # 1. Fade in the distribution
426
+ image_mobject = make_dist_image_mobject_from_samples (
427
+ true_samples ,
428
+ xlim = (self .x_range [0 ], self .x_range [1 ]),
429
+ ylim = (self .y_range [0 ], self .y_range [1 ])
430
+ )
431
+ image_mobject .scale_to_fit_height (
432
+ self .y_length
433
+ )
434
+ image_mobject .move_to (self .axes )
435
+ fade_in_distribution = FadeIn (
436
+ image_mobject ,
437
+ run_time = 0.5
438
+ )
439
+ # 2. Start sampling the chain
440
+ chain_sampling_animation = AnimationGroup (
441
+ * transition_animations ,
442
+ lag_ratio = 1.0 ,
443
+ run_time = 5.0
444
+ )
445
+ # 3. Convert the chain to points, excluding the warmup
446
+ lines = VGroup (* lines )
447
+ warm_up_points = VGroup (* warmup_points )
448
+ fade_out_lines_and_warmup = AnimationGroup (
449
+ Uncreate (lines ),
450
+ Uncreate (warm_up_points ),
451
+ lag_ratio = 0.0
452
+ )
453
+ # Make the final animation
454
+ animation_group = Succession (
455
+ fade_in_distribution ,
456
+ chain_sampling_animation ,
457
+ fade_out_lines_and_warmup ,
372
458
lag_ratio = 1.0
373
459
)
374
460
375
- return animation_group , VGroup ( * lines )
461
+ return animation_group
0 commit comments