-
Notifications
You must be signed in to change notification settings - Fork 0
NUTSでカスタムJVPを用いる
Hajime Kawahara edited this page Feb 6, 2021
·
1 revision
NumPyro 0.5.0からForward 微分でNUTSを実行できるようになった。ここではforward微分のカスタムJVPを自分で定義してNUTSを実行してみる。
コードはGISTにおいた。下記で一部説明する。
f(x) = A sin (x)をカスタムJVPで定義している部分が以下。
from jax import custom_jvp
@custom_jvp
def f(x, A):
return A*jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
x, A = primals
ux, uA = tangents
dfdx=A * jnp.cos(x)
dfdA=jnp.sin(x)
primal_out = f(x, A)
tangent_out = dfdx * ux + dfdA * uA
return primal_out, tangent_out
以下がモデル。f(x) = A sin (x - x0)として、A, x0、およびノイズsigmaを推定する。
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
mu=f(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
NUTS部分。forward_mode_differentiation=Trueでforward微分モードとなる。
from jax import random
from numpyro.infer import MCMC, NUTS
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model,forward_mode_differentiation=True)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()