import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from batorch.rkhs.discrepancies import KSD
from batorch.utils.misc import parameters_to_vector, vector_to_parameters
from batorch.likelihoods.regression_negloglike import GaussianNLL
from batorch.priordists.normal import Normal

from batorch.sgmcmc.samplers import SamplerSGLD
from batorch.rkhs.quantization import QuantizationMMD

from tqdm import tqdm

def get_model():
    return nn.Sequential(nn.Linear(1,50), nn.Tanh(), nn.Linear(50,50), nn.Tanh(), nn.Linear(50,1))

def prepare_data(sig_noise: float):
    f = lambda x: torch.cos(2.0*x) + torch.sin(x)

    x_train = 4*torch.rand((50,1))-2
    y_train = f(x_train) + torch.randn(*x_train.shape)*sig_noise**2

    x_test = torch.linspace(-3,3,200).view((-1,1))
    y_test = f(x_test)

    return (x_train, y_train), (x_test, y_test)

if __name__ == "__main__":

    # Pick a device
    device = torch.device("cpu")
    torch.manual_seed(0)

    # Prepare data
    sig_noise = 0.3
    (x_train, y_train), (x_test, y_test) = prepare_data(sig_noise)

    train_dataset, test_dataset = TensorDataset(x_train, y_train), TensorDataset(x_test, y_test) 
    train_loader, test_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True), DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
    
    # Define log-target
    surrogate_model = get_model()
    nll = GaussianNLL(model=surrogate_model, variance_prior=sig_noise**2).to(device)
    nlp = Normal(loc=0.0, scale=1.0).to(device)

    # Sample with SGLD
    sampler = SamplerSGLD(negloglikelihood=nll, neglogprior=nlp, init_params=None, dataloader=train_loader, step_size=1e-5)

    num_iter, num_burnin = 20000, 10000

    samples = []
    for iter in tqdm(range(num_iter), desc="Sampling with SGLD"):
        for _, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            loss = sampler.step(x, y)
        if iter > num_burnin:
            samples.append(sampler.get_params())
    samples = torch.stack(samples,dim=0)

    # Thinning
    thinner = QuantizationMMD(m=2000)
    idx = thinner.select({"samples": samples}, memory_is_bottleneck=False, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    subsamples = torch.stack([samples[i] for i in idx],dim=0)

    # Predict
    predictions = []
    with torch.no_grad():
        surrogate_model.eval()
        
        for i in tqdm(range(len(subsamples)), desc="Predicting"):
            vector_to_parameters(subsamples[i], surrogate_model.parameters(), grad=False)
            
            predictions_i = []
            for _, (x, y) in enumerate(test_loader):
                x, y = x.to(device), y.to(device)
                prediction_batch = surrogate_model(x)
                predictions_i.append(prediction_batch)
            predictions_i = torch.concat(predictions_i, dim=0)
            predictions.append(predictions_i)
        
        predictions = torch.stack(predictions,dim=0).cpu().detach().numpy().squeeze()

    # Computing the kernelized Stein discrepancy
    grads = torch.zeros_like(subsamples)
    for i in range(len(subsamples)):
        vector_to_parameters(samples[i], nll.parameters(), grad=False)
        nll.zero_grad()
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            negloglike = nll(x, y)
            if batch_idx==0:
                neglogprior = nlp(nll.parameters())
                neglogpost = negloglike + neglogprior
            else:
                neglogpost = negloglike
            neglogpost.backward()
        grads[i] = parameters_to_vector(nll.parameters(), grad=True, both=False)
    grads.mul_(-1.0)
    ksd = KSD(x=subsamples, grad_x=grads)

    # Plot the results
    import matplotlib.pyplot as plt
    plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"]})
    
    fig, ax = plt.subplots(1, 1, figsize = (8,6))
    ax.plot(x_test, predictions.T, linestyle = "-", linewidth = 1, color = "k", alpha = 0.2, zorder = -1)
    ax.plot(x_test, predictions[0], linestyle = "-", linewidth = 1, color = "k", alpha = 0.2, label = r"$\mathrm{Predictions}$", zorder = -1)
    ax.plot(x_test, predictions.mean(axis=0), linestyle = "-", linewidth = 2, color = "tab:red", label = r"$\mathrm{Mean}$ $\mathrm{prediction}$")
    ax.plot(x_test, y_test, linestyle = "-", linewidth = 2, color = "tab:blue", label = r"$\mathrm{Target}$")
    ax.scatter(x_train, y_train, color = "b", marker = "o", s = 32, label = r"$\mathrm{Training}$ $\mathrm{data}$", zorder=0)
    ax.set_xlabel(r"$x$", fontsize = 18)
    ax.set_ylabel(r"$y$", fontsize = 18)
    ax.set_title(r"$\mathrm{KSD}(P,Q)$" + rf"$ = {ksd}$", fontsize = 18)
    ax.legend(fontsize = 18)
    ax.tick_params(labelsize = 18)
    fig.tight_layout()
    plt.savefig("batorch_tutorial.png", format="png")