Source code for batorch.sgmcmc.samplers

import copy
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.distributions as td

from collections import OrderedDict
from typing import Callable, Optional, Union

import functools
from batorch.sgmcmc.diffusions import DiffusionCyclicSGLD, DiffusionCyclicSGHMC, DiffusionSGLD, DiffusionpSGLD, DiffusionSGHMC, DiffusionSGHMCSA
from batorch.utils.misc import parameters_to_vector, vector_to_parameters

[docs]class Sampler(object): """Base class for implementing sampler. """ def __init__( self, DiffusionType: type, negloglikelihood: nn.Module, neglogprior: Callable, init_params: Union[None, str, OrderedDict], dataloader: DataLoader, **kwargs ): self.init_params = init_params self.neglogprior = neglogprior self.negloglikelihood = negloglikelihood self.device = next(self.negloglikelihood.parameters()).device if isinstance(self.init_params, str): self.negloglikelihood.load_state_dict(torch.load(init_params)) elif isinstance(self.init_params, OrderedDict): self.negloglikelihood.load_state_dict(self.init_params) if isinstance(DiffusionType, type): self.diffusion = DiffusionType(params=self.negloglikelihood.parameters(), **kwargs) else: raise TypeError("DiffusionType should be a type not an instance") self.N_train = len(dataloader.dataset)
[docs] def step(self, x, y): """This method updates the parameters and returns the loss. :param x: Input data that will be passed to the negative log-likelihood. :param y: Output tensor that will be compared to the prediction. :raises NotImplementedError: """ raise NotImplementedError
[docs] def get_params(self): """Function that flattens the network parameters. :return: Flattened network parameters. """ vec_params = parameters_to_vector(self.negloglikelihood.parameters(), grad=False, both=False, cpu=False) return vec_params
[docs] def get_grads(self): """Function that flattens the network gradients. :return: Flattened network gradients. """ vec_grads = parameters_to_vector(self.negloglikelihood.parameters(), grad=True, both=False, cpu=False) return vec_grads
[docs] def set_params(self, vec_params): """Function that loads flatten parameters to the model. :param vec_params: Flattened network parameters """ vector_to_parameters(vec_params, self.negloglikelihood.parameters(), grad=False)
[docs] def compute_gradients_log_target(self, dataloader): self.diffusion.zero_grad() for batch_idx, (x, y) in enumerate(dataloader): x, y = x.to(self.device), y.to(self.device) loglike = -self.negloglikelihood(x, y) if batch_idx==0: logprior = -self.neglogprior(self.negloglikelihood.parameters()) logpost = loglike + logprior else: logpost = loglike logpost.backward()
[docs]class EulerExplicit(Sampler): """This class implements a SGMCMC update using a standard stochastic gradient (SG) estimation and an Euler integrator. :param DiffusionType: type of Diffusion. :param negloglikelihood: :class:`torch.nn.Module` object corresponding to a neg. log-likelihood function. :param prior: :class:`torch.distributions` object corresponding to a prior distribution. :param init_params: Optional path to a torch state_dict for initializing the model parameters. :param dataloader: An already instantiated dataloader. :param kwargs: Any optional keyword arguments that will passed to the Diffusion. """ def __init__( self, DiffusionType: type, negloglikelihood: nn.Module, neglogprior: Callable, init_params: Union[None, str, OrderedDict], dataloader: DataLoader, **kwargs ): super().__init__(DiffusionType=DiffusionType, negloglikelihood=negloglikelihood, neglogprior=neglogprior, init_params=init_params, dataloader=dataloader, **kwargs)
[docs] def estimate_gradients(self, x, y): r"""Function that estimates the gradient of the log posterior with a standard mini-batched stochastic gradient. .. math:: \hat{\nabla}U(\theta) = -\frac{N}{|B_k|} \sum_{i\in B_k} \nabla_\theta\log(p(y_i|x_i,\theta)) - \nabla\log(p(\theta)) where :math:`N` is number of samples the whole dataset, :math:`B_k` is a mini batch of indices drawn by the dataloader, :math:`p(y_i|x_i,\theta)` denotes likelihood function, and :math:`p(\theta)` is the prior distribution of the weights :math:`\theta`. :param x: Input data that will be passed to the negative log-likelihood. :param y: Output tensor that will be compared to the prediction. """ self.diffusion.zero_grad() loglike = self.negloglikelihood(x, y) logprior = self.neglogprior(self.negloglikelihood.parameters()) logpost = (self.N_train/x.size(0)) * loglike + logprior logpost.backward() return logpost.detach().item()
[docs] def step(self, x, y): r"""Updates the network parameters according to the chosen :class:`Diffusion` using the standard stochastic gradient approximation. If :math:`\varphi` denotes the update function of the underlying :class:`Diffusion`, then this function performs the following: 1. Zero-out the gradients 2. Estimate the stochastic gradient :math:`\hat{\nabla}U(\theta^k)` for the current value of the network parameters 3. Update the parameters: .. math:: \theta^{k+1} = \varphi(\theta^k, \hat{\nabla}U(\theta^k)) This sampler can be used with any :class:`Diffusion`. Note that if :class:`DiffusionSGHMC` is used here, then no leapfrogs will be performed. If you want to perform :math:`L > 1` leapfrog steps, please see the specific :class:`LeapfrogStochasticGradient`. :param x: Input data that will be passed to the negative log-likelihood. :param y: Output tensor that will be compared to the prediction. """ loss = self.estimate_gradients(x, y) self.diffusion.step() return loss
[docs]class EulerExplicitControlVariates(EulerExplicit): """This class implements a SGMCMC update using the variance reduction technique proposed by the paper the paper `Control variates for stochastic gradient MCMC <https://link.springer.com/article/10.1007/s11222-018-9826-2>`_, and an Euler integrator. The centering value :math:`\hat{\theta}` is provided as a state_dict by the user via the `init_param` parameter. :param init_params: Optional path to a torch state_dict for initializing the model parameters. :param init_control_params: Control variate. :param DiffusionType: type of Diffusion. :param negloglikelihood: :class:`torch.nn.Module` object corresponding to a neg. log-likelihood function. :param prior: :class:`torch.distributions` object corresponding to a prior distribution. :param dataloader: An already instantiated dataloader. :param kwargs: Any additional keyword arguments that will be passed to the Diffusion. """ def __init__( self, DiffusionType: type, negloglikelihood: nn.Module, neglogprior: Union[object, td.Distribution], init_params: Union[None, str, OrderedDict], init_control_params: Union[str, OrderedDict], dataloader: DataLoader, **kwargs ): super().__init__(DiffusionType=DiffusionType, negloglikelihood=negloglikelihood, neglogprior=neglogprior, init_params=init_params, dataloader=dataloader, **kwargs) self.model_center = copy.deepcopy(self.negloglikelihood) if isinstance(init_control_params, str): self.model_center.load_state_dict(torch.load(init_control_params)) elif isinstance(init_control_params, OrderedDict): self.model_center.load_state_dict(init_control_params) else: raise TypeError("Please provide a state_dict for the control variates. Either a path to a file (str) or a state_dict (OrderedDict).") self.compute_full_center_gradient(dataloader)
[docs] def compute_full_center_gradient(self, dataloader): r"""Method that computes the true gradient of the log likelihood using the centering value :math:`\hat{\theta}` chosen by the user. .. math:: \nabla U(\hat{\theta}) = -\sum_{i=1}^{N} \nabla_{\theta} \log(p(y_i|x_i,\hat{\theta})) where :math:`N` denotes the number of samples in the whole dataset and :math:`p(y_i|x_i,\hat{\theta})` is the likelihood function. """ self.model_center.zero_grad() for _, (x, y) in enumerate(dataloader): x, y = x.to(self.device), y.to(self.device) loglike = self.model_center(x, y) loglike.backward() self.full_center_grads = [] with torch.no_grad(): for p in self.model_center.parameters(): self.full_center_grads.append(p.grad.clone())
[docs] def estimate_gradients(self, x, y): r"""Method that estimates the gradient of the log likelihood with a variance reduction technique (control variates). .. math:: \hat{\nabla}U(\theta) = -\sum_{i=1}^{N} \nabla_{\theta} \log(p(y_i|x_i,\hat{\theta})) -\frac{N}{|B_k|} \sum_{i \in B_k} \left( \nabla_{\theta}\log(p(y_i|x_i,\theta)) - \nabla_{\theta}\log(p(y_i|x_i,\hat{\theta}))\right) - \nabla\log(p(\theta) where :math:`N` is number of samples the whole dataset, :math:`B_k` is a mini batch of indices drawn by the dataloader, :math:`\theta \mapsto p(\cdot|\cdot,\theta)` denotes likelihood function, and :math:`p(\theta)` is the prior distribution of the weights :math:`\theta`, and :math:`\hat{\theta}` denotes the control variates. :param x: Input data that will be passed to the negative log-likelihood. :param y: Output tensor that will be compared to the prediction. """ # estimate center gradients self.model_center.zero_grad() loglike_center = self.model_center(x, y) logpost_center = (self.N_train/x.size(0)) * loglike_center logpost_center.backward() # estimate stochastic gradients self.diffusion.zero_grad() loglike = self.negloglikelihood(x, y) logprior = self.neglogprior(self.negloglikelihood.parameters()) logpost = (self.N_train/x.size(0)) * loglike + logprior logpost.backward() # assemble gradients with torch.no_grad(): for i, (p_center, p) in enumerate(zip(self.model_center.parameters(), self.negloglikelihood.parameters())): p.grad.add_(p_center.grad, alpha=-1.0) p.grad.add_(self.full_center_grads[i], alpha=1.0) return logpost.detach().item()
[docs]class EulerExplicitFixedPoint(EulerExplicitControlVariates): """This class implements a SGMCMC update using the variance reduction technique proposed by the paper the paper `Variance reduction in stochastic gradient Langevin dynamics <https://papers.nips.cc/paper/6293-safe-policy-improvement-by-minimizing-robust-baseline-regret.pdf>`_, and an Euler integrator. It extends the class :class:`EulerStochasticGradientCV` with an additional method that updates the centering parameter :math:`\hat{\theta}` every `m_iter` iterations. :param init_params: Optional path to a torch state_dict for initializing the model parameters. :param init_control_params: Initial control variate. :param DiffusionType: type of Diffusion. :param negloglikelihood: :class:`torch.nn.Module` object corresponding to a neg. log-likelihood function. :param prior: :class:`torch.distributions` object corresponding to a prior distribution. :param datamodule: An already instantiated DataModule. :param num_leapfrogs: Number of leapfrog steps to perform at each update. :param m_iter: Number of iterations after which the centering parameters are replaced by the current values of the network parameters. :param kwargs: Any additional keyword arguments that will be passed to the Diffusion. """ def __init__( self, DiffusionType: type, negloglikelihood: nn.Module, neglogprior: Union[object, td.Distribution], init_params: Optional[None], init_control_params: str, dataloader: DataLoader, m_iter: int, **kwargs ): super().__init__(DiffusionType=DiffusionType, negloglikelihood=negloglikelihood, neglogprior=neglogprior, init_params=init_params, init_control_params=init_control_params, dataloader=dataloader, **kwargs) self.iter = 0 self.m_iter = m_iter self.dataloader = dataloader
[docs] def step(self, x, y): r"""Method that estimates the gradient of the log likelihood with a variance reduction technique (control variates). The centering parameter :math:`\hat{\theta}` is updated if the number of current iterations is proportional to `m_iter`. :param x: Input data that will be passed to the negative log-likelihood. :param y: Output tensor that will be compared to the prediction. """ if self.iter % self.m_iter == 0: flag = self.model_center.load_state_dict(self.negloglikelihood.state_dict()) self.compute_full_center_gradient(self.dataloader) loss = self.estimate_gradients(x, y) self.diffusion.step() self.iter += 1 return loss
[docs]def SamplerFactory(SamplerType, DiffusionType): """Function that can be used to define samplers based on a chosen Discretization (gradient estimator, integrator), and a chosen type of Diffusion (transition kernel). Several samplers are already built-in. An example is shown below. .. code-block:: python from batorch.sgmcmc.samplers import EulerExplicit from batorch.sgmcmc.diffusions import DiffusionSGLD SamplerSGLD = SamplerFactory(EulerExplicit, DiffusionSGLD) sampler = SamplerSGLD(negloglikelihood=negloglikehood, neglogprior=neglogprior, datamodule=datamodule.train_dataloader(), init_params=None, step_size=1e-3) """ return functools.partial(SamplerType, DiffusionType=DiffusionType)
SamplerSGLD = SamplerFactory(EulerExplicit, DiffusionType=DiffusionSGLD) SamplerSGLD.__doc__ = """SGLD sampler.""" SamplerPSGLD = SamplerFactory(EulerExplicit, DiffusionpSGLD) SamplerPSGLD.__doc__ = """Preconditioned SGLD sampler.""" SamplerSGLDCV = SamplerFactory(EulerExplicitControlVariates, DiffusionSGLD) SamplerSGLDCV.__doc__ = """SGLD sampler with control variates.""" SamplerSGLDSVRG = SamplerFactory(EulerExplicitFixedPoint, DiffusionSGLD) SamplerSGLDSVRG.__doc__ = """SGLD sampler with fixed point variance reduction.""" SamplerSGHMC = SamplerFactory(EulerExplicit, DiffusionSGHMC) SamplerSGHMC.__doc__ = """SGHMC sampler with a single leapfrog step.""" SamplerSGHMCSA = SamplerFactory(EulerExplicit, DiffusionSGHMCSA) SamplerSGHMCSA.__doc__ = """Scale adaptive SGHMC sampler with a single leapfrog step.""" SamplerSGHMCCV = SamplerFactory(EulerExplicitControlVariates, DiffusionSGHMC) SamplerSGHMCCV.__doc__ = """SGHMC sampler with control variates and a single leapfrog step.""" SamplerSGHMCSVRG = SamplerFactory(EulerExplicitFixedPoint, DiffusionSGHMC) SamplerSGHMCSVRG.__doc__ = """SGHMC sampler with fixed point variance reduction and a single leapfrog step.""" SamplerCyclicSGLD = SamplerFactory(EulerExplicit, DiffusionType=DiffusionCyclicSGLD) SamplerCyclicSGLD.__doc__ = """Cyclic SGLD sampler.""" SamplerCyclicSGHMC = SamplerFactory(EulerExplicit, DiffusionCyclicSGHMC) SamplerCyclicSGHMC.__doc__ = """Cyclic SGHMC sampler with a single leapfrog step."""