Skip to content

NUTSでカスタムJVPを用いる

Hajime Kawahara edited this page Feb 6, 2021 · 1 revision

NumPyro 0.5.0からForward 微分でNUTSを実行できるようになった。ここではforward微分のカスタムJVPを自分で定義してNUTSを実行してみる。

コードはGISTにおいた。下記で一部説明する。

f(x) = A sin (x)をカスタムJVPで定義している部分が以下。

from jax import custom_jvp
@custom_jvp
def f(x, A):
    return A*jnp.sin(x)

@f.defjvp
def f_jvp(primals, tangents):
    x, A = primals
    ux, uA = tangents
    dfdx=A * jnp.cos(x)
    dfdA=jnp.sin(x)
    primal_out = f(x, A)
    tangent_out = dfdx * ux  + dfdA * uA
    return primal_out, tangent_out

以下がモデル。f(x) = A sin (x - x0)として、A, x0、およびノイズsigmaを推定する。

def model(x,y):
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
    A = numpyro.sample('A', dist.Exponential(1.))
    mu=f(x-x0,A)
    numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

NUTS部分。forward_mode_differentiation=Trueでforward微分モードとなる。

from jax import random
from numpyro.infer import MCMC, NUTS

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model,forward_mode_differentiation=True)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()
Clone this wiki locally