from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import required
from batorch.utils.data import DataModule
import numpy as np
import torch
[docs]class Diffusion(Optimizer):
"""Base class for implementing transition kernels of MCMC methods.
These methods result from the discretization of stochastic differential equations whose solutions are diffusion processes.
It essentially inherits from :class:`torch.optim.optimizer.Optimizer` although it does not perform an optimization.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
"""
def __init__(
self,
params,
step_size,
weight_decay,
**kwargs
):
defaults = dict(step_size=step_size, weight_decay=weight_decay, **kwargs)
super().__init__(params, defaults)
[docs] def init_state(self, p, state, group):
"""Function used to initialize optional state variables.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
:raises NotImplementedError:
"""
raise NotImplementedError
[docs] def update_fn(self, p, d_p, state, group):
"""Function that updates the parameters.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dict.
:param group: Group dict.
:raises NotImplementedError:
"""
raise NotImplementedError
@torch.no_grad()
def step(self, closure=None):
"""Basic step function that loops over the model parameters.
:param closure: Closure function that can be useful if the loss function has to be called several times, defaults to None
:return: Objective function.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state['iteration'] = 0
self.init_state(p, state, group)
state['iteration'] += 1
d_p = p.grad
if weight_decay!=0.0:
d_p.add_(p, alpha=weight_decay)
self.update_fn(p, d_p, state, group)
return loss
[docs]class DiffusionSGLD(Diffusion):
"""Implements stochastic gradient Langevin dynamics (SGLD) proposed in the
`Bayesian learning via stochastic gradient Langevin dynamics <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.441.3813&rep=rep1&type=pdf>`_ paper.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
"""
def __init__(
self,
params,
step_size,
weight_decay: float=0.0,
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay)
[docs] def init_state(self, p, state, group):
"""Function used to initialize optional state variables.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
pass
[docs] def update_fn(self, p, d_p, state, group):
r"""Perform the Langevin update.
.. math::
\theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k}\Delta W^{k+1}
where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size,
:math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and
:math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. If the sequence :math:`\{\epsilon_k \}_{k\geq 0}`
and the potential function :math:`U` satisfy a few conditions, the stationary distribution of the discrete Markov chain
approximates the target distribution given by
.. math::
p_{\Theta}(\theta) = c_0 \exp(-U(\theta))
where :math:`c_0` is a (possibly unknown) normalization constant.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary, not used for SGLD.
:param group: Group dictionary that contains at least the key `step_size`.
"""
langevin_noise = p.new(p.size()).normal_(mean=0,std=1)
p.add_(0.5 * d_p, alpha=-group['step_size'])
p.add_(langevin_noise, alpha=np.sqrt(group['step_size']))
[docs]class DiffusionCyclicSGLD(Diffusion):
"""Implements the cyclic stochastic gradient Langevin dynamics (SGLD) proposed in the
`Cyclic Stochastic Gradient MCMC for Bayesian Deep Learning <https://arxiv.org/abs/1902.03932>`_ paper.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param num_cycles: Number of cycles of the cyclic step size.
:param num_iterations: Total number of sampling iterations.
"""
def __init__(
self,
params,
step_size,
num_cycles,
num_iterations,
weight_decay: float=0.0,
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay)
self.num_cycles = num_cycles
self.num_iterations = num_iterations
self.km = self.num_iterations // self.num_cycles
[docs] def init_state(self, p, state, group):
"""Function used to initialize optional state variables.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
pass
[docs] def update_fn(self, p, d_p, state, group):
r"""Perform the Langevin update.
.. math::
\theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k}\Delta W^{k+1}
where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size,
:math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and
:math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. If the sequence :math:`\{\epsilon_k \}_{k\geq 0}`
and the potential function :math:`U` satisfy a few conditions, the stationary distribution of the discrete Markov chain
approximates the target distribution given by
.. math::
p_{\Theta}(\theta) = c_0 \exp(-U(\theta))
where :math:`c_0` is a (possibly unknown) normalization constant.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary, not used for SGLD.
:param group: Group dictionary that contains at least the key `step_size`.
"""
it = state["iteration"]-1
freq = torch.tensor(torch.pi * (it % self.km)/self.km)
lr_0 = group["step_size"]
lr = 0.5*lr_0*(torch.cos(freq) + 1.0)
langevin_noise = p.new(p.size()).normal_(mean=0,std=1)
p.add_(0.5 * d_p, alpha=-lr)
p.add_(langevin_noise, alpha=np.sqrt(lr))
[docs]class DiffusionpSGLD(Diffusion):
"""Implements the preconditioned stochastic gradient Langevin dynamics (pSGLD) proposed in the
`Preconditioned stochastic gradient Langevin dynamics for deep neural networks <https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/download/11835/11805>`_ paper.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param alpha: momentum factor with values in :math:`[0,1]`
:param eps: diagonal perturbation to avoid the preconditioner from degenerating
"""
def __init__(
self,
params,
step_size,
weight_decay: float=0.0,
alpha=0.99,
eps=1e-8,
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, alpha=alpha, eps=eps)
[docs] def init_state(self, p, state, group):
r"""Initializes a state variable for storing the following estimation of the squared stochastic gradient:
.. math::
L(\theta^k) = \alpha L(\theta^{k-1}) + (1-\alpha) \hat{\nabla}U(\theta^{k-1}) \circ \hat{\nabla}U(\theta^{k-1})
The operator :math:`\circ` denotes the element-wise product.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
state['square_avg'] = torch.zeros_like(p)
[docs] def update_fn(self, p, d_p, state, group):
r"""Performs the preconditioned Langevin update.
.. math::
\theta^{k+1} = \theta^k - \frac{\epsilon_k}{2}\mathbf{D}(\theta^k)\hat{\nabla}U(\theta^k) + \sqrt{\epsilon_k\mathbf{D}(\theta^k)}\Delta W^{k+1}
where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`\epsilon_k` is the step size,
:math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and
:math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. The matrix :math:`\mathbf{D}` is referred to as
the preconditioner and takes the form
.. math::
\mathbf{D}(\theta^k) = \mathrm{diag}\left(\lambda \mathbf{I} + \sqrt{L(\theta^k)}\right)^{-1}
It should be noted that this algorithm is missing an extra term, :math:`\Gamma(\theta^k)`, which is discarded for computational efficiency.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary that has the key `square_avg`.
:param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`.
"""
step_size = group['step_size']
alpha = group['alpha']
square_avg = state['square_avg']
square_avg.mul_(alpha).addcmul_(d_p, d_p, value=1.0-alpha)
avg = square_avg.sqrt().add_(group['eps'])
langevin_noise = p.new(p.size()).normal_(mean=0, std=1) / np.sqrt(step_size)
p.add_(0.5 * d_p.div_(avg) + langevin_noise / torch.sqrt(avg), alpha=-step_size)
[docs]class DiffusionSGHMC(Diffusion):
"""Implements the stochastic gradient Hamiltonian Monte Carlo (SGMCMC) algorithm of the
`Stochastic Gradient Hamiltonian Monte Carlo <http://proceedings.mlr.press/v32/cheni14.pdf>`_ paper but without performing the change of variable.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param damping: Damping parameter. Default: 1.0.
"""
def __init__(
self,
params,
step_size,
weight_decay: float=0.0,
damping: float=0.01
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping)
[docs] def init_state(self, p, state, group):
"""Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
state["momentum"] = torch.zeros_like(p)
[docs] def resample_momentum(self):
"""Resamples the `momentum` state variable.
"""
for group in self.param_groups:
for p in group['params']:
lr=group['step_size']
state = self.state[p]
state["momentum"] = p.new(p.size()).normal_(mean=0,std=1)
[docs] def update_fn(self, p, d_p, state, group):
r"""Performs the SGHMC update.
.. math::
& \theta^{k+1} = \theta^k + \epsilon_k \mathbf{M}^{-1} v^k
& v^{k+1} = v^k - \epsilon_k\hat{\nabla}U(\theta^k) - \epsilon_k \mathbf{C}\mathbf{M}^{-1}v^k + \sqrt{2\mathbf{C}\epsilon_k} \, \Delta W^{k+1}
The mass and damping matrices are chosen as :math:`\mathbf{M} = \mathbf{I}` and :math:`\mathbf{C} = f\mathbf{I}`, where :math:`f` is the damping parameter.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary that has the key `square_avg`.
:param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`.
"""
damping=group['damping']
step_size=group['step_size']
momentum=state['momentum']
sigma = np.sqrt(2.0*damping)
sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size)
momentum_t = momentum.mul_(1.0-damping*step_size).add_(d_p - sample_t, alpha=-step_size)
p.add_(momentum_t, alpha=step_size)
[docs]class DiffusionCyclicSGHMC(Diffusion):
"""Implements the cyclic stochastic gradient Hamiltonian Monte Carlo (SGMCMC) algorithm proposed in the
`Cyclic Stochastic Gradient MCMC for Bayesian Deep Learning <https://arxiv.org/abs/1902.03932>`_ paper.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param damping: Damping parameter. Default: 1.0.
:param num_cycles: Number of cycles of the cyclic step size.
:param num_iterations: Total number of sampling iterations.
"""
def __init__(
self,
params,
step_size,
num_cycles,
num_iterations,
weight_decay: float=0.0,
damping: float=0.01
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping)
self.num_cycles = num_cycles
self.num_iterations = num_iterations
self.km = self.num_iterations // self.num_cycles
[docs] def init_state(self, p, state, group):
"""Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
state["momentum"] = torch.zeros_like(p)
[docs] def resample_momentum(self):
"""Resamples the `momentum` state variable.
"""
for group in self.param_groups:
for p in group['params']:
lr=group['step_size']
state = self.state[p]
state["momentum"] = p.new(p.size()).normal_(mean=0,std=1)
[docs] def update_fn(self, p, d_p, state, group):
r"""Performs the SGHMC update.
.. math::
& \theta^{k+1} = \theta^k + \epsilon_k \mathbf{M}^{-1} v^k
& v^{k+1} = v^k - \epsilon_k\hat{\nabla}U(\theta^k) - \epsilon_k \mathbf{C}\mathbf{M}^{-1}v^k + \sqrt{2\mathbf{C}\epsilon_k} \, \Delta W^{k+1}
The mass and damping matrices are chosen as :math:`\mathbf{M} = \mathbf{I}` and :math:`\mathbf{C} = f\mathbf{I}`, where :math:`f` is the damping parameter.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary that has the key `square_avg`.
:param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`.
"""
damping=group['damping']
momentum=state['momentum']
it = state["iteration"]-1
freq = torch.tensor(torch.pi * (it % self.km)/self.km)
lr_0 = group["step_size"]
step_size = 0.5*lr_0*(torch.cos(freq) + 1.0)
sigma = np.sqrt(2.0*damping)
sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size)
momentum_t = momentum.mul_(1.0-damping*step_size).add_(d_p - sample_t, alpha=-step_size)
p.add_(momentum_t, alpha=step_size)
[docs]class DiffusionSGHMCSA(Diffusion):
"""Implements the scale adaptive SGHMC algorithm proposed in the
`Bayesian optimization with robust Bayesian neural networks <https://papers.nips.cc/paper/6117-gaussian-process-bandit-optimisation-with-multi-fidelity-evaluations.pdf>`_ paper.
This algorithms uses the burnin phase to adapt its hyparameters (not including the step size).
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param num_burnin_steps: Number of burn in steps used to estimate the algorithm hyperparameters.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param mdecay: Momentum decay per time step. Default: 0.05.
"""
def __init__(
self,
params,
step_size,
num_burnin_steps: int,
weight_decay: float=0.0,
mdecay: float=0.05,
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, mdecay=mdecay, num_burnin_steps=num_burnin_steps)
[docs] def init_state(self, p, state, group):
"""Initializes a momentum and three additional state variables :math:`\tau, g, \hat{v}`.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
state["tau"] = torch.ones_like(p)
state["g"] = torch.ones_like(p)
state["v_hat"] = torch.ones_like(p)
state["momentum"] = torch.zeros_like(p)
[docs] def update_fn(self, p, d_p, state, group):
r"""Performs the scale adaptive SGHMC update.
.. math::
& \theta^{k+1} = \theta^k + v^k
& v^{k+1} = -\epsilon^2_k\hat{\mathbf{L}}^{-1/2}\hat{\nabla}U(\theta^k) - \epsilon_k \hat{\mathbf{L}}^{-1/2}\mathbf{C} v^k + \sqrt{2\epsilon_k^3\hat{\mathbf{L}}^{-1/2}\mathbf{C}\hat{\mathbf{L}}^{-1/2}-\epsilon_k^4\mathbf{I}} \, \Delta W^{k+1}
where :math:`\theta^k` denotes the network parameters at the k-th iteration, :math:`v^k` is the momentum, :math:`\epsilon_k` is the step size,
:math:`\hat{\nabla}U(\theta^k)` is the stochastic gradient of the potential function :math:`U`, and
:math:`\Delta W^{k+1}` is a centerered normalized Gaussian random variable. The stationary distribution of the discrete Markov chain
approximates the target distribution
.. math::
p_{\Theta,V}(\theta,v) = c_0 \exp\left(-U(\theta) - \frac{1}{2}v^T v \right)
The matrix :math:`\hat{\mathbf{L}}` denotes the second-order moment of the gradient which is estimated with an exponential moving average during the burnin phase.
The damping matrix :math:`\mathbf{C}` is chosen such that :math:`\epsilon_k \mathbf{C}\hat{\mathbf{L}} = 0.05\mathbf{I}`.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary that has the key `square_avg`.
:param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`.
"""
mdecay, step_size = group["mdecay"], group['step_size']
tau, g, v_hat = state["tau"], state["g"], state["v_hat"]
momentum = state["momentum"]
r_t = 1. / (tau + 1.)
minv_t = 1. / torch.sqrt(v_hat)
if state["iteration"] <= group["num_burnin_steps"]:
tau.add_(1. - tau * (g * g / v_hat))
g.add_(-g * r_t + r_t * d_p)
v_hat.add_(-v_hat * r_t + r_t * (d_p ** 2))
noise_scale = 2. * (step_size ** 2) * mdecay * minv_t - (step_size ** 4)
sigma = torch.sqrt(torch.clamp(noise_scale, min=1e-16))
sample_t = torch.normal(mean=0., std=sigma)
momentum_t = momentum.add_(- (step_size ** 2) * minv_t * d_p - mdecay * momentum + sample_t)
p.add_(momentum_t)
class DiffusionStormerVerletSGHMC(Diffusion):
"""Implements Hamiltonian dynamics solved with a Stormer-Verlet discretization scheme. It is an alternative to :class:`sgmcmc.diffusions.DiffusionSGHMC`.
:param params: Generator that yields a network parameters (model.parameters())
:param step_size: Step size of the discretized stochastid differential equation.
:param weight_decay: Optional L2 regularization that can be seen as placing a Gaussian prior distribution on the network parameters. Default: 0.0.
:param damping: Scalar damping parameter. Default: 1.0.
"""
def __init__(
self,
params,
step_size,
weight_decay: float=0.0,
damping: float=1.0
):
super().__init__(params=params, step_size=step_size, weight_decay=weight_decay, damping=damping)
def init_state(self, p, state, group):
"""Initializes a state variable `momentum` used by the SGHMC algorithm and denote by :math:`v^k`.
:param p: A parameter (weight or bias).
:param state: State dict.
:param group: Group dict.
"""
state["momentum"] = torch.zeros_like(p)
def update_fn(self, p, d_p, state, group):
r"""Performs the Stormer-Verlet update.
.. math::
& \theta^{k+1/2} = \theta^k + 0.5\epsilon_k v^{k}
& v^{k+1} = \frac{1-\alpha}{1+\alpha} v^k - \frac{\epsilon_k}{1+\alpha} \hat{\nabla}U(\theta^{k+1/2}) + \sqrt{\frac{f\epsilon}{1+\alpha}} \, \Delta W^{k+1}
& \theta^{k+1} = \theta^{k+1/2} + 0.5\epsilon_k v^{k+1}
where :math:`f` is the damping parameter and :math:`\alpha = f\epsilon_k/4`.
:param p: A parameter (weight or bias).
:param d_p: Gradient of the objective function with respect to parameter p.
:param state: State dictionary that has the key `square_avg`.
:param group: Group dictionary that has the keys `step_size`, `alpha`, and `eps`.
"""
step_size=group['step_size']
damping=group['damping']
momentum=state["momentum"]
varsigma = step_size*damping/4.0
beta = 1.0/(1.0+varsigma)
alpha = (1.0-varsigma)*beta
sigma = np.sqrt(damping)*beta
sample_t = p.new(p.size()).normal_(mean=0.0, std=sigma) / np.sqrt(step_size)
momentum_t = momentum.mul_(alpha).add_(beta*d_p - sample_t, alpha=-step_size)
p.add_(0.5*momentum_t, alpha=step_size)
p.add_(0.5*momentum, alpha=step_size)
if __name__ == "__main__":
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.f = torch.nn.Linear(10,10)
def forward(self,x):
return self.f(x)
model = Model()
diffusion = DiffusionSGHMC(params=model.parameters(), step_size=1.0, weight_decay=0.0, alpha=0.01)