Source code for batorch.rkhs.discrepancies

import torch
import logging
from tqdm import tqdm
from typing import Optional
from batorch.rkhs.kernels import EnergyKernel, Kernel, ImqKernel

log = logging.getLogger(__name__)

[docs]def MMD(x: torch.Tensor, y: torch.Tensor, kernel: Optional[Kernel]=None): r"""Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics. The squared maximum mean discrepancy is defined as: .. math:: \mathrm{MMD}^2(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) Given two samples :math:`\{x_i\}_{i=1}^{n}` and :math:`\{y_j\}_{j=1}^{m}` from :math:`P` and :math:`Q`, the V-estimator takes the form .. math:: \widehat{\mathrm{MMD}}^2(P,Q) = \frac{1}{n^2}\sum_{i,j=1}^{n}k(x_i,x_j) + \frac{1}{m^2}\sum_{i,j=1}^{m}k(y_i,y_j) - \frac{2}{mn}\sum_{i=1}^{n}\sum_{j=1}^{m}k(x_i,y_j) :param x: Matrix of shape :math:`[n,d]` corresponding to the :math:`n` samples of the empirical distribution :math:`Q`. :param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log distribution :math:`P` evaluted for the :math:`n` samples :math:`x_1,\dots,x_n`. :param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel. :return: Value of the squared MMD :rtype: torch.Tensor """ n, m = x.size(0), y.size(0) if kernel is None: kernel = EnergyKernel() kxx = kernel.GramMatrix(x,x) kyy = kernel.GramMatrix(y,y) kxy = kernel.GramMatrix(x,y) mmd = (1.0/(n**2))*kxx.sum() + (1.0/m**2)*kyy.sum() - (2.0/(m*n))*kxy.sum() return mmd
[docs]def KSD(x: torch.Tensor, grad_x: torch.Tensor, kernel: Optional[Kernel]=None, stats="V", method="direct"): r"""Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution :math:`Q` and an untractable distribution :math:`P` (see, e.g., `A kernelized Stein discrepancy for goodness-of-fit tests <http://proceedings.mlr.press/v48/liub16.html>`_). The squared KSD between :math:`P` and :math:`Q` is defined as .. math:: \mathrm{KSD}^2(P,Q) := \mathbb{E}_{x\sim Q} \mathbb{E}_{x^{\prime}\sim Q} k_P(x,x^{\prime}) where :math:`k_P` denotes the Stein kernel defined as .. math:: k_P(x,y) := & \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log P(x) \cdot \nabla_y k(x,y) + \nabla_y \log P(y) \cdot \nabla_x k(x,y) \\ & + (\nabla_x \log P(x))\cdot (\nabla_y\log P(y)) k(x,y). The underlying kernel function :math:`k` used in the Stein kernel :math:`k_P(x,y)` can be chosen as the inverse multi quadratic kernel or RBF kernel. If the string `stats` is set to "V", the squared KSD is estimated with a V-statistics: .. math:: \widehat{\mathrm{KSD}}^2_V(P,Q) = \frac{1}{n^2} \sum_{i=1}^{n}\sum_{j=1}^{n} k_P(x_i, x_j) where :math:`x_1, \dots, x_n` are samples of the distribution :math:`Q`. Two methods are available for computing the above estimator. If the string `stats` is set to "U", the squared KSD is estimated with a U-statistics: .. math:: \widehat{\mathrm{KSD}}^2_U(P,Q) = \frac{1}{n(n-1)} \sum_{1\leq i < j \leq n} k_P(x_i, x_j)\,. The direct method (method="direct") simply computes and stores the full Stein matrix to compute the U- or V-estimators. If memory usage is the bottleneck, the second method (method="iterative") might be more appropriate as it computes the V-estimator iteratively as follows: .. math:: \widehat{\mathrm{KSD}}_V(P,Q_{i+1})^2 = \widehat{\mathrm{KSD}}_V(P,Q_i)^2 + \frac{1}{(i+1)^2}k_P(x_{i+1},x_{i+1}) + \frac{2}{(i+1)^2}\sum_{j=1}^{i}k_P(x_j,x_{i+1}) for :math:`i = 0,\dots, n-1`, where .. math:: Q_{i+1}(x) := \frac{1}{i+1} \sum_{j=1}^{i} \delta_x(x_j) + \frac{1}{i+1} \delta_x(x_{i+1}) and :math:`Q_n := Q`. :param x: Matrix of shape :math:`[n,d]` corresponding to the :math:`n` samples of the empirical distribution :math:`Q`. :param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log distribution :math:`P` evaluted for the :math:`n` samples :math:`x_1,\dots,x_n`. :param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel. :param stats: Type of estimator, U or V, defaults to V :param method: Chosen method for estimating the KSD (direct or iterative), default: direct. :return: Vector of squared KSD values :math:`\widehat{\mathrm{KSD}}_U(P,Q_i)` or :math:`\widehat{\mathrm{KSD}}_V(P,Q_i)`\, :math:`i=1,\dots,n`. :rtype: torch.Tensor """ if method not in ["direct", "iterative"]: raise ValueError("Method should be equal to direct or iterative") if kernel is None: kernel=ImqKernel() if method=="direct": kp = kernel.SteinMatrix(x,grad_x) if stats=="V": ksd = kp.sum()/x.size(0)**2 elif stats=="U": ksd = 2.0*torch.sum(torch.diag(kp,diagonal=1)) / (x.size(0) * (x.size(0) - 1.0)) else: raise ValueError() elif method=="iterative": ps = 0.0 for i in range(x.size(0)): k0 = kernel.LineSteinMatrix(x[0:(i+1)], grad_x[0:(i+1)], i) ps += 2.0 * torch.sum(k0[0:i]) + k0[i] ksd = ps / (i+1)**2 else: raise ValueError() return ksd
[docs]def SlicedKSD(x: torch.Tensor, grad_x: torch.Tensor, r: torch.Tensor, g: torch.Tensor, kernel: Optional[Kernel]=None, stats="V"): r""" Function that computes the squared sliced KSD using the methodology proposed in the `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_ paper. Let :math:`U(\mathbb{S}^{D-1})` be the uniform distribution over the hypersphere :math:`\mathbb{S}^{D-1}` where :math:`D` is the dimension of the network parameters. Then the squared sliced KSD between an untractable distribution :math:`P` and a distribution :math:`Q` is given by .. math:: \mathrm{SKSD}^2(P,Q) := \mathbb{E}_{r\sim U, g\sim U} [ \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime} \sim Q} h_{P,r,g}(x,x^{\prime}) ] where :math:`h_{p,r,g}` is given by Equation (8) in the `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_ paper, i.e., .. math:: h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg) with :math:`s_p^r(x) = \nabla_x \log P(x)^T r`. The kernel function is a inverse multi-quadratic kernel if kernel="imq", and a RBF kernel if kernel="rbf". The inner mathematical expectations are estimated with a V- or U-statistics (stats="V", or stats="U"), while the outer mathematical expectation is estimated with a standard Monte Carlo estimator, i.e., .. math:: \widehat{\mathrm{SKSD}}_V^2(P,Q) := \frac{1}{N \, n^2} \sum_{\ell=1}^{N} \sum_{i=1}^{n}\sum_{j=1}^{n} h_{P,r_\ell, g_\ell}(x_i,x_j) and .. math:: \widehat{\mathrm{SKSD}}_U^2(P,Q) := \frac{1}{N \, n(n-1)} \sum_{\ell=1}^{N} \sum_{1\leq i<j \leq n} h_{P,r_\ell, g_\ell}(x_i,x_j) where :math:`x_1, \dots, x_n` are samples of the distribution :math:`Q`. The vectors :math:`r_1,\dots,r_{N}` and :math:`g_1,\dots,g_N` are samples of distribution :math:`U(\mathbb{S}^{D-1})`. :param x: Matrix of shape :math:`[n,d]`. :param grad_x: Matrix of shape :math:`[n,d]` corresponding to the gradients of the log posterior. :param n_slices: Number of random projections. :param kernel: Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel. :param stats: Type of estimator, U or V, defaults to V :return: Vector of estimated sliced KSD values :math:`\widehat{\mathrm{SKSD}}(P,Q_i)`, \, :math:`i=1,\dots,n`. :rtype: torch.Tensor """ if kernel is None: kernel=ImqKernel() if stats not in ["U", "V"]: raise ValueError() projected_samples = torch.matmul(g, x.T) # [n_slices, n_samples] projected_scores = torch.matmul(r, grad_x.T) # [n_slices, n_samples] R_dot_G = torch.sum(r*g,axis=1) if r.size()!=g.size(): raise ValueError("r and g should have the same sizes") n_slices = r.size(0) sliced_ksd = torch.zeros((n_slices,)) for s in range(n_slices): w, grad_w, r_dot_g = projected_samples[s], projected_scores[s], R_dot_G[s] # r_dot_g = torch.dot(r[s],g[s]) skp = kernel.SlicedSteinMatrix(r_dot_g, w, grad_w) if stats=="U": sliced_ksd[s] = 2.0*torch.sum(torch.diag(skp,diagonal=1)) / (x.size(0) * (x.size(0) - 1)) elif stats=="V": sliced_ksd[s] = skp.sum() / x.size(0)**2 # return torch.mean(sliced_ksd) return sliced_ksd
if __name__ == "__main__": pass