Source code for batorch.rkhs.kernels

import torch
import torch.autograd as autograd
import torch.nn as nn
import numpy as np

import logging
log = logging.getLogger(__name__)

[docs]class Kernel: """Base class for implementing kernel functions. """ def __init__(self): self.defaults = {}
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor): r"""Function that computes the entries of the Gram matrix :math:`k(x_i,y_j)` where :math:`x_i` (resp. :math:`y_j`) is the i-th line of :math:`x` (resp. j-th line of :math:`y`). :param x: Matrix of shape :math:`[m,d]` :param y: Matrix of shape :math:`[n,d]` :raises NotImplementedError: """ raise NotImplementedError()
[docs] def SteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor): r"""Function that computes the entire Stein kernel :math:`k_P(x_i,x_j)` defined as .. math:: k_P(x,y) = \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log p(x) \cdot \nabla_y k(x,y) + \nabla_y \log p(y) \cdot \nabla_x k(x,y) + (\nabla_x \log p(x))\cdot (\nabla_y\log p(y)) k(x,y) and with :math:`k(x,y)` a chosen kernel function. :param y: Matrix of samples of shape :math:`[n,d]`. :param grad_y: Matrix of shape :math:`[n,d]` gathering the gradients of the log posterior :math:`\nabla_{x_i}\log P(x_i)`. :raises NotImplementedError: """ raise NotImplementedError()
[docs] def SlicedSteinMatrix(self, r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor): r"""This function computes the entries of the sliced Stein kernel matrix :math:`h_{P,r,g}` given by .. math:: h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg) with :math:`s_p^r(x) = \nabla_x \log P(x)^T r` (see `Sliced Kernelized Stein Discrepancy <https://arxiv.org/abs/2006.16531>`_.) :param r_got_g: Scalar product :math:`r^Tg` :param w: Projected samples :math:`w = x^Tg`. :param grad_w: Projected gradients :math:`s_p^(x)`. :raises NotImplementedError: """ raise NotImplementedError()
[docs] def Diagonal(self, y: torch.Tensor): """Function that computes the diagonal terms of the Gram matrix :math:`k(x_i,x_i)`. :param y: Matrix of size :math:`n` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`. :raises NotImplementedError: """ raise NotImplementedError()
[docs] def Line(self, y: torch.Tensor, idx: int): """Function that can be used to compute one line of the Gram matrix :math:`k(y_i, y)`. :param y: Matrix of shape :math:`[n,d]`. :param idx: Index of the line to be computed. :raises NotImplementedError: """ raise NotImplementedError()
[docs] def LineSteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor, idx: int): r"""Function that returns one line of the Stein matrix :math:`k_p(x_i,x_j)`. :param y: Matrix of size :math:`m` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`. :param grad_y: Gradients of the log posterior with respect to :math:`y`. :param idx: Line to be computed. :raises NotImplementedError: """ raise NotImplementedError()
[docs]class EnergyKernel(Kernel): """Distance-induced kernel function. .. math:: k(x,y) = \|x\|_2 + \|y\|_2 - \|x-y\|_2 See, e.g., `Stein points <http://proceedings.mlr.press/v80/chen18f/chen18f.pdf>`_ and `Equivalence of distance-based and RKHS-based statistics in hypothesis testing <https://projecteuclid.org/journals/annals-of-statistics/volume-41/issue-5/Equivalence-of-distance-based-and-RKHS-based-statistics-in-hypothesis/10.1214/13-AOS1140.pdf>`_. """ def __init__(self): super().__init__()
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor): r"""This function implements the Gram maitrix of the energy kernel. """ sqdist = torch.cdist(x,y) return torch.linalg.norm(x,dim=1).unsqueeze(1) + torch.linalg.norm(y,dim=1).unsqueeze(0) - sqdist
[docs] def Diagonal(self, y: torch.Tensor): r"""This function returns the diagonal of the Gram matrix. """ t3 = torch.sqrt(torch.sum(y.T*y.T, axis=0)) return 2.0*t3
[docs] def Line(self, y: torch.Tensor, idx: int): r"""Function that can be used to compute one line of the Gram matrix :math:`k(x_i, x_j)`. :param y: Matrix of shape :math:`[n,d]`. :param idx: Index of the line to be computed. :return: Entries the Gram matrix :math:`k(x_i, y_j)`. """ x = y[idx].clone().unsqueeze(0) dist = (x-y).T t1 = -torch.sqrt(torch.sum(dist*dist, axis=0)) t2 = torch.sqrt(torch.sum(x.T*x.T, axis=0)) t3 = torch.sqrt(torch.sum(y.T*y.T, axis=0)) return t1 + t2 + t3
[docs]class ImqKernel(Kernel): r"""Inverse multi-quadratric kernel. .. math:: k(x,y) = (c^2 + \|x-y\|_2^2/\ell^2)^{-\beta} :param c: Parameter :math:`c`, default value :math:`1.0`. :param beta: Parameter :math:`\beta`, default value :math:`1/2`. :param lengthscale: Lengthscale (bandwidth). """ def __init__(self, c=1.0, beta=0.5, lengthscale=1.0): super().__init__() self.defaults["c"] = c**2 self.defaults["beta"] = beta self.defaults["linv"] = 1.0/lengthscale**2
[docs] def GramMatrix(self, x: torch.Tensor, y: torch.Tensor): r"""Function that computes the entries of the Gram matrix for the inverse multi-quadratic kernel, i.e., .. math:: k(x_i,x_j) = (c^2 + \|x_i-x_j\|_2^2/\ell^2)^{-\beta} :param x: Matrix of shape :math:`[m,d]` :param y: Matrix of shape :math:`[n,d]` :return: Entries of the Gram matrix :math:`k(x_i,y_j)` of shape :math:`[m,n]` """ if len(x.shape)==1: x = x.unsqueeze(0) if len(y.shape)==1: y = y.unsqueeze(0) sqdist = torch.cdist(x,y)**2 qf = self.defaults["c"] + sqdist*self.defaults["linv"] return (1.0/qf**self.defaults["beta"])
[docs] def Diagonal(self, y: torch.Tensor): """Diagonal elements of the Gram matrix. :param y: Matrix of shape :math:`[n,d]` :return: Diagonal entries of the Gram matrix :math:`k(x_i,y_j)` of shape :math:`[m,n]` """ qf = self.defaults["c"] return (1.0/qf**self.defaults["beta"])
[docs] def SteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor): r"""This function implements a closed-form expression of the Stein kernel :math:`k_P` for a inverse multi-quadratic kernel: .. math:: k(x_i,x_j) = q_f(x_i,x_j)^{-\beta}\,, \quad q_f(x_i, x_j) = c^2 + \|x_i-x_j\|^2_2/\ell^2 The closed-form expression is given by .. math:: k_P(x_i,x_j) = -4\beta(\beta+1)\ell^{-4} \|x_i-x_j\|^2_2 q_f(x_i,x_j)^{-\beta-2} + 2\beta\ell^{-2} (d + A_{ij} + A_{ji}) q_f(x_i,x_j)^{-\beta-1} + (s_p(x_i)\cdot s_p(x_j)) k(x_i,x_j) where :math:`d` is the dimension of :math:`x`, :math:`\ell` is the bandwidth, and .. math:: s_p(x_i) = \nabla_{x_i} \log P(x_i)\,, \quad A_{ij} = \nabla_{x_i}\log P(x_i) \cdot (x_i - x_j) :param y: Matrix of shape :math:`[n,d]` :param grad_y: Matrix of shape :math:[n,d]`, gradients of the log posterior :return: Entries of the Stein matrix :math:`k_p(x_i,y_j)` of shape :math:`[n,n]` """ _, d = y.shape linv = self.defaults["linv"] sqdist = torch.cdist(y,y)**2 qf = self.defaults["c"] + sqdist*linv Spx = torch.mm(grad_y,y.T) Spxy = Spx.diag().unsqueeze(1) - Spx beta0 = self.defaults["beta"] beta1 = self.defaults["beta"] + 1.0 beta2 = self.defaults["beta"] + 2.0 t1 = -4.0*beta0*beta1*linv*linv*sqdist / qf ** beta2 t2 = 2.0*beta0*linv*(d + Spxy + Spxy.T) / qf ** beta1 t3 = torch.mm(grad_y,grad_y.T) / qf ** beta0 return t1 + t2 + t3
[docs] def SlicedSteinMatrix(self, r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor): r"""This function computes the entries of the sliced Stein kernel matrix :math:`h_{P,r,g}` for the inverse multi quadratric kernel. The closed-form expression is given by: .. math:: h_{P,r,g}(x_i,x_j) = -4(r^T g)^2 \beta(\beta+1)\ell^{-4} \|w_i-w_j\|_2^2 q_f(w_i,w_j)^{-\beta-2} + 2\beta\ell^{-2} (r^T g)( r^T g + A_{ij} + A_{ji}) q_f(w_i,w_j)^{-\beta-1} + (s_p^r(x_i) s_p^r(x_j)) q_f(w_i,w_j)^{-\beta} where .. math:: q_f(w_i,w_j) = c^2 + \ell^{-2}\|w_i-w_j\|^2_2\,, \quad w_i := g^T x_i\,, \quad s_p^r(x_i) = r^T \nabla_{x_i}\log P(x_i)\,, and .. math:: A_{ij} = s_p^r(x_i) (w_i - w_j) :param r_got_g: Scalar product :math:`r^Tg` :param w: Projected samples :math:`w = x^Tg`. :param grad_w: Projected gradients :math:`s_p^r(x)`. :return: Entries of the sliced Stein matrix :math:`h_{P,r,g}(x_i,y_j)` of shape :math:`[n,n]` """ linv = self.defaults["linv"] sqdist = torch.cdist(w.unsqueeze(1),w.unsqueeze(1))**2 qf = self.defaults["c"] + sqdist*linv Spx = torch.mm(grad_w.unsqueeze(1),w.unsqueeze(1).T) Spxy = (grad_w*w).unsqueeze(1) - Spx beta0 = self.defaults["beta"] beta1 = self.defaults["beta"] + 1.0 beta2 = self.defaults["beta"] + 2.0 t1 = -4.0*r_dot_g*r_dot_g*beta0*beta1*linv*linv*sqdist / qf ** beta2 t2 = 2.0*beta0*linv*r_dot_g*(r_dot_g + Spxy + Spxy.T) / qf ** beta1 t3 = torch.mm(grad_w.unsqueeze(1),grad_w.unsqueeze(1).T) / qf ** beta0 return t1 + t2 + t3
[docs] def Line(self, y: torch.Tensor, idx: int): r"""Function that can be used to compute one line of the Gram matrix :math:`k(x_i, x_j)`. :param y: Matrix of shape :math:`[n,d]`. :param idx: Index of the line to be computed. :return: Entries the Gram matrix :math:`k(x_i, y_j)`. """ linv = self.defaults["linv"] x = y[idx].clone().unsqueeze(0) dist = x.T - y.T qf = self.defaults["c"] + torch.sum(dist*dist, axis=0)*linv return 1.0/(qf ** self.defaults["beta"])
[docs] def LineSteinMatrix(self, y: torch.Tensor, grad_y: torch.Tensor, idx: int): r"""Function that returns one line the Stein kernel, :math:`k_P(y_i, \cdot)`, see the function `SteinMatrix` for more details. :param y: Matrix of size :math:`m` by :math:`d` containing :math:`n` samples :math:`y_1, \dots, y_n`. :param grad_y: Gradients of the log posterior with respect to :math:`y`. :return: Entries the Gram matrix :math:`k(x_i, y_j)`. """ x = y[idx].clone().unsqueeze(0) grad_x = grad_y[idx].unsqueeze(0) _, d = x.shape linv = self.defaults["linv"] dist = (x-y).T qf = self.defaults["c"] + torch.sum(dist*dist, axis=0)*linv beta0 = self.defaults["beta"] beta1 = self.defaults["beta"] + 1.0 beta2 = self.defaults["beta"] + 2.0 t1 = -4.0*beta0*beta1*torch.sum(dist*dist, axis=0)*linv*linv / (qf ** beta2) t2 = 2.0*beta0*linv*(d + torch.sum((grad_x.T-grad_y.T)*dist, axis=0)) / (qf ** beta1) t3 = torch.sum(grad_x.T * grad_y.T, axis=0) / (qf ** beta0) return t1 + t2 + t3