Skip to content

Mixture Density Network implementation PyTorch

Notifications You must be signed in to change notification settings

pdogr/pytorch-MDN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mixture Density Network


Implementation of Mixture Density Network in PyTorch

An MDN models the conditional distribution over a scalar response as a mixture of Gaussians.

where the mixture distribution parameters are output by a neural network, trained to maximize overall log-likelihood. The set of mixture distribution parameters is the following.

In order to predict the response as a multivariate Gaussian distribution (for example, in [2]), we assume a fully factored distribution (i.e. a diagonal covariance matrix) and predict each dimension separately. We assume each component of the distribution is statistically independent.

Usage

import torch 
import torch.nn as nn
import torch.optim as optim
from models.mdn import MixtureDensityNetworks
from utils.loss import MDN_loss
from utils.utils import sample

model=nn.Sequential(
	nn.Linear(1,20),
	nn.Tanh(),
	MixtureDensityNetworks(20,1,5),
)

opt=optm.Adam(model.parameters())

for e in range(num_epochs):
	opt.zero_grad()
	pi,mu,sigma=model.forward(x_var)
	loss=MDN_loss(t_var,pi,mu,sigma)
	loss.backward()
	opt.step()

pi,mu,sigma=model.forward(mini)
samples=samples(pi,mu,sigma)
ad

Original data

ad

Inverse data

ad

Inverse and sampled data

References

Bishop, C. M. Mixture density networks. (1994).

About

Mixture Density Network implementation PyTorch

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages