# Weight Net, first take on putting neural nets weight to input [part 1]¶

In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


### The idea is simple, why not feeding neural net's weight directly back into input¶

On first take, I need to match dimension of the weights to that of input. This is done in two ways:

1. shrink dimension by applying max pooling,
2. project a weight vector in high dimension to lower dimension.

This is a typical conv net for MNIST and CIFAR10 task:

In [152]:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


The fully connected layer is a good place to start. Let's brute force some integer that match the dimensions, specifically, set a reasonable weight dimension range, say bigger than 100, and compute possible pooling kernel size:

• x: dimension of weigth directed to the input
• m: 1st dimension's max pooling kernel size
• n: 2nd dimension's max pooling kernel size

$\frac{x+800}{m} \frac{500}{n} + \frac{500}{n} \frac{10}{2} = x$

In [153]:
# fc1 dimension: a, b
# fc2 dimension: b, c
a = 160
b = 50
c = 10
dx = 10
for _dx in range(10, 40, 1):
for n in range(10, b, 1):
for m in range(10, a, 1):
_mm = (_dx + a) / m
_nn = b / n
if int(_mm) == _mm and int(_nn) == _nn:
dw = (_mm + c/2) * _nn
if _dx-dw > 0:
print("dw = %d, dx = %d, dx-dw = %d, m = %d, n = %d" % (dw, _dx, _dx-dw, m, n))

dw = 14, dx = 16, dx-dw = 2, m = 88, n = 25
dw = 16, dx = 17, dx-dw = 1, m = 59, n = 25
dw = 14, dx = 18, dx-dw = 4, m = 89, n = 25
dw = 18, dx = 20, dx-dw = 2, m = 45, n = 25
dw = 16, dx = 20, dx-dw = 4, m = 60, n = 25
dw = 14, dx = 20, dx-dw = 6, m = 90, n = 25
dw = 14, dx = 22, dx-dw = 8, m = 91, n = 25
dw = 16, dx = 23, dx-dw = 7, m = 61, n = 25
dw = 18, dx = 24, dx-dw = 6, m = 46, n = 25
dw = 14, dx = 24, dx-dw = 10, m = 92, n = 25
dw = 20, dx = 25, dx-dw = 5, m = 37, n = 25
dw = 22, dx = 26, dx-dw = 4, m = 31, n = 25
dw = 16, dx = 26, dx-dw = 10, m = 62, n = 25
dw = 14, dx = 26, dx-dw = 12, m = 93, n = 25
dw = 18, dx = 28, dx-dw = 10, m = 47, n = 25
dw = 14, dx = 28, dx-dw = 14, m = 94, n = 25
dw = 28, dx = 29, dx-dw = 1, m = 21, n = 25
dw = 24, dx = 29, dx-dw = 5, m = 27, n = 25
dw = 16, dx = 29, dx-dw = 13, m = 63, n = 25
dw = 20, dx = 30, dx-dw = 10, m = 38, n = 25
dw = 14, dx = 30, dx-dw = 16, m = 95, n = 25
dw = 26, dx = 32, dx-dw = 6, m = 24, n = 25
dw = 22, dx = 32, dx-dw = 10, m = 32, n = 25
dw = 18, dx = 32, dx-dw = 14, m = 48, n = 25
dw = 16, dx = 32, dx-dw = 16, m = 64, n = 25
dw = 14, dx = 32, dx-dw = 18, m = 96, n = 25
dw = 14, dx = 34, dx-dw = 20, m = 97, n = 25
dw = 20, dx = 35, dx-dw = 15, m = 39, n = 25
dw = 16, dx = 35, dx-dw = 19, m = 65, n = 25
dw = 35, dx = 36, dx-dw = 1, m = 98, n = 10
dw = 24, dx = 36, dx-dw = 12, m = 28, n = 25
dw = 18, dx = 36, dx-dw = 18, m = 49, n = 25
dw = 14, dx = 36, dx-dw = 22, m = 98, n = 25
dw = 35, dx = 38, dx-dw = 3, m = 99, n = 10
dw = 32, dx = 38, dx-dw = 6, m = 18, n = 25
dw = 28, dx = 38, dx-dw = 10, m = 22, n = 25
dw = 22, dx = 38, dx-dw = 16, m = 33, n = 25
dw = 16, dx = 38, dx-dw = 22, m = 66, n = 25
dw = 14, dx = 38, dx-dw = 24, m = 99, n = 25


Let's use dw = 22, dx = 26, dx-dw = 4, m = 31, n = 25

The only thing we need to update is extra input in initialization, and concat in forward

In [196]:
m = 31
n = 25

a = 160
b = 50
c = 10
dx = 26
dw = 22

class WeightNetMnist(nn.Module):
def __init__(self):
super(WeightNetMnist, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10 + dx, 60)
self.fc2 = nn.Linear(60, 10)

def forward(self, x):
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
m1 = torch.unsqueeze(self.fc1.weight.data, 0)
m1 = F.max_pool2d(m1, (n, m))
m1 = torch.squeeze(m1, 0)
m2 = torch.unsqueeze(self.fc2.weight.data, 0)
m2 = F.max_pool2d(m2, (2, n))
m2 = torch.squeeze(m2, 0)

me = torch.cat((m1, m2.t()), 1)
me = me.view(-1, int(dw))

x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)

fill = torch.ones([1, dx-dw])
me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
x = torch.cat((x, me), 1)

x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)



Let's setup the experiment infra:

In [189]:
from types import SimpleNamespace
args = {
"batch_size": 64,
"test_batch_size": 1000,
"epochs": 6,
"lr": 0.01,
"momentum": 0.5,
"no_cuda": False,
"seed": 1,
"dataset": "mnist",
"log_interval": 40,
"save_model": False
}

args = SimpleNamespace(**args)
args.dataset = 'mnist'

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

all_loss = []
all_test_loss = []
all_acc = []

def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
all_loss.append(loss.item())
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
100. * batch_idx / len(train_loader), loss.item()))

model.eval()
test_loss = 0
correct = 0
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(

all_test_loss.append(test_loss)

# model = NetCifar10().to(device)
# model = Net().to(device)
def run(modelClass):
if args.dataset == "mnist":
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.dataset == "cifar":
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)

model = modelClass().to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)

if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")

In [ ]:
run(Net)
run(WeightNetMnist)

In [202]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

all_acc_mnist_origin = all_acc[0:6]
all_acc_mnist_weightnet = all_acc[6:13]

t = np.arange(0, 6)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()


This is just slightly better, need to run against a larger dataset:

In [206]:
m = 45
n = 50
class NetCifar10(nn.Module):
def __init__(self):
super(NetCifar10, self).__init__()
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50, 700)
self.fc2 = nn.Linear(700, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

class WeightNetCifar10(nn.Module):
def __init__(self):
super(WeightNetCifar10, self).__init__()
dx = 460
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50 + dx, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
dx = 460
m1 = torch.unsqueeze(self.fc1.weight.data, 0)
m1 = F.max_pool2d(m1, (n, m))
m1 = torch.squeeze(m1, 0)
m2 = torch.unsqueeze(self.fc2.weight.data, 0)
m2 = F.max_pool2d(m2, (2, n))
m2 = torch.squeeze(m2, 0)

me = torch.cat((m1, m2.t()), 1)

size = (((5*5*50 + dx) / m) + 5) * (500/n)
me = me.view(-1, int(size))

x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)

fill = torch.ones([1, 30])
me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
x = torch.cat((x, me), 1)

x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

In [ ]:
args.dataset = 'cifar'
args.epochs = 20
run(NetCifar10)
run(WeightNetCifar10)

In [211]:
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]

t = np.arange(0, 20)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()

52


Dimension matching is pretty ugly, as all we want is feed signal back into the input, a handy trick is random matrix projection. This trick was used to measure intrinsic dimension of reinforcement learning tasks.

In [212]:
class WeightNetMnistP(nn.Module):
def __init__(self):
super(WeightNetMnistP, self).__init__()
dx = 26
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10 + dx, 50)
self.fc2 = nn.Linear(50, 10)
in_dim = (4*4*10 + dx + 10) * 50
out_dim = dx

P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
self.P = torch.nn.init.xavier_uniform_(P)

def forward(self, x):
dx = 26
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)

size = (4*4*10 + dx + 10) * 50
me = me.view(-1, int(size))
me = torch.matmul(me, self.P)

x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)

me = me.expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

In [ ]:
args.dataset = 'mnist'
args.epochs = 6
run(WeightNetMnistP)

In [215]:
t = np.arange(0, 6)

all_acc_mnist_weightnetP = all_acc[52:52+6]

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_mnist_weightnetP, label='w/ Weight feedback RMP') # random matrix projection
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()

In [217]:
class WeightNetCifar10P(nn.Module):
def __init__(self):
super(WeightNetCifar10P, self).__init__()
dx = 460
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50 + dx, 500)
self.fc2 = nn.Linear(500, 10)
in_dim = (5*5*50 + dx + 10) * 500
out_dim = dx

P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
self.P = torch.nn.init.xavier_uniform_(P)

def forward(self, x):
dx = 460
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
size = (5*5*50 + dx + 10) * 500
me = me.view(-1, int(size))
me = torch.matmul(me, self.P)

x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)

me = me.expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

In [ ]:
args.dataset = 'cifar'
args.epochs = 20
run(WeightNetCifar10P)

In [219]:
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]
all_acc_cifar_weightnetP = all_acc[58:58+20]

t = np.arange(0, 20)

fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_cifar_weightnetP, label='w/ Weight feedback RMP')

ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()