Skip to content

PyTorch implementation of the mixture distribution family with implicit reparametrisation gradients.

Notifications You must be signed in to change notification settings

vsimkus/torch-reparametrised-mixture-distribution

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Reparametrisable PyTorch MixtureSameFamily distribution

PyTorch implementation of the implicit reparametrisation trick for mixture distributions based on Figurnov et al., 2019, "Implicit Reparameterization Gradients" and the implementation in Tensorflow Probability.

Can be readily used for variational inference with mixture distribution variational families.

Remarks:

  • For multivariate mixtures, the class is currently implemented when the mixture component distributions fully factorise.
  • Also added a StableNormal distribution, which overrides the default cdf method with a more stable implementation from pytorch/pytorch#52973 (comment). The implementation also provides a _log_cdf method, however it is not used for the implicit reparametrisation.