Skip to content

Commit 2b21261

Browse files
committed
Added changes to the MCMC sampling code. Added an MCMC example.
1 parent 7538e2b commit 2b21261

File tree

4 files changed

+175
-14
lines changed

4 files changed

+175
-14
lines changed

examples/mcmc/warmup_mcmc.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from manim import *
2+
3+
import scipy.stats
4+
from manim_ml.diffusion.mcmc import MCMCAxes
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
8+
plt.style.use('dark_background')
9+
10+
# Make the specific scene
11+
config.pixel_height = 720
12+
config.pixel_width = 720
13+
config.frame_height = 7.0
14+
config.frame_width = 7.0
15+
16+
class MCMCWarmupScene(Scene):
17+
18+
def construct(self):
19+
# Define the Gaussian Mixture likelihood
20+
def gaussian_mm_logpdf(x):
21+
"""Gaussian Mixture Model Log PDF"""
22+
# Choose two arbitrary Gaussians
23+
# Big Gaussian
24+
big_gaussian_pdf = scipy.stats.multivariate_normal(
25+
mean=[-0.5, -0.5],
26+
cov=[1.0, 1.0]
27+
).pdf(x)
28+
# Little Gaussian
29+
little_gaussian_pdf = scipy.stats.multivariate_normal(
30+
mean=[2.3, 1.9],
31+
cov=[0.3, 0.3]
32+
).pdf(x)
33+
# Sum their likelihoods and take the log
34+
logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf)
35+
36+
return logpdf
37+
38+
# Generate a bunch of true samples
39+
true_samples = []
40+
# Generate samples for little gaussian
41+
little_gaussian_samples = np.random.multivariate_normal(
42+
mean=[2.3, 1.9],
43+
cov=[[0.3, 0.0], [0.0, 0.3]],
44+
size=(10000)
45+
)
46+
big_gaussian_samples = np.random.multivariate_normal(
47+
mean=[-0.5, -0.5],
48+
cov=[[1.0, 0.0], [0.0, 1.0]],
49+
size=(10000)
50+
)
51+
true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples))
52+
# Make the MCMC axes
53+
axes = MCMCAxes(
54+
x_range=[-5, 5],
55+
y_range=[-5, 5],
56+
x_length=7.0,
57+
y_length=7.0
58+
)
59+
axes.move_to(ORIGIN)
60+
self.play(
61+
Create(axes)
62+
)
63+
# Make the chain sampling animation
64+
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
65+
log_prob_fn=gaussian_mm_logpdf,
66+
true_samples=true_samples,
67+
sampling_kwargs={
68+
"iterations": 2000,
69+
"warm_up": 50,
70+
"initial_location": np.array([-3.5, 3.5]),
71+
"sampling_seed": 4
72+
},
73+
)
74+
self.play(chain_sampling_animation)
75+
self.wait(3)

manim_ml/diffusion/mcmc.py

+97-11
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
Tool for animating Markov Chain Monte Carlo simulations in 2D.
33
"""
44
from manim import *
5+
import matplotlib
6+
import matplotlib.pyplot as plt
7+
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
58
import numpy as np
69
import scipy
710
import scipy.stats
811
from tqdm import tqdm
12+
import seaborn as sns
913

1014
from manim_ml.utils.mobjects.probability import GaussianDistribution
1115

1216
######################## MCMC Algorithms #########################
1317

14-
def gaussian_proposal(x, sigma=1.0):
18+
def gaussian_proposal(x, sigma=0.3):
1519
"""
1620
Gaussian proposal distribution.
1721
@@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
9498
iterations=25,
9599
warm_up=0,
96100
ndim=2,
101+
sampling_seed=1
97102
):
98103
"""Samples using a Metropolis-Hastings sampler.
99104
@@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
119124
candidate_samples: np.ndarray
120125
numpy array of the candidate samples for each time step
121126
"""
122-
assert warm_up == 0, "Warmup not implemented yet"
127+
np.random.seed(sampling_seed)
123128
# initialize chain, acceptance rate and lnprob
124129
chain = np.zeros((iterations, ndim))
125130
proposals = np.zeros((iterations, ndim))
@@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
156161

157162
#################### MCMC Visualization Tools ######################
158163

164+
def make_dist_image_mobject_from_samples(samples, ylim, xlim):
165+
# Make the plot
166+
matplotlib.use('Agg')
167+
plt.figure(figsize=(10,10), dpi=100)
168+
print(np.shape(samples[:, 0]))
169+
displot = sns.displot(
170+
x=samples[:, 0],
171+
y=samples[:, 1],
172+
cmap="Reds",
173+
kind="kde",
174+
norm=matplotlib.colors.LogNorm()
175+
)
176+
plt.ylim(ylim[0], ylim[1])
177+
plt.xlim(xlim[0], xlim[1])
178+
plt.axis('off')
179+
fig = displot.fig
180+
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
181+
182+
return image_mobject
183+
184+
class Uncreate(Create):
185+
def __init__(
186+
self,
187+
mobject,
188+
reverse_rate_function: bool = True,
189+
introducer: bool = True,
190+
remover: bool = True,
191+
**kwargs,
192+
) -> None:
193+
super().__init__(
194+
mobject,
195+
reverse_rate_function=reverse_rate_function,
196+
introducer=introducer,
197+
remover=remover,
198+
**kwargs,
199+
)
200+
159201
class MCMCAxes(Group):
160202
"""Container object for visualizing MCMC on a 2D axis"""
161203

@@ -166,7 +208,7 @@ def __init__(
166208
accept_line_color=GREEN,
167209
reject_line_color=RED,
168210
line_color=BLUE,
169-
line_stroke_width=3,
211+
line_stroke_width=2,
170212
x_range=[-3, 3],
171213
y_range=[-3, 3],
172214
x_length=5,
@@ -180,6 +222,10 @@ def __init__(
180222
self.line_color = line_color
181223
self.line_stroke_width = line_stroke_width
182224
# Make the axes
225+
self.x_length = x_length
226+
self.y_length = y_length
227+
self.x_range = x_range
228+
self.y_range = y_range
183229
self.axes = Axes(
184230
x_range=x_range,
185231
y_range=y_range,
@@ -290,6 +336,7 @@ def visualize_metropolis_hastings_chain_sampling(
290336
log_prob_fn=MultidimensionalGaussianPosterior(),
291337
prop_fn=gaussian_proposal,
292338
show_dots=False,
339+
true_samples=None,
293340
sampling_kwargs={},
294341
):
295342
"""
@@ -318,12 +365,14 @@ def visualize_metropolis_hastings_chain_sampling(
318365
"""
319366
# Compute the chain samples using a Metropolis Hastings Sampler
320367
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
321-
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
368+
log_prob_fn=log_prob_fn,
369+
prop_fn=prop_fn,
370+
**sampling_kwargs
322371
)
323372
# print(f"MCMC samples: {mcmc_samples}")
324373
# print(f"Candidate samples: {candidate_samples}")
325374
# Make the animation for visualizing the chain
326-
animations = []
375+
transition_animations = []
327376
# Place the initial point
328377
current_point = mcmc_samples[0]
329378
current_point = Dot(
@@ -332,10 +381,11 @@ def visualize_metropolis_hastings_chain_sampling(
332381
radius=self.dot_radius,
333382
)
334383
create_initial_point = Create(current_point)
335-
animations.append(create_initial_point)
384+
transition_animations.append(create_initial_point)
336385
# Show the initial point's proposal distribution
337386
# NOTE: visualize the warm up and the iterations
338387
lines = []
388+
warmup_points = []
339389
num_iterations = len(mcmc_samples) + len(warm_up_samples)
340390
for iteration in tqdm(range(1, num_iterations)):
341391
next_sample = mcmc_samples[iteration]
@@ -362,14 +412,50 @@ def visualize_metropolis_hastings_chain_sampling(
362412
transition_animation, line = self.make_transition_animation(
363413
current_point, next_point, candidate_point
364414
)
415+
# Save assets
365416
lines.append(line)
366-
animations.append(transition_animation)
417+
if iteration < len(warm_up_samples):
418+
warmup_points.append(candidate_point)
419+
420+
# Add the transition animation
421+
transition_animations.append(transition_animation)
367422
# Setup for next iteration
368423
current_point = next_point
369-
# Make the final animation group
370-
animation_group = AnimationGroup(
371-
*animations,
424+
# Overall MCMC animation
425+
# 1. Fade in the distribution
426+
image_mobject = make_dist_image_mobject_from_samples(
427+
true_samples,
428+
xlim=(self.x_range[0], self.x_range[1]),
429+
ylim=(self.y_range[0], self.y_range[1])
430+
)
431+
image_mobject.scale_to_fit_height(
432+
self.y_length
433+
)
434+
image_mobject.move_to(self.axes)
435+
fade_in_distribution = FadeIn(
436+
image_mobject,
437+
run_time=0.5
438+
)
439+
# 2. Start sampling the chain
440+
chain_sampling_animation = AnimationGroup(
441+
*transition_animations,
442+
lag_ratio=1.0,
443+
run_time=5.0
444+
)
445+
# 3. Convert the chain to points, excluding the warmup
446+
lines = VGroup(*lines)
447+
warm_up_points = VGroup(*warmup_points)
448+
fade_out_lines_and_warmup = AnimationGroup(
449+
Uncreate(lines),
450+
Uncreate(warm_up_points),
451+
lag_ratio=0.0
452+
)
453+
# Make the final animation
454+
animation_group = Succession(
455+
fade_in_distribution,
456+
chain_sampling_animation,
457+
fade_out_lines_and_warmup,
372458
lag_ratio=1.0
373459
)
374460

375-
return animation_group, VGroup(*lines)
461+
return animation_group

manim_ml/utils/mobjects/plotting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
1414
matplotlib figure
1515
"""
1616
fig.tight_layout(pad=0)
17-
plt.axis('off')
17+
# plt.axis('off')
1818
fig.canvas.draw()
1919
# Save data into a buffer
2020
image_buffer = io.BytesIO()

tests/test_mcmc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# Make the specific scene
1616
config.pixel_height = 1200
1717
config.pixel_width = 1200
18-
config.frame_height = 10.0
19-
config.frame_width = 10.0
18+
config.frame_height = 7.0
19+
config.frame_width = 7.0
2020

2121
def test_metropolis_hastings_sampler(iterations=100):
2222
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)

0 commit comments

Comments
 (0)