batorch.rkhs#
batorch.rkhs.discrepancies#
Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics. |
|
Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution \(Q\) and an untractable distribution \(P\) (see, e.g., A kernelized Stein discrepancy for goodness-of-fit tests). |
|
Function that computes the squared sliced KSD using the methodology proposed in the Sliced Kernelized Stein Discrepancy paper. |
- batorch.rkhs.discrepancies.MMD(x: torch.Tensor, y: torch.Tensor, kernel: Optional[Kernel] = None)[source]#
Function that computes an estimator of the Maximum Mean Discrepancy (MMD) with a V-statistics.
The squared maximum mean discrepancy is defined as:
\[\mathrm{MMD}^2(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})\]Given two samples \(\{x_i\}_{i=1}^{n}\) and \(\{y_j\}_{j=1}^{m}\) from \(P\) and \(Q\), the V-estimator takes the form
\[\widehat{\mathrm{MMD}}^2(P,Q) = \frac{1}{n^2}\sum_{i,j=1}^{n}k(x_i,x_j) + \frac{1}{m^2}\sum_{i,j=1}^{m}k(y_i,y_j) - \frac{2}{mn}\sum_{i=1}^{n}\sum_{j=1}^{m}k(x_i,y_j)\]- Parameters
x – Matrix of shape \([n,d]\) corresponding to the \(n\) samples of the empirical distribution \(Q\).
grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log distribution \(P\) evaluted for the \(n\) samples \(x_1,\dots,x_n\).
kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
- Returns
Value of the squared MMD
- Return type
torch.Tensor
- batorch.rkhs.discrepancies.KSD(x: torch.Tensor, grad_x: torch.Tensor, kernel: Optional[Kernel] = None, stats='V', method='direct')[source]#
Function that computes the squared Kernelized Stein Discrepancy (KSD) between a distribution \(Q\) and an untractable distribution \(P\) (see, e.g., A kernelized Stein discrepancy for goodness-of-fit tests). The squared KSD between \(P\) and \(Q\) is defined as
\[\mathrm{KSD}^2(P,Q) := \mathbb{E}_{x\sim Q} \mathbb{E}_{x^{\prime}\sim Q} k_P(x,x^{\prime})\]where \(k_P\) denotes the Stein kernel defined as
\[\begin{split}k_P(x,y) := & \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log P(x) \cdot \nabla_y k(x,y) + \nabla_y \log P(y) \cdot \nabla_x k(x,y) \\ & + (\nabla_x \log P(x))\cdot (\nabla_y\log P(y)) k(x,y).\end{split}\]The underlying kernel function \(k\) used in the Stein kernel \(k_P(x,y)\) can be chosen as the inverse multi quadratic kernel or RBF kernel.
If the string stats is set to “V”, the squared KSD is estimated with a V-statistics:
\[\widehat{\mathrm{KSD}}^2_V(P,Q) = \frac{1}{n^2} \sum_{i=1}^{n}\sum_{j=1}^{n} k_P(x_i, x_j)\]where \(x_1, \dots, x_n\) are samples of the distribution \(Q\). Two methods are available for computing the above estimator.
If the string stats is set to “U”, the squared KSD is estimated with a U-statistics:
\[\widehat{\mathrm{KSD}}^2_U(P,Q) = \frac{1}{n(n-1)} \sum_{1\leq i < j \leq n} k_P(x_i, x_j)\,.\]The direct method (method=”direct”) simply computes and stores the full Stein matrix to compute the U- or V-estimators. If memory usage is the bottleneck, the second method (method=”iterative”) might be more appropriate as it computes the V-estimator iteratively as follows:
\[\widehat{\mathrm{KSD}}_V(P,Q_{i+1})^2 = \widehat{\mathrm{KSD}}_V(P,Q_i)^2 + \frac{1}{(i+1)^2}k_P(x_{i+1},x_{i+1}) + \frac{2}{(i+1)^2}\sum_{j=1}^{i}k_P(x_j,x_{i+1})\]for \(i = 0,\dots, n-1\), where
\[Q_{i+1}(x) := \frac{1}{i+1} \sum_{j=1}^{i} \delta_x(x_j) + \frac{1}{i+1} \delta_x(x_{i+1})\]and \(Q_n := Q\).
- Parameters
x – Matrix of shape \([n,d]\) corresponding to the \(n\) samples of the empirical distribution \(Q\).
grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log distribution \(P\) evaluted for the \(n\) samples \(x_1,\dots,x_n\).
kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
stats – Type of estimator, U or V, defaults to V
method – Chosen method for estimating the KSD (direct or iterative), default: direct.
- Returns
Vector of squared KSD values \(\widehat{\mathrm{KSD}}_U(P,Q_i)\) or \(\widehat{\mathrm{KSD}}_V(P,Q_i)\), \(i=1,\dots,n\).
- Return type
torch.Tensor
- batorch.rkhs.discrepancies.SlicedKSD(x: torch.Tensor, grad_x: torch.Tensor, r: torch.Tensor, g: torch.Tensor, kernel: Optional[Kernel] = None, stats='V')[source]#
Function that computes the squared sliced KSD using the methodology proposed in the Sliced Kernelized Stein Discrepancy paper.
Let \(U(\mathbb{S}^{D-1})\) be the uniform distribution over the hypersphere \(\mathbb{S}^{D-1}\) where \(D\) is the dimension of the network parameters. Then the squared sliced KSD between an untractable distribution \(P\) and a distribution \(Q\) is given by
\[\mathrm{SKSD}^2(P,Q) := \mathbb{E}_{r\sim U, g\sim U} [ \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime} \sim Q} h_{P,r,g}(x,x^{\prime}) ]\]where \(h_{p,r,g}\) is given by Equation (8) in the Sliced Kernelized Stein Discrepancy paper, i.e.,
\[h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)\]with \(s_p^r(x) = \nabla_x \log P(x)^T r\). The kernel function is a inverse multi-quadratic kernel if kernel=”imq”, and a RBF kernel if kernel=”rbf”.
The inner mathematical expectations are estimated with a V- or U-statistics (stats=”V”, or stats=”U”), while the outer mathematical expectation is estimated with a standard Monte Carlo estimator, i.e.,
\[\widehat{\mathrm{SKSD}}_V^2(P,Q) := \frac{1}{N \, n^2} \sum_{\ell=1}^{N} \sum_{i=1}^{n}\sum_{j=1}^{n} h_{P,r_\ell, g_\ell}(x_i,x_j)\]and
\[\widehat{\mathrm{SKSD}}_U^2(P,Q) := \frac{1}{N \, n(n-1)} \sum_{\ell=1}^{N} \sum_{1\leq i<j \leq n} h_{P,r_\ell, g_\ell}(x_i,x_j)\]where \(x_1, \dots, x_n\) are samples of the distribution \(Q\). The vectors \(r_1,\dots,r_{N}\) and \(g_1,\dots,g_N\) are samples of distribution \(U(\mathbb{S}^{D-1})\).
- Parameters
x – Matrix of shape \([n,d]\).
grad_x – Matrix of shape \([n,d]\) corresponding to the gradients of the log posterior.
n_slices – Number of random projections.
kernel – Kernel object such as ImqKernel, RbfKernel, or EnergyKernel. Defaults to EnergyKernel.
stats – Type of estimator, U or V, defaults to V
- Returns
Vector of estimated sliced KSD values \(\widehat{\mathrm{SKSD}}(P,Q_i)\), , \(i=1,\dots,n\).
- Return type
torch.Tensor
batorch.rkhs.kernels#
Base class for implementing kernel functions. |
|
Inverse multi-quadratric kernel. |
|
Distance-induced kernel function. |
- class batorch.rkhs.kernels.Kernel[source]#
Base class for implementing kernel functions.
- GramMatrix(x: torch.Tensor, y: torch.Tensor)[source]#
Function that computes the entries of the Gram matrix \(k(x_i,y_j)\) where \(x_i\) (resp. \(y_j\)) is the i-th line of \(x\) (resp. j-th line of \(y\)).
- Parameters
x – Matrix of shape \([m,d]\)
y – Matrix of shape \([n,d]\)
- Raises
NotImplementedError –
- SteinMatrix(y: torch.Tensor, grad_y: torch.Tensor)[source]#
Function that computes the entire Stein kernel \(k_P(x_i,x_j)\) defined as
\[k_P(x,y) = \nabla_x\cdot \nabla_y k(x,y) + \nabla_x \log p(x) \cdot \nabla_y k(x,y) + \nabla_y \log p(y) \cdot \nabla_x k(x,y) + (\nabla_x \log p(x))\cdot (\nabla_y\log p(y)) k(x,y)\]and with \(k(x,y)\) a chosen kernel function.
- Parameters
y – Matrix of samples of shape \([n,d]\).
grad_y – Matrix of shape \([n,d]\) gathering the gradients of the log posterior \(\nabla_{x_i}\log P(x_i)\).
- Raises
NotImplementedError –
- SlicedSteinMatrix(r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor)[source]#
This function computes the entries of the sliced Stein kernel matrix \(h_{P,r,g}\) given by
\[h_{P,r,g}(x,y) = s_p^r(x)k(x^Tg,y^Tg)s_p^r(y) + r^Tg \, s_p^r(y)\nabla_{x^Tg}k(x^Tg,y^Tg) + r^Tg \, s_p^r(x) \nabla_{y^Tg} k(x^Tg,y^Tg) + (r^Tg)^2\nabla_{x^Tg}\cdot\nabla_{y^Tg}k(x^Tg,y^Tg)\]with \(s_p^r(x) = \nabla_x \log P(x)^T r\) (see Sliced Kernelized Stein Discrepancy.)
- Parameters
r_got_g – Scalar product \(r^Tg\)
w – Projected samples \(w = x^Tg\).
grad_w – Projected gradients \(s_p^(x)\).
- Raises
NotImplementedError –
- Diagonal(y: torch.Tensor)[source]#
Function that computes the diagonal terms of the Gram matrix \(k(x_i,x_i)\).
- Parameters
y – Matrix of size \(n\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).
- Raises
NotImplementedError –
- Line(y: torch.Tensor, idx: int)[source]#
Function that can be used to compute one line of the Gram matrix \(k(y_i, y)\).
- Parameters
y – Matrix of shape \([n,d]\).
idx – Index of the line to be computed.
- Raises
NotImplementedError –
- LineSteinMatrix(y: torch.Tensor, grad_y: torch.Tensor, idx: int)[source]#
Function that returns one line of the Stein matrix \(k_p(x_i,x_j)\).
- Parameters
y – Matrix of size \(m\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).
grad_y – Gradients of the log posterior with respect to \(y\).
idx – Line to be computed.
- Raises
NotImplementedError –
- class batorch.rkhs.kernels.EnergyKernel[source]#
Distance-induced kernel function.
\[k(x,y) = \|x\|_2 + \|y\|_2 - \|x-y\|_2\]See, e.g., Stein points and Equivalence of distance-based and RKHS-based statistics in hypothesis testing.
- class batorch.rkhs.kernels.ImqKernel(c=1.0, beta=0.5, lengthscale=1.0)[source]#
Inverse multi-quadratric kernel.
\[k(x,y) = (c^2 + \|x-y\|_2^2/\ell^2)^{-\beta}\]- Parameters
c – Parameter \(c\), default value \(1.0\).
beta – Parameter \(\beta\), default value \(1/2\).
lengthscale – Lengthscale (bandwidth).
- GramMatrix(x: torch.Tensor, y: torch.Tensor)[source]#
Function that computes the entries of the Gram matrix for the inverse multi-quadratic kernel, i.e.,
\[k(x_i,x_j) = (c^2 + \|x_i-x_j\|_2^2/\ell^2)^{-\beta}\]- Parameters
x – Matrix of shape \([m,d]\)
y – Matrix of shape \([n,d]\)
- Returns
Entries of the Gram matrix \(k(x_i,y_j)\) of shape \([m,n]\)
- Diagonal(y: torch.Tensor)[source]#
Diagonal elements of the Gram matrix.
- Parameters
y – Matrix of shape \([n,d]\)
- Returns
Diagonal entries of the Gram matrix \(k(x_i,y_j)\) of shape \([m,n]\)
- SteinMatrix(y: torch.Tensor, grad_y: torch.Tensor)[source]#
This function implements a closed-form expression of the Stein kernel \(k_P\) for a inverse multi-quadratic kernel:
\[k(x_i,x_j) = q_f(x_i,x_j)^{-\beta}\,, \quad q_f(x_i, x_j) = c^2 + \|x_i-x_j\|^2_2/\ell^2\]The closed-form expression is given by
\[k_P(x_i,x_j) = -4\beta(\beta+1)\ell^{-4} \|x_i-x_j\|^2_2 q_f(x_i,x_j)^{-\beta-2} + 2\beta\ell^{-2} (d + A_{ij} + A_{ji}) q_f(x_i,x_j)^{-\beta-1} + (s_p(x_i)\cdot s_p(x_j)) k(x_i,x_j)\]where \(d\) is the dimension of \(x\), \(\ell\) is the bandwidth, and
\[s_p(x_i) = \nabla_{x_i} \log P(x_i)\,, \quad A_{ij} = \nabla_{x_i}\log P(x_i) \cdot (x_i - x_j)\]- Parameters
y – Matrix of shape \([n,d]\)
grad_y – Matrix of shape :math:[n,d]`, gradients of the log posterior
- Returns
Entries of the Stein matrix \(k_p(x_i,y_j)\) of shape \([n,n]\)
- SlicedSteinMatrix(r_dot_g: torch.Tensor, w: torch.Tensor, grad_w: torch.Tensor)[source]#
This function computes the entries of the sliced Stein kernel matrix \(h_{P,r,g}\) for the inverse multi quadratric kernel.
The closed-form expression is given by:
\[h_{P,r,g}(x_i,x_j) = -4(r^T g)^2 \beta(\beta+1)\ell^{-4} \|w_i-w_j\|_2^2 q_f(w_i,w_j)^{-\beta-2} + 2\beta\ell^{-2} (r^T g)( r^T g + A_{ij} + A_{ji}) q_f(w_i,w_j)^{-\beta-1} + (s_p^r(x_i) s_p^r(x_j)) q_f(w_i,w_j)^{-\beta}\]where
\[q_f(w_i,w_j) = c^2 + \ell^{-2}\|w_i-w_j\|^2_2\,, \quad w_i := g^T x_i\,, \quad s_p^r(x_i) = r^T \nabla_{x_i}\log P(x_i)\,,\]and
\[A_{ij} = s_p^r(x_i) (w_i - w_j)\]- Parameters
r_got_g – Scalar product \(r^Tg\)
w – Projected samples \(w = x^Tg\).
grad_w – Projected gradients \(s_p^r(x)\).
- Returns
Entries of the sliced Stein matrix \(h_{P,r,g}(x_i,y_j)\) of shape \([n,n]\)
- Line(y: torch.Tensor, idx: int)[source]#
Function that can be used to compute one line of the Gram matrix \(k(x_i, x_j)\).
- Parameters
y – Matrix of shape \([n,d]\).
idx – Index of the line to be computed.
- Returns
Entries the Gram matrix \(k(x_i, y_j)\).
- LineSteinMatrix(y: torch.Tensor, grad_y: torch.Tensor, idx: int)[source]#
Function that returns one line the Stein kernel, \(k_P(y_i, \cdot)\), see the function SteinMatrix for more details.
- Parameters
y – Matrix of size \(m\) by \(d\) containing \(n\) samples \(y_1, \dots, y_n\).
grad_y – Gradients of the log posterior with respect to \(y\).
- Returns
Entries the Gram matrix \(k(x_i, y_j)\).
batorch.rkhs.quantization#
Base class for implementing a quantization method. |
|
Class that implements quantization by greedely minimizing the maximum mean discrepancy. |
- batorch.rkhs.quantization.median_heuristic(samples: torch.Tensor)[source]#
Function that computes the median heuristic for a kernel’s lengthscale.
\[\ell := \mathrm{median}\{\|X_i-X_j\|: 1 \leq i \leq j \leq n\}\]- Parameters
samples – Matrix of size \(n\) by \(d\) where \(n\) is the number of samples and \(d\) the dimension of each samples.
- Returns
The scalar median heuristic that can be used as a lengthscale/bandwidth.
- class batorch.rkhs.quantization.Quantization(m: int)[source]#
Base class for implementing a quantization method. For mroe details, see for instance the paper Optimal quantisation of probability measures using maximum mean discrepancy.
- Parameters
m – Quantization size.
- class batorch.rkhs.quantization.QuantizationMMD(m: int, kernel: Optional[Union[omegaconf.DictConfig, Kernel]] = None)[source]#
Class that implements quantization by greedely minimizing the maximum mean discrepancy. For mroe details, see for instance the paper Optimal quantisation of probability measures using maximum mean discrepancy.
- Parameters
m – Quantization size.
kernel – Underlying kernel used in the MMD.
- select(data: dict, memory_is_bottleneck=True, device=None)[source]#
This function implements a simple algorithm that greedely minimizes the maximum mean discrepancy. Given a set of samples \(\{ x_i \}_{i=1}^{N}\) with empirical distribution
\[P(x) = \frac{1}{N}\sum_{i=1}^{N} \delta_x(x_i)\,,\]the goal is to select a subset of \(m < n\) samples \(\{ x_{\pi(\ell)} \}_{\ell=1}^{m}\) such that
\[Q_m(x) = \frac{1}{m}\sum_{i=1}^{m} \delta_x(x_\pi(i))\,, \quad Q_m \approx P\,.\]The list of selected indices \(\pi(1), \dots, \pi(m)\) is constructed iteratively. At the \(i+1\)-th iteration, the index \(\pi(i+1)\) is obtained by solving:
\[\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, \mathrm{MMD}^2(P,Q_j^{i+1})\]where \(\mathrm{MMD}(P,Q)\) denotes the maximum mean discrepancy between two distributions \(P\) and \(Q\),
\[\mathrm{MMD}(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})\]and \(Q^{i+1}_j(x)\) denotes the empirical distribution
\[Q^{i+1}_j(x) = \frac{1}{i+1}\sum_{\ell=1}^{i} \delta_x(x_{\pi(\ell)}) + \frac{1}{i+1} \delta_x(x_j)\]The default kernel function \(k: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\) is chosen as the energy kernel with parameters \(c=1\) and \(\beta=0.5\) but the user can specify a different kernel with an OmegaConf DictConfig.
By plugging the expression of \(P\) and \(Q^{i+1}_j\) into the squared \(\mathrm{MMD}\) and discarding all the terms that do not depend on \(x_j\), it can be shown that the optimization reduces to
\[\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, 2\sum_{\ell=1}^{i}k(x_{\pi(\ell)},x_j) + k(x_j,x_j) - \frac{2(i+1)}{N}\sum_{\ell=1}^{N}k(x_\ell,x_j)\]- Parameters
data – Dict with key “samples” that contains the samples in a torch.Tensor
memory_is_bottleneck – Boolean parameter. Set to true for high dimensional data.
device – Device to be used for the quantization algorithm. If none, the current device to which belong the data will be used.
- Returns
List of selected indices \(\pi(1), \dots, \pi(m)\).