import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
import numpy as np
DEVICE = 'cuda'
# DEVICE = 'cpu'
def overlay_y_on_x(x, y):
x_ = x.clone()
x_[:, :10] *= 0.0
x_[range(x.shape[0]), y] = x.max()
return x_
def generate_negative_data(x, y):
y_neg = y.clone()
for idx, y_samp in enumerate(y):
allowed_indices = [i for i in range(10)]
allowed_indices.pop(y_samp.item())
y_neg[idx] = torch.tensor(np.random.choice(allowed_indices)).to(DEVICE)
return overlay_y_on_x(x, y_neg)
class Layer(nn.Linear):
def __init__(self, in_features, out_features,
bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.relu = torch.nn.ReLU()
self.opt = Adam(self.parameters(), lr=0.03)
self.threshold = 3.0
self.num_epochs = 1000
def forward(self, x):
x_direction = x / (x.pow(2).sum(dim=1).sqrt().reshape((x.shape[0], 1)) + 1e-4)
return self.relu(
torch.mm(x_direction, self.weight.T) +
self.bias.unsqueeze(0))
def train(self, x_pos, x_neg):
for i in range(self.num_epochs):
g_pos = self.forward(x_pos).pow(2).mean(1)
g_neg = self.forward(x_neg).pow(2).mean(1)
pos_loss = -g_pos + self.threshold
neg_loss = g_neg - self.threshold
loss = torch.log(1 + torch.exp(torch.cat([
pos_loss,
neg_loss]))).mean()
self.opt.zero_grad()
loss.backward()
self.opt.step()
return self.forward(x_pos).detach(), self.forward(x_neg).detach()
class Net(torch.nn.Module):
def __init__(self, dims):
super().__init__()
self.layers = []
for d in range(len(dims) - 1):
self.layers += [Layer(dims[d], dims[d + 1]).to(DEVICE)]
def predict(self, x):
goodness_per_label = []
for label in range(10):
h = overlay_y_on_x(x, label)
goodness = []
for layer in self.layers:
h = layer(h)
goodness += [h.pow(2).mean(1)]
goodness_per_label += [sum(goodness).unsqueeze(1)]
goodness_per_label = torch.cat(goodness_per_label, 1)
return goodness_per_label.argmax(1)
def train(self, x_pos, x_neg):
h_pos, h_neg = x_pos, x_neg
for i, layer in enumerate(self.layers):
h_pos, h_neg = layer.train(h_pos, h_neg)
torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()
net = Net([784, 512, 512])
for x, y in tqdm(train_loader):
x, y = x.to(DEVICE), y.to(DEVICE)
x_pos = overlay_y_on_x(x, y)
x_neg = generate_negative_data(x, y)
net.train(x_pos, x_neg)