import torch
import logging
from tqdm import tqdm
from typing import Optional
from batorch.rkhs.kernels import EnergyKernel, Kernel, ImqKernel
log = logging.getLogger(__name__)
[docs]def MMD(x: torch.Tensor, y: torch.Tensor, kernel: Optional[Kernel]=None):
r"""Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics.
The squared maximum mean discrepancy is defined as:
.. math::
\mathrm{MMD}^2(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})
Given two samples :math:`\{x_i\}_{i=1}^{n}` and :math:`\{y_j\}_{j=1}^{m}` from :math:`P` and :math:`Q`, the V-estimator takes the form
.. math::
\widehat{\mathrm{MMD}}^2(P,Q) = \frac{1}{n^2}\sum_{i,j=1}^{n}k(x_i,x_j) + \frac{1}{m^2}\sum_{i,j=1}^{m}k(y_i,y_j) - \frac{2}{mn}\sum_{i=1}^{n}\sum_{j=1}^{m}k(x_i,y_j)
:param x: Matrix of shape :math:`[n,d]` corresponding to the :math:`n` samples of the empirical distribution :math:`Q`.
:param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log distribution :math:`P` evaluted for the :math:`n` samples :math:`x_1,\dots,x_n`.
:param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
:return: Value of the squared MMD
:rtype: torch.Tensor
"""
n, m = x.size(0), y.size(0)
if kernel is None:
kernel = EnergyKernel()
kxx = kernel.GramMatrix(x,x)
kyy = kernel.GramMatrix(y,y)
kxy = kernel.GramMatrix(x,y)
mmd = (1.0/(n**2))*kxx.sum() + (1.0/m**2)*kyy.sum() - (2.0/(m*n))*kxy.sum()
return mmd
[docs]def KSD(x: torch.Tensor, grad_x: torch.Tensor, kernel: Optional[Kernel]=None, stats="V", method="direct"):
r"""Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution :math:`Q` and an untractable distribution :math:`P` (see, e.g., `A kernelized Stein discrepancy for goodness-of-fit tests <http://proceedings.mlr.press/v48/liub16.html>`_).
The squared KSD between :math:`P` and :math:`Q` is defined as
.. math::
\mathrm{KSD}^2(P,Q) := \mathbb{E}_{x\sim Q} \mathbb{E}_{x^{\prime}\sim Q} k_P(x,x^{\prime})
where :math:`k_P` denotes the Stein kernel defined as
.. math::
k_P(x,y) := & \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log P(x) \cdot \nabla_y k(x,y) + \nabla_y \log P(y) \cdot \nabla_x k(x,y) \\
& + (\nabla_x \log P(x))\cdot (\nabla_y\log P(y)) k(x,y).
The underlying kernel function :math:`k` used in the Stein kernel :math:`k_P(x,y)` can be chosen as the inverse multi quadratic kernel or RBF kernel.
If the string `stats` is set to "V", the squared KSD is estimated with a V-statistics:
.. math::
\widehat{\mathrm{KSD}}^2_V(P,Q) = \frac{1}{n^2} \sum_{i=1}^{n}\sum_{j=1}^{n} k_P(x_i, x_j)
where :math:`x_1, \dots, x_n` are samples of the distribution :math:`Q`. Two methods are available for computing the above estimator.
If the string `stats` is set to "U", the squared KSD is estimated with a U-statistics:
.. math::
\widehat{\mathrm{KSD}}^2_U(P,Q) = \frac{1}{n(n-1)} \sum_{1\leq i < j \leq n} k_P(x_i, x_j)\,.
The direct method (method="direct") simply computes and stores the full Stein matrix to compute the U- or V-estimators. If memory usage is the bottleneck, the second method (method="iterative") might be more appropriate as it computes the V-estimator iteratively as follows:
.. math::
\widehat{\mathrm{KSD}}_V(P,Q_{i+1})^2 = \widehat{\mathrm{KSD}}_V(P,Q_i)^2 + \frac{1}{(i+1)^2}k_P(x_{i+1},x_{i+1}) + \frac{2}{(i+1)^2}\sum_{j=1}^{i}k_P(x_j,x_{i+1})
for :math:`i = 0,\dots, n-1`, where
.. math::
Q_{i+1}(x) := \frac{1}{i+1} \sum_{j=1}^{i} \delta_x(x_j) + \frac{1}{i+1} \delta_x(x_{i+1})
and :math:`Q_n := Q`.
:param x: Matrix of shape :math:`[n,d]` corresponding to the :math:`n` samples of the empirical distribution :math:`Q`.
:param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log distribution :math:`P` evaluted for the :math:`n` samples :math:`x_1,\dots,x_n`.
:param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
:param stats: Type of estimator, U or V, defaults to V
:param method: Chosen method for estimating the KSD (direct or iterative), default: direct.
:return: Vector of squared KSD values :math:`\widehat{\mathrm{KSD}}_U(P,Q_i)` or :math:`\widehat{\mathrm{KSD}}_V(P,Q_i)`\, :math:`i=1,\dots,n`.
:rtype: torch.Tensor
"""
if method not in ["direct", "iterative"]:
raise ValueError("Method should be equal to direct or iterative")
if kernel is None:
kernel=ImqKernel()
if method=="direct":
kp = kernel.SteinMatrix(x,grad_x)
if stats=="V":
ksd = kp.sum()/x.size(0)**2
elif stats=="U":
ksd = 2.0*torch.sum(torch.diag(kp,diagonal=1)) / (x.size(0) * (x.size(0) - 1.0))
else:
raise ValueError()
elif method=="iterative":
ps = 0.0
for i in range(x.size(0)):
k0 = kernel.LineSteinMatrix(x[0:(i+1)], grad_x[0:(i+1)], i)
ps += 2.0 * torch.sum(k0[0:i]) + k0[i]
ksd = ps / (i+1)**2
else:
raise ValueError()
return ksd
[docs]def SlicedKSD(x: torch.Tensor, grad_x: torch.Tensor, r: torch.Tensor, g: torch.Tensor, kernel: Optional[Kernel]=None, stats="V"):
r"""
Function that computes the squared sliced KSD using the methodology proposed in the `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_ paper.
Let :math:`U(\mathbb{S}^{D-1})` be the uniform distribution over the hypersphere :math:`\mathbb{S}^{D-1}` where
:math:`D` is the dimension of the network parameters. Then the squared sliced KSD between an
untractable distribution :math:`P` and a distribution :math:`Q` is given by
.. math::
\mathrm{SKSD}^2(P,Q) := \mathbb{E}_{r\sim U, g\sim U} [ \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime} \sim Q} h_{P,r,g}(x,x^{\prime}) ]
where :math:`h_{p,r,g}` is given by Equation (8) in the `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_ paper, i.e.,
.. math::
h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)
with :math:`s_p^r(x) = \nabla_x \log P(x)^T r`. The kernel function is a inverse multi-quadratic kernel if kernel="imq", and a RBF kernel if kernel="rbf".
The inner mathematical expectations are estimated with a V- or U-statistics (stats="V", or stats="U"), while the outer mathematical expectation is estimated with a standard Monte Carlo estimator, i.e.,
.. math::
\widehat{\mathrm{SKSD}}_V^2(P,Q) := \frac{1}{N \, n^2} \sum_{\ell=1}^{N} \sum_{i=1}^{n}\sum_{j=1}^{n} h_{P,r_\ell, g_\ell}(x_i,x_j)
and
.. math::
\widehat{\mathrm{SKSD}}_U^2(P,Q) := \frac{1}{N \, n(n-1)} \sum_{\ell=1}^{N} \sum_{1\leq i<j \leq n} h_{P,r_\ell, g_\ell}(x_i,x_j)
where :math:`x_1, \dots, x_n` are samples of the distribution :math:`Q`. The vectors :math:`r_1,\dots,r_{N}` and :math:`g_1,\dots,g_N` are samples of distribution :math:`U(\mathbb{S}^{D-1})`.
:param x: Matrix of shape :math:`[n,d]`.
:param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log posterior.
:param n_slices: Number of random projections.
:param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
:param stats: Type of estimator, U or V, defaults to V
:return: Vector of estimated sliced KSD values :math:`\widehat{\mathrm{SKSD}}(P,Q_i)`, \, :math:`i=1,\dots,n`.
:rtype: torch.Tensor
"""
if kernel is None:
kernel=ImqKernel()
if stats not in ["U", "V"]:
raise ValueError()
projected_samples = torch.matmul(g, x.T) # [n_slices, n_samples]
projected_scores = torch.matmul(r, grad_x.T) # [n_slices, n_samples]
R_dot_G = torch.sum(r*g,axis=1)
if r.size()!=g.size():
raise ValueError("r and g should have the same sizes")
n_slices = r.size(0)
sliced_ksd = torch.zeros((n_slices,))
for s in range(n_slices):
w, grad_w, r_dot_g = projected_samples[s], projected_scores[s], R_dot_G[s]
# r_dot_g = torch.dot(r[s],g[s])
skp = kernel.SlicedSteinMatrix(r_dot_g, w, grad_w)
if stats=="U":
sliced_ksd[s] = 2.0*torch.sum(torch.diag(skp,diagonal=1)) / (x.size(0) * (x.size(0) - 1))
elif stats=="V":
sliced_ksd[s] = skp.sum() / x.size(0)**2
# return torch.mean(sliced_ksd)
return sliced_ksd
if __name__ == "__main__":
pass