Probabilistic Programming with Programmable Variational Inference
Author(s)
Becker, McCoy R.; Lew, Alexander K.; Wang, Xiaoyan; Ghavami, Matin; Huot, Mathieu; Rinard, Martin C.; Mansinghka, Vikash K.; ... Show more Show less
Download3656463.pdf (887.2Kb)
Publisher with Creative Commons License
Publisher with Creative Commons License
Creative Commons Attribution
Terms of use
Metadata
Show full item recordAbstract
Compared to the wide array of advanced Monte Carlo methods supported by modern probabilistic programming languages (PPLs), PPL support for variational inference (VI) is underdeveloped: users are typically limited to a small selection of predefined variational objectives and gradient estimators, which are implemented monolithically (and without explicit correctness arguments) in PPL backends. In this paper, we propose a modular approach to supporting VI in PPLs, based on compositional program transformation. First, we present a probabilistic programming language for defining models, variational families, and compositional strategies for propagating gradients. Second, we present a differentiable programming language for defining variational objectives. Models and variational families from the first language are automatically compiled into new differentiable functions that can be called from the second language, for estimating densities and expectations. Finally, we present an automatic differentiation algorithm that differentiates these variational objectives, yielding provably unbiased gradient estimators for use during optimization. We also extend our source language with features not previously supported for VI in PPLs, including approximate marginalization and normalization. This makes it possible to concisely express many models, variational families, objectives, and gradient estimators from the machine learning literature, including importance-weighted autoencoders (IWAE), hierarchical variational inference (HVI), and reweighted wake-sleep (RWS). We implement our approach in an extension to the Gen probabilistic programming system (genjax.vi, implemented in JAX), and evaluate our automation on several deep generative modeling tasks, showing minimal performance overhead vs. hand-coded implementations and performance competitive to well-established open-source PPLs.
Date issued
2024-06-20Department
Massachusetts Institute of Technology. Department of Electrical Engineering and Computer Science; Massachusetts Institute of Technology. Computer Science and Artificial Intelligence LaboratoryJournal
Proceedings of the ACM on Programming Languages
Publisher
Association for Computing Machinery
Citation
Becker, McCoy R., Lew, Alexander K., Wang, Xiaoyan, Ghavami, Matin, Huot, Mathieu et al. 2024. "Probabilistic Programming with Programmable Variational Inference." Proceedings of the ACM on Programming Languages, 8 (PLDI).
Version: Final published version
ISSN
2475-1421
Collections
The following license files are associated with this item: