Source code for batorch.rkhs.quantization

import torch

from typing import Optional, Union
from tqdm import tqdm
from omegaconf import DictConfig

from batorch.rkhs.kernels import Kernel, EnergyKernel

import logging
log = logging.getLogger(__name__)

[docs]def median_heuristic(samples: torch.Tensor): r"""Function that computes the median heuristic for a kernel's lengthscale. .. math:: \ell := \mathrm{median}\{\|X_i-X_j\|: 1 \leq i \leq j \leq n\} :param samples: Matrix of size :math:`n` by :math:`d` where :math:`n` is the number of samples and :math:`d` the dimension of each samples. :return: The scalar median heuristic that can be used as a lengthscale/bandwidth. """ med = torch.median(torch.nn.functional.pdist(samples, p=2)) return med
[docs]class Quantization(object): """Base class for implementing a quantization method. For mroe details, see for instance the paper `Optimal quantisation of probability measures using maximum mean discrepancy <https://arxiv.org/abs/2010.07064>`_. :param m: Quantization size. """ def __init__( self, m: int ): self.m = m
[docs] def select(self, data: dict): raise NotImplementedError()
[docs]class QuantizationMMD(Quantization): """Class that implements quantization by greedely minimizing the maximum mean discrepancy. For mroe details, see for instance the paper `Optimal quantisation of probability measures using maximum mean discrepancy <https://arxiv.org/abs/2010.07064>`_. :param m: Quantization size. :param kernel: Underlying kernel used in the MMD. """ def __init__( self, m: int, kernel: Optional[Union[DictConfig,Kernel]]=None, ): super().__init__(m=m) if kernel is None: self.kernel: Kernel = EnergyKernel() else: self.kernel = kernel
[docs] def select(self, data: dict, memory_is_bottleneck=True, device=None): r"""This function implements a simple algorithm that greedely minimizes the maximum mean discrepancy. Given a set of samples :math:`\{ x_i \}_{i=1}^{N}` with empirical distribution .. math:: P(x) = \frac{1}{N}\sum_{i=1}^{N} \delta_x(x_i)\,, the goal is to select a subset of :math:`m < n` samples :math:`\{ x_{\pi(\ell)} \}_{\ell=1}^{m}` such that .. math:: Q_m(x) = \frac{1}{m}\sum_{i=1}^{m} \delta_x(x_\pi(i))\,, \quad Q_m \approx P\,. The list of selected indices :math:`\pi(1), \dots, \pi(m)` is constructed iteratively. At the :math:`i+1`-th iteration, the index :math:`\pi(i+1)` is obtained by solving: .. math:: \pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, \mathrm{MMD}^2(P,Q_j^{i+1}) where :math:`\mathrm{MMD}(P,Q)` denotes the maximum mean discrepancy between two distributions :math:`P` and :math:`Q`, .. math:: \mathrm{MMD}(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) and :math:`Q^{i+1}_j(x)` denotes the empirical distribution .. math:: Q^{i+1}_j(x) = \frac{1}{i+1}\sum_{\ell=1}^{i} \delta_x(x_{\pi(\ell)}) + \frac{1}{i+1} \delta_x(x_j) The default kernel function :math:`k: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}` is chosen as the energy kernel with parameters :math:`c=1` and :math:`\beta=0.5` but the user can specify a different kernel with an OmegaConf DictConfig. By plugging the expression of :math:`P` and :math:`Q^{i+1}_j` into the squared :math:`\mathrm{MMD}` and discarding all the terms that do not depend on :math:`x_j`, it can be shown that the optimization reduces to .. math:: \pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, 2\sum_{\ell=1}^{i}k(x_{\pi(\ell)},x_j) + k(x_j,x_j) - \frac{2(i+1)}{N}\sum_{\ell=1}^{N}k(x_\ell,x_j) :param data: Dict with key "samples" that contains the samples in a torch.Tensor :param memory_is_bottleneck: Boolean parameter. Set to true for high dimensional data. :param device: Device to be used for the quantization algorithm. If none, the current device to which belong the data will be used. :return: List of selected indices :math:`\pi(1), \dots, \pi(m)`. """ samples = data["samples"] if device is not None: samples = samples.to(device) n = samples.shape[0] m = self.m if m>n: m = n idx = torch.zeros(m, dtype=torch.int32) k0 = torch.zeros((n,m)).to(samples.device) # diagonal term: k(\theta_i, \theta_i) k0[:,0] = self.kernel.Diagonal(samples) if memory_is_bottleneck: # row mean: (1/N) \sum_{j=1}^{N} k(\theta_i, \theta_j) \forall i \in \{1,\dots,N\} k0_mean = torch.zeros((n,)) for i in tqdm(range(n), desc="Computing row means"): k0_mean[i] = torch.mean(self.kernel.Line(samples, i)) else: GramMatrix = self.kernel.GramMatrix(x=samples, y=samples) k0_mean = torch.mean(GramMatrix, dim=1) idx[0] = torch.argmin(k0[:,0]-2.0*k0_mean) for i in tqdm(range(1,m), "Greedy minimization"): k0[:,i] = self.kernel.Line(samples, idx[i-1]) # argmin of: k(\theta, theta) + 2 \sum_{i=1}^t k(\theta, \theta_i) - 2.0*(t+1)*k0_mean idx[i] = torch.argmin(k0[:,0] + 2.0*torch.sum(k0[:,1:(i+1)], axis=1) - 2.0*(i+1)*k0_mean) return idx
if __name__ == "__main__": x = torch.randn(20000,2) quantMMD = QuantizationMMD(m=200) idxMMD = quantMMD.select({"samples": x}).numpy().tolist()