batorch.rkhs#

batorch.rkhs.discrepancies#

batorch.rkhs.discrepancies.MMD

Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics.

batorch.rkhs.discrepancies.KSD

Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution \(Q\) and an untractable distribution \(P\) (see, e.g., A kernelized Stein discrepancy for goodness-of-fit tests).

batorch.rkhs.discrepancies.SlicedKSD

Function that computes the squared sliced KSD using the methodology proposed in the Sliced Kernelized Stein Discrepancy paper.

batorch.rkhs.discrepancies.MMD(x: torch.Tensor, y: torch.Tensor, kernel: Optional[Kernel] = None)[source]#

Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics.

The squared maximum mean discrepancy is defined as:

\[\mathrm{MMD}^2(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})\]

Given two samples \(\{x_i\}_{i=1}^{n}\) and \(\{y_j\}_{j=1}^{m}\) from \(P\) and \(Q\), the V-estimator takes the form

\[\widehat{\mathrm{MMD}}^2(P,Q) = \frac{1}{n^2}\sum_{i,j=1}^{n}k(x_i,x_j) + \frac{1}{m^2}\sum_{i,j=1}^{m}k(y_i,y_j) - \frac{2}{mn}\sum_{i=1}^{n}\sum_{j=1}^{m}k(x_i,y_j)\]
Parameters
  • x – Matrix of shape \([n,d]\) corresponding to the \(n\) samples of the empirical distribution \(Q\).

  • grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log distribution \(P\) evaluted for the \(n\) samples \(x_1,\dots,x_n\).

  • kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.

Returns

Value of the squared MMD

Return type

torch.Tensor

batorch.rkhs.discrepancies.KSD(x: torch.Tensor, grad_x: torch.Tensor, kernel: Optional[Kernel] = None, stats='V', method='direct')[source]#

Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution \(Q\) and an untractable distribution \(P\) (see, e.g., A kernelized Stein discrepancy for goodness-of-fit tests). The squared KSD between \(P\) and \(Q\) is defined as

\[\mathrm{KSD}^2(P,Q) := \mathbb{E}_{x\sim Q} \mathbb{E}_{x^{\prime}\sim Q} k_P(x,x^{\prime})\]

where \(k_P\) denotes the Stein kernel defined as

\[\begin{split}k_P(x,y) := & \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log P(x) \cdot \nabla_y k(x,y) + \nabla_y \log P(y) \cdot \nabla_x k(x,y) \\ & + (\nabla_x \log P(x))\cdot (\nabla_y\log P(y)) k(x,y).\end{split}\]

The underlying kernel function \(k\) used in the Stein kernel \(k_P(x,y)\) can be chosen as the inverse multi quadratic kernel or RBF kernel.

If the string stats is set to “V”, the squared KSD is estimated with a V-statistics:

\[\widehat{\mathrm{KSD}}^2_V(P,Q) = \frac{1}{n^2} \sum_{i=1}^{n}\sum_{j=1}^{n} k_P(x_i, x_j)\]

where \(x_1, \dots, x_n\) are samples of the distribution \(Q\). Two methods are available for computing the above estimator.

If the string stats is set to “U”, the squared KSD is estimated with a U-statistics:

\[\widehat{\mathrm{KSD}}^2_U(P,Q) = \frac{1}{n(n-1)} \sum_{1\leq i < j \leq n} k_P(x_i, x_j)\,.\]

The direct method (method=”direct”) simply computes and stores the full Stein matrix to compute the U- or V-estimators. If memory usage is the bottleneck, the second method (method=”iterative”) might be more appropriate as it computes the V-estimator iteratively as follows:

\[\widehat{\mathrm{KSD}}_V(P,Q_{i+1})^2 = \widehat{\mathrm{KSD}}_V(P,Q_i)^2 + \frac{1}{(i+1)^2}k_P(x_{i+1},x_{i+1}) + \frac{2}{(i+1)^2}\sum_{j=1}^{i}k_P(x_j,x_{i+1})\]

for \(i = 0,\dots, n-1\), where

\[Q_{i+1}(x) := \frac{1}{i+1} \sum_{j=1}^{i} \delta_x(x_j) + \frac{1}{i+1} \delta_x(x_{i+1})\]

and \(Q_n := Q\).

Parameters
  • x – Matrix of shape \([n,d]\) corresponding to the \(n\) samples of the empirical distribution \(Q\).

  • grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log distribution \(P\) evaluted for the \(n\) samples \(x_1,\dots,x_n\).

  • kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.

  • stats – Type of estimator, U or V, defaults to V

  • method – Chosen method for estimating the KSD (direct or iterative), default: direct.

Returns

Vector of squared KSD values \(\widehat{\mathrm{KSD}}_U(P,Q_i)\) or \(\widehat{\mathrm{KSD}}_V(P,Q_i)\), \(i=1,\dots,n\).

Return type

torch.Tensor

batorch.rkhs.discrepancies.SlicedKSD(x: torch.Tensor, grad_x: torch.Tensor, r: torch.Tensor, g: torch.Tensor, kernel: Optional[Kernel] = None, stats='V')[source]#

Function that computes the squared sliced KSD using the methodology proposed in the Sliced Kernelized Stein Discrepancy paper.

Let \(U(\mathbb{S}^{D-1})\) be the uniform distribution over the hypersphere \(\mathbb{S}^{D-1}\) where \(D\) is the dimension of the network parameters. Then the squared sliced KSD between an untractable distribution \(P\) and a distribution \(Q\) is given by

\[\mathrm{SKSD}^2(P,Q) := \mathbb{E}_{r\sim U, g\sim U} [ \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime} \sim Q} h_{P,r,g}(x,x^{\prime}) ]\]

where \(h_{p,r,g}\) is given by Equation (8) in the Sliced Kernelized Stein Discrepancy paper, i.e.,

\[h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)\]

with \(s_p^r(x) = \nabla_x \log P(x)^T r\). The kernel function is a inverse multi-quadratic kernel if kernel=”imq”, and a RBF kernel if kernel=”rbf”.

The inner mathematical expectations are estimated with a V- or U-statistics (stats=”V”, or stats=”U”), while the outer mathematical expectation is estimated with a standard Monte Carlo estimator, i.e.,

\[\widehat{\mathrm{SKSD}}_V^2(P,Q) := \frac{1}{N \, n^2} \sum_{\ell=1}^{N} \sum_{i=1}^{n}\sum_{j=1}^{n} h_{P,r_\ell, g_\ell}(x_i,x_j)\]

and

\[\widehat{\mathrm{SKSD}}_U^2(P,Q) := \frac{1}{N \, n(n-1)} \sum_{\ell=1}^{N} \sum_{1\leq i<j \leq n} h_{P,r_\ell, g_\ell}(x_i,x_j)\]

where \(x_1, \dots, x_n\) are samples of the distribution \(Q\). The vectors \(r_1,\dots,r_{N}\) and \(g_1,\dots,g_N\) are samples of distribution \(U(\mathbb{S}^{D-1})\).

Parameters
  • x – Matrix of shape \([n,d]\).

  • grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log posterior.

  • n_slices – Number of random projections.

  • kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.

  • stats – Type of estimator, U or V, defaults to V

Returns

Vector of estimated sliced KSD values \(\widehat{\mathrm{SKSD}}(P,Q_i)\), , \(i=1,\dots,n\).

Return type

torch.Tensor

batorch.rkhs.kernels#

batorch.rkhs.kernels.Kernel

Base class for implementing kernel functions.

batorch.rkhs.kernels.ImqKernel

Inverse multi-quadratric kernel.

batorch.rkhs.kernels.EnergyKernel

Distance-induced kernel function.

class batorch.rkhs.kernels.Kernel[source]#

Base class for implementing kernel functions.

GramMatrix(x: torch.Tensor, y: torch.Tensor)[source]#

Function that computes the entries of the Gram matrix \(k(x_i,y_j)\) where \(x_i\) (resp. \(y_j\)) is the i-th line of \(x\) (resp. j-th line of \(y\)).

Parameters
  • x – Matrix of shape \([m,d]\)

  • y – Matrix of shape \([n,d]\)

Raises

NotImplementedError

SteinMatrix(y: torch.Tensor, grad_y: torch.Tensor)[source]#

Function that computes the entire Stein kernel \(k_P(x_i,x_j)\) defined as

\[k_P(x,y) = \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log p(x) \cdot \nabla_y k(x,y) + \nabla_y \log p(y) \cdot \nabla_x k(x,y) + (\nabla_x \log p(x))\cdot (\nabla_y\log p(y)) k(x,y)\]

and with \(k(x,y)\) a chosen kernel function.

Parameters
  • y – Matrix of samples of shape \([n,d]\).

  • grad_y – Matrix of shape \([n,d]\) gathering the gradients of the log posterior \(\nabla_{x_i}\log P(x_i)\).

Raises

NotImplementedError

SlicedSteinMatrix(r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor)[source]#

This function computes the entries of the sliced Stein kernel matrix \(h_{P,r,g}\) given by

\[h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)\]

with \(s_p^r(x) = \nabla_x \log P(x)^T r\) (see Sliced Kernelized Stein Discrepancy.)

Parameters
  • r_got_g – Scalar product \(r^Tg\)

  • w – Projected samples \(w = x^Tg\).

  • grad_w – Projected gradients \(s_p^(x)\).

Raises

NotImplementedError

Diagonal(y: torch.Tensor)[source]#

Function that computes the diagonal terms of the Gram matrix \(k(x_i,x_i)\).

Parameters

y – Matrix of size \(n\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).

Raises

NotImplementedError

Line(y: torch.Tensor, idx: int)[source]#

Function that can be used to compute one line of the Gram matrix \(k(y_i, y)\).

Parameters
  • y – Matrix of shape \([n,d]\).

  • idx – Index of the line to be computed.

Raises

NotImplementedError

LineSteinMatrix(y: torch.Tensor, grad_y: torch.Tensor, idx: int)[source]#

Function that returns one line of the Stein matrix \(k_p(x_i,x_j)\).

Parameters
  • y – Matrix of size \(m\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).

  • grad_y – Gradients of the log posterior with respect to \(y\).

  • idx – Line to be computed.

Raises

NotImplementedError

class batorch.rkhs.kernels.EnergyKernel[source]#

Distance-induced kernel function.

\[k(x,y) = \|x\|_2 + \|y\|_2 - \|x-y\|_2\]

See, e.g., Stein points and Equivalence of distance-based and RKHS-based statistics in hypothesis testing.

GramMatrix(x: torch.Tensor, y: torch.Tensor)[source]#

This function implements the Gram maitrix of the energy kernel.

Diagonal(y: torch.Tensor)[source]#

This function returns the diagonal of the Gram matrix.

Line(y: torch.Tensor, idx: int)[source]#

Function that can be used to compute one line of the Gram matrix \(k(x_i, x_j)\).

Parameters
  • y – Matrix of shape \([n,d]\).

  • idx – Index of the line to be computed.

Returns

Entries the Gram matrix \(k(x_i, y_j)\).

class batorch.rkhs.kernels.ImqKernel(c=1.0, beta=0.5, lengthscale=1.0)[source]#

Inverse multi-quadratric kernel.

\[k(x,y) = (c^2 + \|x-y\|_2^2/\ell^2)^{-\beta}\]
Parameters
  • c – Parameter \(c\), default value \(1.0\).

  • beta – Parameter \(\beta\), default value \(1/2\).

  • lengthscale – Lengthscale (bandwidth).

GramMatrix(x: torch.Tensor, y: torch.Tensor)[source]#

Function that computes the entries of the Gram matrix for the inverse multi-quadratic kernel, i.e.,

\[k(x_i,x_j) = (c^2 + \|x_i-x_j\|_2^2/\ell^2)^{-\beta}\]
Parameters
  • x – Matrix of shape \([m,d]\)

  • y – Matrix of shape \([n,d]\)

Returns

Entries of the Gram matrix \(k(x_i,y_j)\) of shape \([m,n]\)

Diagonal(y: torch.Tensor)[source]#

Diagonal elements of the Gram matrix.

Parameters

y – Matrix of shape \([n,d]\)

Returns

Diagonal entries of the Gram matrix \(k(x_i,y_j)\) of shape \([m,n]\)

SteinMatrix(y: torch.Tensor, grad_y: torch.Tensor)[source]#

This function implements a closed-form expression of the Stein kernel \(k_P\) for a inverse multi-quadratic kernel:

\[k(x_i,x_j) = q_f(x_i,x_j)^{-\beta}\,, \quad q_f(x_i, x_j) = c^2 + \|x_i-x_j\|^2_2/\ell^2\]

The closed-form expression is given by

\[k_P(x_i,x_j) = -4\beta(\beta+1)\ell^{-4} \|x_i-x_j\|^2_2 q_f(x_i,x_j)^{-\beta-2} + 2\beta\ell^{-2} (d + A_{ij} + A_{ji}) q_f(x_i,x_j)^{-\beta-1} + (s_p(x_i)\cdot s_p(x_j)) k(x_i,x_j)\]

where \(d\) is the dimension of \(x\), \(\ell\) is the bandwidth, and

\[s_p(x_i) = \nabla_{x_i} \log P(x_i)\,, \quad A_{ij} = \nabla_{x_i}\log P(x_i) \cdot (x_i - x_j)\]
Parameters
  • y – Matrix of shape \([n,d]\)

  • grad_y – Matrix of shape :math:[n,d]`, gradients of the log posterior

Returns

Entries of the Stein matrix \(k_p(x_i,y_j)\) of shape \([n,n]\)

SlicedSteinMatrix(r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor)[source]#

This function computes the entries of the sliced Stein kernel matrix \(h_{P,r,g}\) for the inverse multi quadratric kernel.

The closed-form expression is given by:

\[h_{P,r,g}(x_i,x_j) = -4(r^T g)^2 \beta(\beta+1)\ell^{-4} \|w_i-w_j\|_2^2 q_f(w_i,w_j)^{-\beta-2} + 2\beta\ell^{-2} (r^T g)( r^T g + A_{ij} + A_{ji}) q_f(w_i,w_j)^{-\beta-1} + (s_p^r(x_i) s_p^r(x_j)) q_f(w_i,w_j)^{-\beta}\]

where

\[q_f(w_i,w_j) = c^2 + \ell^{-2}\|w_i-w_j\|^2_2\,, \quad w_i := g^T x_i\,, \quad s_p^r(x_i) = r^T \nabla_{x_i}\log P(x_i)\,,\]

and

\[A_{ij} = s_p^r(x_i) (w_i - w_j)\]
Parameters
  • r_got_g – Scalar product \(r^Tg\)

  • w – Projected samples \(w = x^Tg\).

  • grad_w – Projected gradients \(s_p^r(x)\).

Returns

Entries of the sliced Stein matrix \(h_{P,r,g}(x_i,y_j)\) of shape \([n,n]\)

Line(y: torch.Tensor, idx: int)[source]#

Function that can be used to compute one line of the Gram matrix \(k(x_i, x_j)\).

Parameters
  • y – Matrix of shape \([n,d]\).

  • idx – Index of the line to be computed.

Returns

Entries the Gram matrix \(k(x_i, y_j)\).

LineSteinMatrix(y: torch.Tensor, grad_y: torch.Tensor, idx: int)[source]#

Function that returns one line the Stein kernel, \(k_P(y_i, \cdot)\), see the function SteinMatrix for more details.

Parameters
  • y – Matrix of size \(m\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).

  • grad_y – Gradients of the log posterior with respect to \(y\).

Returns

Entries the Gram matrix \(k(x_i, y_j)\).

batorch.rkhs.quantization#

batorch.rkhs.quantization.Quantization

Base class for implementing a quantization method.

batorch.rkhs.quantization.QuantizationMMD

Class that implements quantization by greedely minimizing the maximum mean discrepancy.

batorch.rkhs.quantization.median_heuristic(samples: torch.Tensor)[source]#

Function that computes the median heuristic for a kernel’s lengthscale.

\[\ell := \mathrm{median}\{\|X_i-X_j\|: 1 \leq i \leq j \leq n\}\]
Parameters

samples – Matrix of size \(n\) by \(d\) where \(n\) is the number of samples and \(d\) the dimension of each samples.

Returns

The scalar median heuristic that can be used as a lengthscale/bandwidth.

class batorch.rkhs.quantization.Quantization(m: int)[source]#

Base class for implementing a quantization method. For mroe details, see for instance the paper Optimal quantisation of probability measures using maximum mean discrepancy.

Parameters

m – Quantization size.

select(data: dict)[source]#
class batorch.rkhs.quantization.QuantizationMMD(m: int, kernel: Optional[Union[omegaconf.DictConfig, Kernel]] = None)[source]#

Class that implements quantization by greedely minimizing the maximum mean discrepancy. For mroe details, see for instance the paper Optimal quantisation of probability measures using maximum mean discrepancy.

Parameters
  • m – Quantization size.

  • kernel – Underlying kernel used in the MMD.

select(data: dict, memory_is_bottleneck=True, device=None)[source]#

This function implements a simple algorithm that greedely minimizes the maximum mean discrepancy. Given a set of samples \(\{ x_i \}_{i=1}^{N}\) with empirical distribution

\[P(x) = \frac{1}{N}\sum_{i=1}^{N} \delta_x(x_i)\,,\]

the goal is to select a subset of \(m < n\) samples \(\{ x_{\pi(\ell)} \}_{\ell=1}^{m}\) such that

\[Q_m(x) = \frac{1}{m}\sum_{i=1}^{m} \delta_x(x_\pi(i))\,, \quad Q_m \approx P\,.\]

The list of selected indices \(\pi(1), \dots, \pi(m)\) is constructed iteratively. At the \(i+1\)-th iteration, the index \(\pi(i+1)\) is obtained by solving:

\[\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, \mathrm{MMD}^2(P,Q_j^{i+1})\]

where \(\mathrm{MMD}(P,Q)\) denotes the maximum mean discrepancy between two distributions \(P\) and \(Q\),

\[\mathrm{MMD}(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})\]

and \(Q^{i+1}_j(x)\) denotes the empirical distribution

\[Q^{i+1}_j(x) = \frac{1}{i+1}\sum_{\ell=1}^{i} \delta_x(x_{\pi(\ell)}) + \frac{1}{i+1} \delta_x(x_j)\]

The default kernel function \(k: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\) is chosen as the energy kernel with parameters \(c=1\) and \(\beta=0.5\) but the user can specify a different kernel with an OmegaConf DictConfig.

By plugging the expression of \(P\) and \(Q^{i+1}_j\) into the squared \(\mathrm{MMD}\) and discarding all the terms that do not depend on \(x_j\), it can be shown that the optimization reduces to

\[\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, 2\sum_{\ell=1}^{i}k(x_{\pi(\ell)},x_j) + k(x_j,x_j) - \frac{2(i+1)}{N}\sum_{\ell=1}^{N}k(x_\ell,x_j)\]
Parameters
  • data – Dict with key “samples” that contains the samples in a torch.Tensor

  • memory_is_bottleneck – Boolean parameter. Set to true for high dimensional data.

  • device – Device to be used for the quantization algorithm. If none, the current device to which belong the data will be used.

Returns

List of selected indices \(\pi(1), \dots, \pi(m)\).