GitXplorerGitXplorer
h

flow_mcmc

public
6 stars
1 forks
0 issues

Commits

List of commits on branch master.
Unverified
a1fb870b33676f7e3e686f24719f0e47d332f08b

mcmc: global_local: Fix target_x_s_t and prop_x_s_t not up-to-date after warmup

hh2o64 committed 2 years ago
Unverified
9e0926dc1eac4b9bf439aff2bd46ea052b5dd5e2

mcmc: isir: Update x_s_t online in one_step

hh2o64 committed 2 years ago
Unverified
56b9c39b8d1237d89626f96bbb71ab3dc67f5699

mcmc: Make sure the current state is updated

hh2o64 committed 2 years ago
Unverified
dccd0afefbe6bffef5c9f30732d3b79f3d18911b

mcmc: mala: Improve speed

hh2o64 committed 2 years ago
Unverified
58b333283945a05439737f6a92b44b0967cab4fa

flow_mcmc: Add Arxiv link

hh2o64 committed 2 years ago
Unverified
65db318e113a671e0ffca3c4bd553589fb0021aa

flow_mcmc: Initial commit

hh2o64 committed 2 years ago

README

The README file for this repository.

Flow MCMC

This is the official repository for the preprint "On Sampling with Approximate Transport Maps" (Arxiv) in PyTorch.

Installation

Install dependencies with pip install -r requirements.txt. If you want to rerun the Alanine Dipeptide experiments, use conda and install the following packages conda install -c conda-forge openmm openmmtools mdtraj and also pip install git+https://github.com/VincentStimper/boltzmann-generators.git

MCMC Samplers

Flow MCMC provide many popular MCMC samplers under a common API

  • Metropolis-Adjusted Langevin Algorithm (MALA) with mcmc.mala.MALA
  • Hamiltonian Monte Carlo (HMC) (Neal et al., 2011) with mcmc.hmc.HMC
  • Random Walk Metropolis Hastings (RWMH) with mcmc.rwhm.RWHM
  • Elliptical Slice Sampling (ESS) (Murray et al., 2010) with mcmc.ess.ESS
  • Independent Metropolis-Hastings (IMH) with mcmc.imh.IndependentMetropolisHastings
  • Iterated Sampling Importance Resampling (i-SIR) (Andrieu et al., 2010) with mcmc.isir.iSIR
  • Pyro wrappers for HMC and NUTS (Bingham et al., 2019) with mcmc.pyro_mcmc.HMC and mcmc.pyro_mcmc.NUTS

Sampling works by calling

sampler.sample(x_s_t_0, n_steps, target, temp=1.0, warmup_steps=0, verbose=False)

Where

  • x_s_t_0 (tensor of shape (batch_size, dim) ) is the first sample of the chain
  • n_steps (int) is the length of the chain
  • target (callable) is the log-likelihood of the target distribution (it must support batched inputs)
  • temp (float) is the temperature factor of the log-likelihood
  • warmup_steps (int) is the number of burn-in steps (they will be wasted)
  • verbose (bool) displays a progress bar during sampling

Note that unlike many MCMC samplers in PyTorch, all the MCMC samplers support sampling multiple chains in parallel. Each sampler collects diagnostics (acceptance rates, ...) which can be collected using sampler.get_diagnostics(diag_name).

You can also use normalizing flows to enhance your sampler in two ways

  • Using flow-MCMC algorithms (i.e., using the flow as global proposals)
  • Using neutra-MCMC algorithms (i.e., using the flow as a reparametrization map) (Parno & Marzouk, 2018, Hoffman et al., 2019) by wrapping the sampler with mcmc.neutra.NeuTra(inner_sampler, flow)

Importance Sampling is also available at mcmc.classic_is.IS which the same API as MCMCs (batch_size and n_steps are ignored) and can be enhanced by a flow.

We also provide a way to perform adaptive learning of normalizing flows with mcmc.learn_flow.LearnMCMC (Gabrie et al. , 2022) .

We provide an implementation of RealNVP (Dinh et al., 2016) based on marylou-gabrie/adapt-flow-ergo's implementation as well as a wrapper to flows from VincentStimper/normalizing-flows.

"On Sampling with Approximate Transport Maps"

Here we explain how to rerun the experiments presented in the paper. Note that the output paths are defined on top of the configuration files in configs/flow_approx/.

Synthetic case studies

All the experiments from the synthetic case studies can be rerun using the following commands

python experiments/flow_approx/gaussians_three_flows.py configs/flow_approx/gaussians_three_flows.yaml --seed {INSERT_SEED}
python experiments/flow_approx/funnel.py configs/flow_approx/funnel.yaml --seed {INSERT_SEED}
python experiments/flow_approx/gaussian_mixture.py configs/flow_approx/gaussians_mixture.yaml --seed {INSERT_SEED}
python experiments/flow_approx/banana.py {OUTPUT_PATH}/backward_dim{DIMENSION}.pkl --loss_type backward_kl --dim {DIMENSION} --seed {SEED}

The hyper-parameter grid search can be rerun by using the *_debug.yaml configs. The flows for the mixture of Gaussians can be re-trained using

python experiments/flow_approx/gaussian_mixture_flow.py --dim {DIMENSION} --checkpoint_path {SAVE_PATH}/dim_{DIMENSION}/

Benchmarks on real tasks

Alanine Dipeptide

The flow (also available in experiments/flow_approx/models/aldp/flow_aldp.pt) can be retrained using the procedure described in lollcat/fab-torch (Midgley et al., 2022). Sampling can be performed using

python experiments/flow_approx/aldp.py configs/flow_approx/aldp.yaml --seed {SEED} --save_samples

The data used for the ground truth are available on authors' Zenodo.

Logistic Regression

The flow for the logistic regression experiment can be obtained by running

python experiments/flow_approx/logistic_regression_flow.py --save_path {OUTPUT_PATH} --neutra_flow

and sampling can be done with

python experiments/flow_approx/logistic_regression.py configs/flow_approx/logistic_regression.yaml --seed {SEED} --neutra_flow

Note that you will need ground truth samples obtained using NUTS by running

python experiments/flow_approx/logistic_regression_gt.py --save_path {OUTPUT_PATH}

Phi Four

The flows for the Phi Four experiment can be obtained by running

python experiments/flow_approx/phi_four_parameters.py configs/flow_approx/phi_four_parameters/global_{DIMENSION}.yaml configs/flow_approx/phi_four_parameters/best_flows_{DIMENSION}.yaml --mala_sampler 

and sampling can be done with

python experiments/flow_approx/phi_four.py configs/flow_approx/phi_four.yaml --save_samples --seed {SEED}

Appendix

The flows for the figure 8 can be retrained using

python experiments/flow_approx/many_flows_two_moons.py --load_path {OUTPUT_PATH} --seed {SEED}

🏗️ TODO

  • Fix mcmc.hmc.HMC : right now the warmup phase is broken
  • Allow learning a preconditioning matrix for mcmc.mala.MALA