BIJAX - Bayesian Inference in JAX

  • Contributed to an open-source Python library with a unified and transparent approach for various distribution approximation techniques such as Laplace Approximation and Markov Chain Monte Carlo (MCMC) sampling.
  • Reinforced the approximation methods by adding extensive functionalities such as Full rank, Low rank, Mean Field, Subnetwork, Last Layer, and Kronecker-Factored Approximate Curvature (KFAC).

Quick Links:

Laplace Approximation

The results of Laplace Approximation for fully bayesian network, last-layer bayesian network, and subnet bayesian network are shown in the figures below.

Laplace Approximation using all layers for regression on a 2D dataset. MAP, fullrank, lowrank, diag, kron (left to right)

Laplace Approximation using last layer for regression on a 2D dataset. MAP, fullrank, diag, kron (left to right)

Laplace Approximation using subnetwork for regression on a 2D dataset. MAP, fullrank (left to right)

Hamiltonian Monte Carlo (HMC) on the Digits dataset

For the case of out of distribution data (the letter ‘A’ is out of distribution for a dataset of digits), a deterministic model, would predict a class with a certain confidence. We define an uncertainty metric ‘entropy’ which provides information on how uncertain the model is (higher is more uncertain). The results below show that the HMC model predicts a class with low confidence and high entropy.

Ground Truth, Predicted Class, Predicted Class Probability and Entropy

Accuracy and Timing Analysis

Comparable accuracies for different models

Time taken for training different models
