import torch
from typing import Optional, Union
from tqdm import tqdm
from omegaconf import DictConfig
from batorch.rkhs.kernels import Kernel, EnergyKernel
import logging
log = logging.getLogger(__name__)
[docs]class Quantization(object):
"""Base class for implementing a quantization method. For mroe details, see for instance the paper
`Optimal quantisation of probability measures using maximum mean discrepancy <https://arxiv.org/abs/2010.07064>`_.
:param m: Quantization size.
"""
def __init__(
self,
m: int
):
self.m = m
[docs] def select(self, data: dict):
raise NotImplementedError()
[docs]class QuantizationMMD(Quantization):
"""Class that implements quantization by greedely minimizing the maximum mean discrepancy. For mroe details, see for instance the paper
`Optimal quantisation of probability measures using maximum mean discrepancy <https://arxiv.org/abs/2010.07064>`_.
:param m: Quantization size.
:param kernel: Underlying kernel used in the MMD.
"""
def __init__(
self,
m: int,
kernel: Optional[Union[DictConfig,Kernel]]=None,
):
super().__init__(m=m)
if kernel is None:
self.kernel: Kernel = EnergyKernel()
else:
self.kernel = kernel
[docs] def select(self, data: dict, memory_is_bottleneck=True, device=None):
r"""This function implements a simple algorithm that greedely minimizes the maximum mean discrepancy. Given a set of samples :math:`\{ x_i \}_{i=1}^{N}` with empirical distribution
.. math::
P(x) = \frac{1}{N}\sum_{i=1}^{N} \delta_x(x_i)\,,
the goal is to select a subset of :math:`m < n` samples :math:`\{ x_{\pi(\ell)} \}_{\ell=1}^{m}` such that
.. math::
Q_m(x) = \frac{1}{m}\sum_{i=1}^{m} \delta_x(x_\pi(i))\,, \quad Q_m \approx P\,.
The list of selected indices :math:`\pi(1), \dots, \pi(m)` is constructed iteratively. At the :math:`i+1`-th iteration, the index :math:`\pi(i+1)` is
obtained by solving:
.. math::
\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, \mathrm{MMD}^2(P,Q_j^{i+1})
where :math:`\mathrm{MMD}(P,Q)` denotes the maximum mean discrepancy between two distributions :math:`P` and :math:`Q`,
.. math::
\mathrm{MMD}(P,Q) = \mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim P}k(x,x^{\prime}) + \mathbb{E}_{x\sim Q}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime}) - 2\mathbb{E}_{x\sim P}\mathbb{E}_{x^{\prime}\sim Q}k(x,x^{\prime})
and :math:`Q^{i+1}_j(x)` denotes the empirical distribution
.. math::
Q^{i+1}_j(x) = \frac{1}{i+1}\sum_{\ell=1}^{i} \delta_x(x_{\pi(\ell)}) + \frac{1}{i+1} \delta_x(x_j)
The default kernel function :math:`k: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}` is chosen as the energy kernel with parameters :math:`c=1` and :math:`\beta=0.5`
but the user can specify a different kernel with an OmegaConf DictConfig.
By plugging the expression of :math:`P` and :math:`Q^{i+1}_j` into the squared :math:`\mathrm{MMD}` and discarding all the terms that do not depend on :math:`x_j`, it can be shown that
the optimization reduces to
.. math::
\pi(i+1) = \underset{j\in\{1,\dots,N\}}{\mathrm{argmin}} \, 2\sum_{\ell=1}^{i}k(x_{\pi(\ell)},x_j) + k(x_j,x_j) - \frac{2(i+1)}{N}\sum_{\ell=1}^{N}k(x_\ell,x_j)
:param data: Dict with key "samples" that contains the samples in a torch.Tensor
:param memory_is_bottleneck: Boolean parameter. Set to true for high dimensional data.
:param device: Device to be used for the quantization algorithm. If none, the current device to which belong the data will be used.
:return: List of selected indices :math:`\pi(1), \dots, \pi(m)`.
"""
samples = data["samples"]
if device is not None:
samples = samples.to(device)
n = samples.shape[0]
m = self.m
if m>n:
m = n
idx = torch.zeros(m, dtype=torch.int32)
k0 = torch.zeros((n,m)).to(samples.device)
# diagonal term: k(\theta_i, \theta_i)
k0[:,0] = self.kernel.Diagonal(samples)
if memory_is_bottleneck:
# row mean: (1/N) \sum_{j=1}^{N} k(\theta_i, \theta_j) \forall i \in \{1,\dots,N\}
k0_mean = torch.zeros((n,))
for i in tqdm(range(n), desc="Computing row means"):
k0_mean[i] = torch.mean(self.kernel.Line(samples, i))
else:
GramMatrix = self.kernel.GramMatrix(x=samples, y=samples)
k0_mean = torch.mean(GramMatrix, dim=1)
idx[0] = torch.argmin(k0[:,0]-2.0*k0_mean)
for i in tqdm(range(1,m), "Greedy minimization"):
k0[:,i] = self.kernel.Line(samples, idx[i-1])
# argmin of: k(\theta, theta) + 2 \sum_{i=1}^t k(\theta, \theta_i) - 2.0*(t+1)*k0_mean
idx[i] = torch.argmin(k0[:,0] + 2.0*torch.sum(k0[:,1:(i+1)], axis=1) - 2.0*(i+1)*k0_mean)
return idx
if __name__ == "__main__":
x = torch.randn(20000,2)
quantMMD = QuantizationMMD(m=200)
idxMMD = quantMMD.select({"samples": x}).numpy().tolist()