Weight Net, first take on putting neural nets weight to input [part 1]

In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

The idea is simple, why not feeding neural net's weight directly back into input

On first take, I need to match dimension of the weights to that of input. This is done in two ways:

  1. shrink dimension by applying max pooling,
  2. project a weight vector in high dimension to lower dimension.

This is a typical conv net for MNIST and CIFAR10 task:

In [152]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 5, 1)
        self.conv2 = nn.Conv2d(5, 10, 5, 1)
        self.fc1 = nn.Linear(4*4*10, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*10)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

The fully connected layer is a good place to start. Let's brute force some integer that match the dimensions, specifically, set a reasonable weight dimension range, say bigger than 100, and compute possible pooling kernel size:

  • x: dimension of weigth directed to the input
  • m: 1st dimension's max pooling kernel size
  • n: 2nd dimension's max pooling kernel size

$\frac{x+800}{m} \frac{500}{n} + \frac{500}{n} \frac{10}{2} = x$

In [153]:
# fc1 dimension: a, b
# fc2 dimension: b, c
a = 160
b = 50
c = 10
dx = 10
for _dx in range(10, 40, 1):
    for n in range(10, b, 1):
        for m in range(10, a, 1):
            _mm = (_dx + a) / m
            _nn = b / n
            if int(_mm) == _mm and int(_nn) == _nn:
                dw = (_mm + c/2) * _nn
                if _dx-dw > 0:
                    print("dw = %d, dx = %d, dx-dw = %d, m = %d, n = %d" % (dw, _dx, _dx-dw, m, n))
dw = 14, dx = 16, dx-dw = 2, m = 88, n = 25
dw = 16, dx = 17, dx-dw = 1, m = 59, n = 25
dw = 14, dx = 18, dx-dw = 4, m = 89, n = 25
dw = 18, dx = 20, dx-dw = 2, m = 45, n = 25
dw = 16, dx = 20, dx-dw = 4, m = 60, n = 25
dw = 14, dx = 20, dx-dw = 6, m = 90, n = 25
dw = 14, dx = 22, dx-dw = 8, m = 91, n = 25
dw = 16, dx = 23, dx-dw = 7, m = 61, n = 25
dw = 18, dx = 24, dx-dw = 6, m = 46, n = 25
dw = 14, dx = 24, dx-dw = 10, m = 92, n = 25
dw = 20, dx = 25, dx-dw = 5, m = 37, n = 25
dw = 22, dx = 26, dx-dw = 4, m = 31, n = 25
dw = 16, dx = 26, dx-dw = 10, m = 62, n = 25
dw = 14, dx = 26, dx-dw = 12, m = 93, n = 25
dw = 18, dx = 28, dx-dw = 10, m = 47, n = 25
dw = 14, dx = 28, dx-dw = 14, m = 94, n = 25
dw = 28, dx = 29, dx-dw = 1, m = 21, n = 25
dw = 24, dx = 29, dx-dw = 5, m = 27, n = 25
dw = 16, dx = 29, dx-dw = 13, m = 63, n = 25
dw = 20, dx = 30, dx-dw = 10, m = 38, n = 25
dw = 14, dx = 30, dx-dw = 16, m = 95, n = 25
dw = 26, dx = 32, dx-dw = 6, m = 24, n = 25
dw = 22, dx = 32, dx-dw = 10, m = 32, n = 25
dw = 18, dx = 32, dx-dw = 14, m = 48, n = 25
dw = 16, dx = 32, dx-dw = 16, m = 64, n = 25
dw = 14, dx = 32, dx-dw = 18, m = 96, n = 25
dw = 14, dx = 34, dx-dw = 20, m = 97, n = 25
dw = 20, dx = 35, dx-dw = 15, m = 39, n = 25
dw = 16, dx = 35, dx-dw = 19, m = 65, n = 25
dw = 35, dx = 36, dx-dw = 1, m = 98, n = 10
dw = 24, dx = 36, dx-dw = 12, m = 28, n = 25
dw = 18, dx = 36, dx-dw = 18, m = 49, n = 25
dw = 14, dx = 36, dx-dw = 22, m = 98, n = 25
dw = 35, dx = 38, dx-dw = 3, m = 99, n = 10
dw = 32, dx = 38, dx-dw = 6, m = 18, n = 25
dw = 28, dx = 38, dx-dw = 10, m = 22, n = 25
dw = 22, dx = 38, dx-dw = 16, m = 33, n = 25
dw = 16, dx = 38, dx-dw = 22, m = 66, n = 25
dw = 14, dx = 38, dx-dw = 24, m = 99, n = 25

Let's use dw = 22, dx = 26, dx-dw = 4, m = 31, n = 25

The only thing we need to update is extra input in initialization, and concat in forward

In [196]:
m = 31
n = 25

a = 160
b = 50
c = 10
dx = 26
dw = 22

class WeightNetMnist(nn.Module):
    def __init__(self):
        super(WeightNetMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 5, 1)
        self.conv2 = nn.Conv2d(5, 10, 5, 1)
        self.fc1 = nn.Linear(4*4*10 + dx, 60)
        self.fc2 = nn.Linear(60, 10)

    def forward(self, x):
        me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
        m1 = torch.unsqueeze(self.fc1.weight.data, 0)
        m1 = F.max_pool2d(m1, (n, m))
        m1 = torch.squeeze(m1, 0)
        m2 = torch.unsqueeze(self.fc2.weight.data, 0)
        m2 = F.max_pool2d(m2, (2, n))
        m2 = torch.squeeze(m2, 0)

        me = torch.cat((m1, m2.t()), 1)
        me = me.view(-1, int(dw))

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*10)

        fill = torch.ones([1, dx-dw])
        me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
        x = torch.cat((x, me), 1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    

Let's setup the experiment infra:

In [189]:
from types import SimpleNamespace
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 6,
    "lr": 0.01,
    "momentum": 0.5,
    "no_cuda": False,
    "seed": 1,
    "dataset": "mnist",
    "log_interval": 40,
    "save_model": False
}

args = SimpleNamespace(**args)
args.dataset = 'mnist'

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

all_loss = []
all_test_loss = []
all_acc = []

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        all_loss.append(loss.item())
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    all_test_loss.append(test_loss)
    all_acc.append(100. * correct / len(test_loader.dataset))
    
# model = NetCifar10().to(device)
# model = Net().to(device)
def run(modelClass):
    if args.dataset == "mnist":
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='../data', train=True, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = modelClass().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)

    if (args.save_model):
        torch.save(model.state_dict(),"mnist_cnn.pt")
In [ ]:
run(Net)
run(WeightNetMnist)
In [202]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

all_acc_mnist_origin = all_acc[0:6]
all_acc_mnist_weightnet = all_acc[6:13]

t = np.arange(0, 6)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()

This is just slightly better, need to run against a larger dataset:

In [206]:
m = 45
n = 50
class NetCifar10(nn.Module):
    def __init__(self):
        super(NetCifar10, self).__init__()
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5*5*50, 700)
        self.fc2 = nn.Linear(700, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
class WeightNetCifar10(nn.Module):
    def __init__(self):
        super(WeightNetCifar10, self).__init__()
        dx = 460
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5*5*50 + dx, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        dx = 460
        m1 = torch.unsqueeze(self.fc1.weight.data, 0)
        m1 = F.max_pool2d(m1, (n, m))
        m1 = torch.squeeze(m1, 0)
        m2 = torch.unsqueeze(self.fc2.weight.data, 0)
        m2 = F.max_pool2d(m2, (2, n))
        m2 = torch.squeeze(m2, 0)

        me = torch.cat((m1, m2.t()), 1)

        size = (((5*5*50 + dx) / m) + 5) * (500/n)
        me = me.view(-1, int(size))

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*50)

        fill = torch.ones([1, 30])
        me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
        x = torch.cat((x, me), 1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
In [ ]:
args.dataset = 'cifar'
args.epochs = 20
run(NetCifar10)
run(WeightNetCifar10)
In [211]:
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]

t = np.arange(0, 20)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()
52

Dimension matching is pretty ugly, as all we want is feed signal back into the input, a handy trick is random matrix projection. This trick was used to measure intrinsic dimension of reinforcement learning tasks.

In [212]:
class WeightNetMnistP(nn.Module):
    def __init__(self):
        super(WeightNetMnistP, self).__init__()
        dx = 26
        self.conv1 = nn.Conv2d(1, 5, 5, 1)
        self.conv2 = nn.Conv2d(5, 10, 5, 1)
        self.fc1 = nn.Linear(4*4*10 + dx, 50)
        self.fc2 = nn.Linear(50, 10)
        in_dim = (4*4*10 + dx + 10) * 50
        out_dim = dx

        P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
        self.P = torch.nn.init.xavier_uniform_(P)

    def forward(self, x):
        dx = 26
        me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)

        size = (4*4*10 + dx + 10) * 50
        me = me.view(-1, int(size))
        me = torch.matmul(me, self.P)

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*10)

        me = me.expand((x.shape[0], -1))
        x = torch.cat((x, me), 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
In [ ]:
args.dataset = 'mnist'
args.epochs = 6
run(WeightNetMnistP)
In [215]:
t = np.arange(0, 6)

all_acc_mnist_weightnetP = all_acc[52:52+6]

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_mnist_weightnetP, label='w/ Weight feedback RMP') # random matrix projection
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()
In [217]:
class WeightNetCifar10P(nn.Module):
    def __init__(self):
        super(WeightNetCifar10P, self).__init__()
        dx = 460
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5*5*50 + dx, 500)
        self.fc2 = nn.Linear(500, 10)
        in_dim = (5*5*50 + dx + 10) * 500
        out_dim = dx

        P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
        self.P = torch.nn.init.xavier_uniform_(P)

    def forward(self, x):
        dx = 460
        me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
        size = (5*5*50 + dx + 10) * 500
        me = me.view(-1, int(size))
        me = torch.matmul(me, self.P)

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*50)

        me = me.expand((x.shape[0], -1))
        x = torch.cat((x, me), 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
In [ ]:
args.dataset = 'cifar'
args.epochs = 20
run(WeightNetCifar10P)
In [219]:
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]
all_acc_cifar_weightnetP = all_acc[58:58+20]

t = np.arange(0, 20)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_cifar_weightnetP, label='w/ Weight feedback RMP')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()