import torch
import torch.autograd as autograd
import torch.nn as nn
import numpy as np
import logging
log = logging.getLogger(__name__)
[docs]class Kernel:
"""Base class for implementing kernel functions.
"""
def __init__(self):
self.defaults = {}
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor):
r"""Function that computes the entries of the Gram matrix :math:`k(x_i,y_j)` where :math:`x_i` (resp. :math:`y_j`) is the i-th line of :math:`x` (resp. j-th line of :math:`y`).
:param x: Matrix of shape :math:`[m,d]`
:param y: Matrix of shape :math:`[n,d]`
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs] def SteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor):
r"""Function that computes the entire Stein kernel :math:`k_P(x_i,x_j)` defined as
.. math::
k_P(x,y) = \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log p(x) \cdot \nabla_y k(x,y) + \nabla_y \log p(y) \cdot \nabla_x k(x,y) + (\nabla_x \log p(x))\cdot (\nabla_y\log p(y)) k(x,y)
and with :math:`k(x,y)` a chosen kernel function.
:param y: Matrix of samples of shape :math:`[n,d]`.
:param grad_y: Matrix of shape :math:`[n,d]` gathering the gradients of the log posterior :math:`\nabla_{x_i}\log P(x_i)`.
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs] def SlicedSteinMatrix(self, r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor):
r"""This function computes the entries of the sliced Stein kernel matrix :math:`h_{P,r,g}` given by
.. math::
h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)
with :math:`s_p^r(x) = \nabla_x \log P(x)^T r` (see `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_.)
:param r_got_g: Scalar product :math:`r^Tg`
:param w: Projected samples :math:`w = x^Tg`.
:param grad_w: Projected gradients :math:`s_p^(x)`.
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs] def Diagonal(self, y: torch.Tensor):
"""Function that computes the diagonal terms of the Gram matrix :math:`k(x_i,x_i)`.
:param y: Matrix of size :math:`n` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`.
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs] def Line(self, y: torch.Tensor, idx: int):
"""Function that can be used to compute one line of the Gram matrix :math:`k(y_i, y)`.
:param y: Matrix of shape :math:`[n,d]`.
:param idx: Index of the line to be computed.
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs] def LineSteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor, idx: int):
r"""Function that returns one line of the Stein matrix :math:`k_p(x_i,x_j)`.
:param y: Matrix of size :math:`m` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`.
:param grad_y: Gradients of the log posterior with respect to :math:`y`.
:param idx: Line to be computed.
:raises NotImplementedError:
"""
raise NotImplementedError()
[docs]class EnergyKernel(Kernel):
"""Distance-induced kernel function.
.. math::
k(x,y) = \|x\|_2 + \|y\|_2 - \|x-y\|_2
See, e.g., `Stein points <http://proceedings.mlr.press/v80/chen18f/chen18f.pdf>`_ and
`Equivalence of distance-based and RKHS-based statistics in hypothesis testing <https://projecteuclid.org/journals/annals-of-statistics/volume-41/issue-5/Equivalence-of-distance-based-and-RKHS-based-statistics-in-hypothesis/10.1214/13-AOS1140.pdf>`_.
"""
def __init__(self):
super().__init__()
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor):
r"""This function implements the Gram maitrix of the energy kernel.
"""
sqdist = torch.cdist(x,y)
return torch.linalg.norm(x,dim=1).unsqueeze(1) + torch.linalg.norm(y,dim=1).unsqueeze(0) - sqdist
[docs] def Diagonal(self, y: torch.Tensor):
r"""This function returns the diagonal of the Gram matrix.
"""
t3 = torch.sqrt(torch.sum(y.T*y.T, axis=0))
return 2.0*t3
[docs] def Line(self, y: torch.Tensor, idx: int):
r"""Function that can be used to compute one line of the Gram matrix :math:`k(x_i, x_j)`.
:param y: Matrix of shape :math:`[n,d]`.
:param idx: Index of the line to be computed.
:return: Entries the Gram matrix :math:`k(x_i, y_j)`.
"""
x = y[idx].clone().unsqueeze(0)
dist = (x-y).T
t1 = -torch.sqrt(torch.sum(dist*dist, axis=0))
t2 = torch.sqrt(torch.sum(x.T*x.T, axis=0))
t3 = torch.sqrt(torch.sum(y.T*y.T, axis=0))
return t1 + t2 + t3
[docs]class ImqKernel(Kernel):
r"""Inverse multi-quadratric kernel.
.. math::
k(x,y) = (c^2 + \|x-y\|_2^2/\ell^2)^{-\beta}
:param c: Parameter :math:`c`, default value :math:`1.0`.
:param beta: Parameter :math:`\beta`, default value :math:`1/2`.
:param lengthscale: Lengthscale (bandwidth).
"""
def __init__(self, c=1.0, beta=0.5, lengthscale=1.0):
super().__init__()
self.defaults["c"] = c**2
self.defaults["beta"] = beta
self.defaults["linv"] = 1.0/lengthscale**2
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor):
r"""Function that computes the entries of the Gram matrix for the inverse multi-quadratic kernel, i.e.,
.. math::
k(x_i,x_j) = (c^2 + \|x_i-x_j\|_2^2/\ell^2)^{-\beta}
:param x: Matrix of shape :math:`[m,d]`
:param y: Matrix of shape :math:`[n,d]`
:return: Entries of the Gram matrix :math:`k(x_i,y_j)` of shape :math:`[m,n]`
"""
if len(x.shape)==1:
x = x.unsqueeze(0)
if len(y.shape)==1:
y = y.unsqueeze(0)
sqdist = torch.cdist(x,y)**2
qf = self.defaults["c"] + sqdist*self.defaults["linv"]
return (1.0/qf**self.defaults["beta"])
[docs] def Diagonal(self, y: torch.Tensor):
"""Diagonal elements of the Gram matrix.
:param y: Matrix of shape :math:`[n,d]`
:return: Diagonal entries of the Gram matrix :math:`k(x_i,y_j)` of shape :math:`[m,n]`
"""
qf = self.defaults["c"]
return (1.0/qf**self.defaults["beta"])
[docs] def SteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor):
r"""This function implements a closed-form expression of the Stein kernel :math:`k_P` for a inverse multi-quadratic kernel:
.. math::
k(x_i,x_j) = q_f(x_i,x_j)^{-\beta}\,, \quad q_f(x_i, x_j) = c^2 + \|x_i-x_j\|^2_2/\ell^2
The closed-form expression is given by
.. math::
k_P(x_i,x_j) = -4\beta(\beta+1)\ell^{-4} \|x_i-x_j\|^2_2 q_f(x_i,x_j)^{-\beta-2} + 2\beta\ell^{-2} (d + A_{ij} + A_{ji}) q_f(x_i,x_j)^{-\beta-1} + (s_p(x_i)\cdot s_p(x_j)) k(x_i,x_j)
where :math:`d` is the dimension of :math:`x`, :math:`\ell` is the bandwidth, and
.. math::
s_p(x_i) = \nabla_{x_i} \log P(x_i)\,, \quad A_{ij} = \nabla_{x_i}\log P(x_i) \cdot (x_i - x_j)
:param y: Matrix of shape :math:`[n,d]`
:param grad_y: Matrix of shape :math:[n,d]`, gradients of the log posterior
:return: Entries of the Stein matrix :math:`k_p(x_i,y_j)` of shape :math:`[n,n]`
"""
_, d = y.shape
linv = self.defaults["linv"]
sqdist = torch.cdist(y,y)**2
qf = self.defaults["c"] + sqdist*linv
Spx = torch.mm(grad_y,y.T)
Spxy = Spx.diag().unsqueeze(1) - Spx
beta0 = self.defaults["beta"]
beta1 = self.defaults["beta"] + 1.0
beta2 = self.defaults["beta"] + 2.0
t1 = -4.0*beta0*beta1*linv*linv*sqdist / qf ** beta2
t2 = 2.0*beta0*linv*(d + Spxy + Spxy.T) / qf ** beta1
t3 = torch.mm(grad_y,grad_y.T) / qf ** beta0
return t1 + t2 + t3
[docs] def SlicedSteinMatrix(self, r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor):
r"""This function computes the entries of the sliced Stein kernel matrix :math:`h_{P,r,g}` for the inverse multi quadratric kernel.
The closed-form expression is given by:
.. math::
h_{P,r,g}(x_i,x_j) = -4(r^T g)^2 \beta(\beta+1)\ell^{-4} \|w_i-w_j\|_2^2 q_f(w_i,w_j)^{-\beta-2} + 2\beta\ell^{-2} (r^T g)( r^T g + A_{ij} + A_{ji}) q_f(w_i,w_j)^{-\beta-1} + (s_p^r(x_i) s_p^r(x_j)) q_f(w_i,w_j)^{-\beta}
where
.. math::
q_f(w_i,w_j) = c^2 + \ell^{-2}\|w_i-w_j\|^2_2\,, \quad w_i := g^T x_i\,, \quad s_p^r(x_i) = r^T \nabla_{x_i}\log P(x_i)\,,
and
.. math::
A_{ij} = s_p^r(x_i) (w_i - w_j)
:param r_got_g: Scalar product :math:`r^Tg`
:param w: Projected samples :math:`w = x^Tg`.
:param grad_w: Projected gradients :math:`s_p^r(x)`.
:return: Entries of the sliced Stein matrix :math:`h_{P,r,g}(x_i,y_j)` of shape :math:`[n,n]`
"""
linv = self.defaults["linv"]
sqdist = torch.cdist(w.unsqueeze(1),w.unsqueeze(1))**2
qf = self.defaults["c"] + sqdist*linv
Spx = torch.mm(grad_w.unsqueeze(1),w.unsqueeze(1).T)
Spxy = (grad_w*w).unsqueeze(1) - Spx
beta0 = self.defaults["beta"]
beta1 = self.defaults["beta"] + 1.0
beta2 = self.defaults["beta"] + 2.0
t1 = -4.0*r_dot_g*r_dot_g*beta0*beta1*linv*linv*sqdist / qf ** beta2
t2 = 2.0*beta0*linv*r_dot_g*(r_dot_g + Spxy + Spxy.T) / qf ** beta1
t3 = torch.mm(grad_w.unsqueeze(1),grad_w.unsqueeze(1).T) / qf ** beta0
return t1 + t2 + t3
[docs] def Line(self, y: torch.Tensor, idx: int):
r"""Function that can be used to compute one line of the Gram matrix :math:`k(x_i, x_j)`.
:param y: Matrix of shape :math:`[n,d]`.
:param idx: Index of the line to be computed.
:return: Entries the Gram matrix :math:`k(x_i, y_j)`.
"""
linv = self.defaults["linv"]
x = y[idx].clone().unsqueeze(0)
dist = x.T - y.T
qf = self.defaults["c"] + torch.sum(dist*dist, axis=0)*linv
return 1.0/(qf ** self.defaults["beta"])
[docs] def LineSteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor, idx: int):
r"""Function that returns one line the Stein kernel, :math:`k_P(y_i, \cdot)`, see the function `SteinMatrix` for more details.
:param y: Matrix of size :math:`m` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`.
:param grad_y: Gradients of the log posterior with respect to :math:`y`.
:return: Entries the Gram matrix :math:`k(x_i, y_j)`.
"""
x = y[idx].clone().unsqueeze(0)
grad_x = grad_y[idx].unsqueeze(0)
_, d = x.shape
linv = self.defaults["linv"]
dist = (x-y).T
qf = self.defaults["c"] + torch.sum(dist*dist, axis=0)*linv
beta0 = self.defaults["beta"]
beta1 = self.defaults["beta"] + 1.0
beta2 = self.defaults["beta"] + 2.0
t1 = -4.0*beta0*beta1*torch.sum(dist*dist, axis=0)*linv*linv / (qf ** beta2)
t2 = 2.0*beta0*linv*(d + torch.sum((grad_x.T-grad_y.T)*dist, axis=0)) / (qf ** beta1)
t3 = torch.sum(grad_x.T * grad_y.T, axis=0) / (qf ** beta0)
return t1 + t2 + t3