Skip to content

Latest commit

 

History

History
60 lines (45 loc) · 1.62 KB

README.md

File metadata and controls

60 lines (45 loc) · 1.62 KB

Mixture Density Network


Implementation of Mixture Density Network in PyTorch

An MDN models the conditional distribution over a scalar response as a mixture of Gaussians.

where the mixture distribution parameters are output by a neural network, trained to maximize overall log-likelihood. The set of mixture distribution parameters is the following.

In order to predict the response as a multivariate Gaussian distribution (for example, in [2]), we assume a fully factored distribution (i.e. a diagonal covariance matrix) and predict each dimension separately. We assume each component of the distribution is statistically independent.

Usage

import torch 
import torch.nn as nn
import torch.optim as optim
from models.mdn import MixtureDensityNetworks
from utils.loss import MDN_loss
from utils.utils import sample

model=nn.Sequential(
	nn.Linear(1,20),
	nn.Tanh(),
	MixtureDensityNetworks(20,1,5),
)

opt=optm.Adam(model.parameters())

for e in range(num_epochs):
	opt.zero_grad()
	pi,mu,sigma=model.forward(x_var)
	loss=MDN_loss(t_var,pi,mu,sigma)
	loss.backward()
	opt.step()

pi,mu,sigma=model.forward(mini)
samples=samples(pi,mu,sigma)
ad

Original data

ad

Inverse data

ad

Inverse and sampled data

References

Bishop, C. M. Mixture density networks. (1994).