Source code for batorch.sgmcmc.diffusions

from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import required
from batorch.utils.data import DataModule
import numpy as np
import torch

[docs]class Diffusion(Optimizer): """Base class for implementing transition kernels of MCMC methods. These methods result from the discretization of stochastic differential equations whose solutions are diffusion processes. It essentially inherits from :class:`torch.optim.optimizer.Optimizer` although it does not perform an optimization. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. """ def __init__( self, params, step_size, weight_decay, **kwargs ): defaults = dict(step_size=step_size, weight_decay=weight_decay, **kwargs) super().__init__(params, defaults)
[docs] def init_state(self, p, state, group): """Function used to initialize optional state variables. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. :raises NotImplementedError: """ raise NotImplementedError
[docs] def update_fn(self, p, d_p, state, group): """Function that updates the parameters. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dict. :param group: Group dict. :raises NotImplementedError: """ raise NotImplementedError
@torch.no_grad() def step(self, closure=None): """Basic step function that loops over the model parameters. :param closure: Closure function that can be useful if the loss function has to be called several times, defaults to None :return: Objective function. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] for p in group['params']: if p.grad is None: continue state = self.state[p] if len(state) == 0: state['iteration'] = 0 self.init_state(p, state, group) state['iteration'] += 1 d_p = p.grad if weight_decay!=0.0: d_p.add_(p, alpha=weight_decay) self.update_fn(p, d_p, state, group) return loss
[docs]class DiffusionSGLD(Diffusion): """Implements stochastic gradient Langevin dynamics (SGLD) proposed in the `Bayesian learning via stochastic gradient Langevin dynamics <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.441.3813&rep=rep1&type=pdf>`_ paper. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. """ def __init__( self, params, step_size, weight_decay: float=0.0, ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay)
[docs] def init_state(self, p, state, group): """Function used to initialize optional state variables. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ pass
[docs] def update_fn(self, p, d_p, state, group): r"""Perform the Langevin update. .. math:: \theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k}\Delta W^{k+1} where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size, :math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and :math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. If the sequence :math:`\{\epsilon_k \}_{k\geq 0}` and the potential function :math:`U` satisfy a few conditions, the stationary distribution of the discrete Markov chain approximates the target distribution given by .. math:: p_{\Theta}(\theta) = c_0 \exp(-U(\theta)) where :math:`c_0` is a (possibly unknown) normalization constant. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary, not used for SGLD. :param group: Group dictionary that contains at least the key `step_size`. """ langevin_noise = p.new(p.size()).normal_(mean=0,std=1) p.add_(0.5 * d_p, alpha=-group['step_size']) p.add_(langevin_noise, alpha=np.sqrt(group['step_size']))
[docs]class DiffusionCyclicSGLD(Diffusion): """Implements the cyclic stochastic gradient Langevin dynamics (SGLD) proposed in the `Cyclic Stochastic Gradient MCMC for Bayesian Deep Learning <https://arxiv.org/abs/1902.03932>`_ paper. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param num_cycles: Number of cycles of the cyclic step size. :param num_iterations: Total number of sampling iterations. """ def __init__( self, params, step_size, num_cycles, num_iterations, weight_decay: float=0.0, ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay) self.num_cycles = num_cycles self.num_iterations = num_iterations self.km = self.num_iterations // self.num_cycles
[docs] def init_state(self, p, state, group): """Function used to initialize optional state variables. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ pass
[docs] def update_fn(self, p, d_p, state, group): r"""Perform the Langevin update. .. math:: \theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k}\Delta W^{k+1} where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size, :math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and :math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. If the sequence :math:`\{\epsilon_k \}_{k\geq 0}` and the potential function :math:`U` satisfy a few conditions, the stationary distribution of the discrete Markov chain approximates the target distribution given by .. math:: p_{\Theta}(\theta) = c_0 \exp(-U(\theta)) where :math:`c_0` is a (possibly unknown) normalization constant. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary, not used for SGLD. :param group: Group dictionary that contains at least the key `step_size`. """ it = state["iteration"]-1 freq = torch.tensor(torch.pi * (it % self.km)/self.km) lr_0 = group["step_size"] lr = 0.5*lr_0*(torch.cos(freq) + 1.0) langevin_noise = p.new(p.size()).normal_(mean=0,std=1) p.add_(0.5 * d_p, alpha=-lr) p.add_(langevin_noise, alpha=np.sqrt(lr))
[docs]class DiffusionpSGLD(Diffusion): """Implements the preconditioned stochastic gradient Langevin dynamics (pSGLD) proposed in the `Preconditioned stochastic gradient Langevin dynamics for deep neural networks <https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/download/11835/11805>`_ paper. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param alpha: momentum factor with values in :math:`[0,1]` :param eps: diagonal perturbation to avoid the preconditioner from degenerating """ def __init__( self, params, step_size, weight_decay: float=0.0, alpha=0.99, eps=1e-8, ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, alpha=alpha, eps=eps)
[docs] def init_state(self, p, state, group): r"""Initializes a state variable for storing the following estimation of the squared stochastic gradient: .. math:: L(\theta^k) = \alpha L(\theta^{k-1}) + (1-\alpha) \hat{\nabla}U(\theta^{k-1}) \circ \hat{\nabla}U(\theta^{k-1}) The operator :math:`\circ` denotes the element-wise product. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ state['square_avg'] = torch.zeros_like(p)
[docs] def update_fn(self, p, d_p, state, group): r"""Performs the preconditioned Langevin update. .. math:: \theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\mathbf{D}(\theta^k)\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k\mathbf{D}(\theta^k)}\Delta W^{k+1} where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size, :math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and :math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. The matrix :math:`\mathbf{D}` is referred to as the preconditioner and takes the form .. math:: \mathbf{D}(\theta^k) = \mathrm{diag}\left(\lambda \mathbf{I} + \sqrt{L(\theta^k)}\right)^{-1} It should be noted that this algorithm is missing an extra term, :math:`\Gamma(\theta^k)`, which is discarded for computational efficiency. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary that has the key `square_avg`. :param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`. """ step_size = group['step_size'] alpha = group['alpha'] square_avg = state['square_avg'] square_avg.mul_(alpha).addcmul_(d_p, d_p, value=1.0-alpha) avg = square_avg.sqrt().add_(group['eps']) langevin_noise = p.new(p.size()).normal_(mean=0, std=1) / np.sqrt(step_size) p.add_(0.5 * d_p.div_(avg) + langevin_noise / torch.sqrt(avg), alpha=-step_size)
[docs]class DiffusionSGHMC(Diffusion): """Implements the stochastic gradient Hamiltonian Monte Carlo (SGMCMC) algorithm of the `Stochastic Gradient Hamiltonian Monte Carlo <http://proceedings.mlr.press/v32/cheni14.pdf>`_ paper but without performing the change of variable. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param damping: Damping parameter. Default: 1.0. """ def __init__( self, params, step_size, weight_decay: float=0.0, damping: float=0.01 ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping)
[docs] def init_state(self, p, state, group): """Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ state["momentum"] = torch.zeros_like(p)
[docs] def resample_momentum(self): """Resamples the `momentum` state variable. """ for group in self.param_groups: for p in group['params']: lr=group['step_size'] state = self.state[p] state["momentum"] = p.new(p.size()).normal_(mean=0,std=1)
[docs] def update_fn(self, p, d_p, state, group): r"""Performs the SGHMC update. .. math:: & \theta^{k+1} = \theta^k + \epsilon_k \mathbf{M}^{-1} v^k & v^{k+1} = v^k - \epsilon_k\hat{\nabla}U(\theta^k) - \epsilon_k \mathbf{C}\mathbf{M}^{-1}v^k + \sqrt{2\mathbf{C}\epsilon_k} \, \Delta W^{k+1} The mass and damping matrices are chosen as :math:`\mathbf{M} = \mathbf{I}` and :math:`\mathbf{C} = f\mathbf{I}`, where :math:`f` is the damping parameter. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary that has the key `square_avg`. :param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`. """ damping=group['damping'] step_size=group['step_size'] momentum=state['momentum'] sigma = np.sqrt(2.0*damping) sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size) momentum_t = momentum.mul_(1.0-damping*step_size).add_(d_p - sample_t, alpha=-step_size) p.add_(momentum_t, alpha=step_size)
[docs]class DiffusionCyclicSGHMC(Diffusion): """Implements the cyclic stochastic gradient Hamiltonian Monte Carlo (SGMCMC) algorithm proposed in the `Cyclic Stochastic Gradient MCMC for Bayesian Deep Learning <https://arxiv.org/abs/1902.03932>`_ paper. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param damping: Damping parameter. Default: 1.0. :param num_cycles: Number of cycles of the cyclic step size. :param num_iterations: Total number of sampling iterations. """ def __init__( self, params, step_size, num_cycles, num_iterations, weight_decay: float=0.0, damping: float=0.01 ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping) self.num_cycles = num_cycles self.num_iterations = num_iterations self.km = self.num_iterations // self.num_cycles
[docs] def init_state(self, p, state, group): """Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ state["momentum"] = torch.zeros_like(p)
[docs] def resample_momentum(self): """Resamples the `momentum` state variable. """ for group in self.param_groups: for p in group['params']: lr=group['step_size'] state = self.state[p] state["momentum"] = p.new(p.size()).normal_(mean=0,std=1)
[docs] def update_fn(self, p, d_p, state, group): r"""Performs the SGHMC update. .. math:: & \theta^{k+1} = \theta^k + \epsilon_k \mathbf{M}^{-1} v^k & v^{k+1} = v^k - \epsilon_k\hat{\nabla}U(\theta^k) - \epsilon_k \mathbf{C}\mathbf{M}^{-1}v^k + \sqrt{2\mathbf{C}\epsilon_k} \, \Delta W^{k+1} The mass and damping matrices are chosen as :math:`\mathbf{M} = \mathbf{I}` and :math:`\mathbf{C} = f\mathbf{I}`, where :math:`f` is the damping parameter. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary that has the key `square_avg`. :param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`. """ damping=group['damping'] momentum=state['momentum'] it = state["iteration"]-1 freq = torch.tensor(torch.pi * (it % self.km)/self.km) lr_0 = group["step_size"] step_size = 0.5*lr_0*(torch.cos(freq) + 1.0) sigma = np.sqrt(2.0*damping) sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size) momentum_t = momentum.mul_(1.0-damping*step_size).add_(d_p - sample_t, alpha=-step_size) p.add_(momentum_t, alpha=step_size)
[docs]class DiffusionSGHMCSA(Diffusion): """Implements the scale adaptive SGHMC algorithm proposed in the `Bayesian optimization with robust Bayesian neural networks <https://papers.nips.cc/paper/6117-gaussian-process-bandit-optimisation-with-multi-fidelity-evaluations.pdf>`_ paper. This algorithms uses the burnin phase to adapt its hyparameters (not including the step size). :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param num_burnin_steps: Number of burn in steps used to estimate the algorithm hyperparameters. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param mdecay: Momentum decay per time step. Default: 0.05. """ def __init__( self, params, step_size, num_burnin_steps: int, weight_decay: float=0.0, mdecay: float=0.05, ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, mdecay=mdecay, num_burnin_steps=num_burnin_steps)
[docs] def init_state(self, p, state, group): """Initializes a momentum and three additional state variables :math:`\tau, g, \hat{v}`. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ state["tau"] = torch.ones_like(p) state["g"] = torch.ones_like(p) state["v_hat"] = torch.ones_like(p) state["momentum"] = torch.zeros_like(p)
[docs] def update_fn(self, p, d_p, state, group): r"""Performs the scale adaptive SGHMC update. .. math:: & \theta^{k+1} = \theta^k + v^k & v^{k+1} = -\epsilon^2_k\hat{\mathbf{L}}^{-1/2}\hat{\nabla}U(\theta^k) - \epsilon_k \hat{\mathbf{L}}^{-1/2}\mathbf{C} v^k + \sqrt{2\epsilon_k^3\hat{\mathbf{L}}^{-1/2}\mathbf{C}\hat{\mathbf{L}}^{-1/2}-\epsilon_k^4\mathbf{I}} \, \Delta W^{k+1} where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`v^k` is the momentum, :math:`\epsilon_k` is the step size, :math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and :math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. The stationary distribution of the discrete Markov chain approximates the target distribution .. math:: p_{\Theta,V}(\theta,v) = c_0 \exp\left(-U(\theta) - \frac{1}{2}v^T v \right) The matrix :math:`\hat{\mathbf{L}}` denotes the second-order moment of the gradient which is estimated with an exponential moving average during the burnin phase. The damping matrix :math:`\mathbf{C}` is chosen such that :math:`\epsilon_k \mathbf{C}\hat{\mathbf{L}} = 0.05\mathbf{I}`. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary that has the key `square_avg`. :param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`. """ mdecay, step_size = group["mdecay"], group['step_size'] tau, g, v_hat = state["tau"], state["g"], state["v_hat"] momentum = state["momentum"] r_t = 1. / (tau + 1.) minv_t = 1. / torch.sqrt(v_hat) if state["iteration"] <= group["num_burnin_steps"]: tau.add_(1. - tau * (g * g / v_hat)) g.add_(-g * r_t + r_t * d_p) v_hat.add_(-v_hat * r_t + r_t * (d_p ** 2)) noise_scale = 2. * (step_size ** 2) * mdecay * minv_t - (step_size ** 4) sigma = torch.sqrt(torch.clamp(noise_scale, min=1e-16)) sample_t = torch.normal(mean=0., std=sigma) momentum_t = momentum.add_(- (step_size ** 2) * minv_t * d_p - mdecay * momentum + sample_t) p.add_(momentum_t)
class DiffusionStormerVerletSGHMC(Diffusion): """Implements Hamiltonian dynamics solved with a Stormer-Verlet discretization scheme. It is an alternative to :class:`sgmcmc.diffusions.DiffusionSGHMC`. :param params: Generator that yields a network parameters (model.parameters()) :param step_size: Step size of the discretized stochastid differential equation. :param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0. :param damping: Scalar damping parameter. Default: 1.0. """ def __init__( self, params, step_size, weight_decay: float=0.0, damping: float=1.0 ): super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping) def init_state(self, p, state, group): """Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`. :param p: A parameter (weight or bias). :param state: State dict. :param group: Group dict. """ state["momentum"] = torch.zeros_like(p) def update_fn(self, p, d_p, state, group): r"""Performs the Stormer-Verlet update. .. math:: & \theta^{k+1/2} = \theta^k + 0.5\epsilon_k v^{k} & v^{k+1} = \frac{1-\alpha}{1+\alpha} v^k - \frac{\epsilon_k}{1+\alpha} \hat{\nabla}U(\theta^{k+1/2}) + \sqrt{\frac{f\epsilon}{1+\alpha}} \, \Delta W^{k+1} & \theta^{k+1} = \theta^{k+1/2} + 0.5\epsilon_k v^{k+1} where :math:`f` is the damping parameter and :math:`\alpha = f\epsilon_k/4`. :param p: A parameter (weight or bias). :param d_p: Gradient of the objective function with respect to parameter p. :param state: State dictionary that has the key `square_avg`. :param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`. """ step_size=group['step_size'] damping=group['damping'] momentum=state["momentum"] varsigma = step_size*damping/4.0 beta = 1.0/(1.0+varsigma) alpha = (1.0-varsigma)*beta sigma = np.sqrt(damping)*beta sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size) momentum_t = momentum.mul_(alpha).add_(beta*d_p - sample_t, alpha=-step_size) p.add_(0.5*momentum_t, alpha=step_size) p.add_(0.5*momentum, alpha=step_size) if __name__ == "__main__": class Model(torch.nn.Module): def __init__(self): super().__init__() self.f = torch.nn.Linear(10,10) def forward(self,x): return self.f(x) model = Model() diffusion = DiffusionSGHMC(params=model.parameters(), step_size=1.0, weight_decay=0.0, alpha=0.01)