Introducing...

We (or I) introduce ConvTwist, a replacement of—and, I argue, an improvement on—the good old convolutional layer widely used in Deep Learning models (rightfully referred to as ConvNets or CNNs) in Computer Vision. Famously introduced by Yann LeCun some 30 years ago into image classification, it became the source of the Deep Learning/Artificial Intelligence revolution with AlexNet in 2012. Rapid improvements on the architecture followed, most notably the ResNet of 2015. Recently attention has somewhat shifted away from image classification, but convolutional layers are still the bread and butter of any Computer Vision models. What more can be said about convolutional layers, one might ask? The answer is a little bit of mathematics.

For what it's worth, this is not peer-reviewed or published in any conference. If you'd like to give it a "review" on Twitter, please feel free to do so.

Without further ado, here is one (crude) implementation of ConvTwist in PyTorch, and you can easily swap out the 3x3 Conv2d in your model and plug this in, and train from scratch to see if it gives any improvement. Early success has been shown on ResNet50 and Imagenette/Imagewoof benchmarks. (Help with testing on other models/datasets is greatly appreciated.)

import torch
import torch.nn as nn
import numpy as np

class ConvTwist(nn.Module):  # replacing 3x3 Conv2d
    def __init__(self, ni, nf, stride=1, init_max=1.5):
        super(ConvTwist, self).__init__()
        self.conv = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv_x = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv_y = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
        self.symmetrize(self.conv_x.weight)  # make conv_x a "first-order operator" by symmetrizing it
        self.conv_y.weight.data = self.conv_x.weight.transpose(2,3).flip(3)  # make conv_y a 90 degree rotation of conv_x
        self.center_x = nn.Parameter(torch.Tensor(nf), requires_grad=True)
        self.center_y = nn.Parameter(torch.Tensor(nf), requires_grad=True)
        self.center_x.data.uniform_(-init_max, init_max)
        self.center_y.data.uniform_(-init_max, init_max)

    def symmetrize(self, conv_wt):
        conv_wt.data = (conv_wt - conv_wt.flip(2).flip(3)) / 2
        
    def forward(self, inpt):
        out = self.conv(inpt)
        _, _, h, w = out.size()
        XX = torch.from_numpy(np.indices((1,h,w))[2]*2/w).type(out.dtype).to(out.device) - 1 + self.center_x.view(-1,1,1)
        YY = torch.from_numpy(np.indices((1,h,w))[1]*2/h).type(out.dtype).to(out.device) - 1 + self.center_y.view(-1,1,1)
        self.symmetrize(self.conv_x.weight)
        self.symmetrize(self.conv_y.weight)
        return out + XX * self.conv_x(inpt) + YY * self.conv_y(inpt)

Let's take a look at the (initial) weights in such a ConvTwist model:

model = ConvTwist(3,1)
for name, param in model.named_parameters():
    print(name, param)
center_x Parameter containing:
tensor([1.0942], requires_grad=True)
center_y Parameter containing:
tensor([-0.9627], requires_grad=True)
conv.weight Parameter containing:
tensor([[[[ 0.0346,  0.1070,  0.1897],
          [ 0.1223,  0.1070,  0.0603],
          [ 0.1833,  0.1848, -0.1324]],

         [[-0.0733, -0.1502,  0.0553],
          [-0.0819, -0.0296, -0.0263],
          [-0.1566, -0.0072, -0.0485]],

         [[-0.0580, -0.1849,  0.1489],
          [-0.0274, -0.0409,  0.0025],
          [ 0.1501,  0.1230,  0.1433]]]], requires_grad=True)
conv_x.weight Parameter containing:
tensor([[[[ 0.0982,  0.0611,  0.0239],
          [-0.0116,  0.0000,  0.0116],
          [-0.0239, -0.0611, -0.0982]],

         [[ 0.0945, -0.0347,  0.0114],
          [ 0.0003,  0.0000, -0.0003],
          [-0.0114,  0.0347, -0.0945]],

         [[-0.0063, -0.0120,  0.0594],
          [-0.0686,  0.0000,  0.0686],
          [-0.0594,  0.0120,  0.0063]]]], requires_grad=True)
conv_y.weight Parameter containing:
tensor([[[[-0.0239, -0.0116,  0.0982],
          [-0.0611,  0.0000,  0.0611],
          [-0.0982,  0.0116,  0.0239]],

         [[-0.0114,  0.0003,  0.0945],
          [ 0.0347,  0.0000, -0.0347],
          [-0.0945, -0.0003,  0.0114]],

         [[-0.0594, -0.0686, -0.0063],
          [ 0.0120,  0.0000, -0.0120],
          [ 0.0063,  0.0686,  0.0594]]]], requires_grad=True)

If you take a look at the conv_x and conv_y weights, you'll notice that each 3x3 kernel is symmetric (the numbers on the opposite ends of the square are identical but with opposite signs, with the middle one always 0). That's the effect of "symmetrizing", done at each forward pass. You can also check that conv_y weights are initialized to be a 90 degree rotation of conv_x.

Why do I choose to initialize the weights this way, and what are XX and YY? Well, I'll try to explain everything later. For now it's important to note that ConvTwist is a lot bigger than the standard Conv2d layer, but not as much as it appears to be. This particular implementation, if I had done it properly, has about twice as many trainable weights as a single Conv2d layer.

Let's first revisit (classical) convolution.

What do convolutions actually do?

It is by now common knowledge that the convolution operator captures the spatial relation of the pixels (local features), so is particularly suited for image-related learning task. Moreover, it uses much fewer weights than a generic linear map (fully connected layer). The classic storyline goes like this: the first layer would learn about the line pattern, the next one how those lines come together; the deeper you go, the "higher-level structure" the convolutions learn to capture. It is also common knowledge that this is just heuristics, and should not be taken too literally. In addition, over the years we have learned that it's better to do away with kernels that are larger than 3x3 (and do away with any biases), and to go deeper (insert more layers) instead.

What is perhaps not well-known is that different 3x3 kernels have rather intuitive meanings, in terms of what it does to the image overall. For example, the Gaussian kernel in image processing "blurs" the image. We can do a little experiment to see:

G = torch.Tensor([[1,2,1],[2,4,2],[1,2,1]]) / 16
conv = nn.Conv2d(1,1,kernel_size=3,padding=1,bias=False)
conv.weight.data[:,:] = G
for name, param in conv.named_parameters():
    print(name, param)
    
import cv2
image = cv2.imread('./data/arturito.jpeg')
mono = image[:,:,1]/255
t = torch.from_numpy(mono)[None,None,:,:].type(torch.float32)
for _ in range(10):
    t = conv(t)
res = t[0,0].detach().numpy()
    
%matplotlib inline
import matplotlib.pyplot as plt
def display(orig, res):
    fig, axs = plt.subplots(1, 2, figsize=(32, 16))
    axs = axs.ravel()
    axs[0].axis('off')
    axs[0].imshow(orig)
    axs[1].axis('off')
    axs[1].imshow(res)

display(mono, res)
weight Parameter containing:
tensor([[[[0.0625, 0.1250, 0.0625],
          [0.1250, 0.2500, 0.1250],
          [0.0625, 0.1250, 0.0625]]]], requires_grad=True)

To illustrate the effect of other 3x3 kernels, it is best to choose the kernel close to the "identity", and to apply it many times. Try the following:

A = torch.Tensor([[-1,0,1],[-2,0,2],[-1,0,1]])
B = torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]])
# These are the Sobel operators for edge detection (in traditional CV)
I = torch.Tensor([[0,0,0],[0,1,0],[0,0,0]])
conv = nn.Conv2d(1,1,kernel_size=3,padding=1,bias=False)
conv.weight.data[:,:] = I + 0.01 * A  # or B
for name, param in conv.named_parameters():
    print(name, param)
    
t = torch.from_numpy(mono)[None,None,:,:].type(torch.float32)
for _ in range(300):
    t = conv(t)
res = t[0,0].detach().numpy()
display(mono, res)
weight Parameter containing:
tensor([[[[-0.0100,  0.0000,  0.0100],
          [-0.0200,  1.0000,  0.0200],
          [-0.0100,  0.0000,  0.0100]]]], requires_grad=True)

Do you see the image gets shifted in the x- or y-direction? What does it all mean? Without defining terms, one would like to say that

the operators A and B generate translations in the x- and y-directions.

You can also check for yourself that A+B generates translation in the 45 degree direction. Similarly, you can "mix up" translation with blurring. In fact, any 3x3 convolution is a combination of (tiny bits of) blurring, translations, and possibly other transformations. That's what convolution does to the whole image. By the way, this "wholistic" point of view has the advantage that one can now "picture" the entire CNN with the same standard picture of a neural net of nodes and edges:each node is now a whole image (or feature map), and each edge is a convolution operator. (insert image)

What is missing from such a convolution operator are rotation and scaling, as it is well known that CNNs are not rotation/scale-invariant, inasmuch as they are translation-invariant. There have been attempts at making a neural network invariant or equivariant for rotation and scaling, but to my limited understanding they are rather more like "forced" fixes than modifying the convolutional layer itself. For people with the right kind of mathematical backgrounds, a more natural solution may come immediately to mind: the (infinitesimal) generators for rotation and scaling (Lie algebra of a Lie group, with the exponential map), or the flow of a vector field, or solving a first-order linear partial differential equation by the method of characteristics.

Let me first show how the ConvTwist works before we get back to explaining the mathematics, as simply as I could (not using any of the jargon words above). Only then will we explain the details of the code.

ConvTwist at work (no training)

In addition to a standard Conv2d layer conv, the ConvTwist is feeding the input to two other 3x3 Conv2d layers, conv_x and conv_y. Let's see what they do:

model = ConvTwist(1,1)
model.conv.weight.data[:,:] = I
model.center_x.data[:] = 0.
model.center_y.data[:] = 0.
model.conv_x.weight.data[:,:] = 0.01 * B  # or -A
model.conv_y.weight.data[:,:] = 0.01 * A  # or B
for name, param in model.named_parameters():
    print(name, param)
    
t = torch.from_numpy(mono)[None,None,:,:].type(torch.float32)
for _ in range(500):
    t = model(t)
res = t[0,0].detach().numpy()
display(mono, res)
center_x Parameter containing:
tensor([0.], requires_grad=True)
center_y Parameter containing:
tensor([0.], requires_grad=True)
conv.weight Parameter containing:
tensor([[[[0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]]]], requires_grad=True)
conv_x.weight Parameter containing:
tensor([[[[ 0.0100,  0.0200,  0.0100],
          [ 0.0000,  0.0000,  0.0000],
          [-0.0100, -0.0200, -0.0100]]]], requires_grad=True)
conv_y.weight Parameter containing:
tensor([[[[-0.0100,  0.0000,  0.0100],
          [-0.0200,  0.0000,  0.0200],
          [-0.0100,  0.0000,  0.0100]]]], requires_grad=True)

The math behind all this

There are various ways that one could discuss the pheonomenon that a "local operator", when applied many times, yields a global transformation. But the first step is invariably to think of the image not as a collection of discrete data (pixels) but as a continuous object (function of two continuous variables), and the pixels are "samples" of the function that give a good approximation for whatever computations we want to perform on the function.

For functions of continuous variables, the first thing you might want to do is taking (partial) derivatives in the x- or y-direction. How would you approximate that if you don't have full information of the function, but only the samples at discrete points?

$$ \frac{\partial f}{\partial x}(a,b)=\lim_{h\to 0}\frac{f(a+h,b)-f(a,b)}{h}$$

print(np.indices((1,7,10))[1])
print(np.indices((1,7,10))[2])
[[[0 0 0 0 0 0 0 0 0 0]
  [1 1 1 1 1 1 1 1 1 1]
  [2 2 2 2 2 2 2 2 2 2]
  [3 3 3 3 3 3 3 3 3 3]
  [4 4 4 4 4 4 4 4 4 4]
  [5 5 5 5 5 5 5 5 5 5]
  [6 6 6 6 6 6 6 6 6 6]]]
[[[0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]
  [0 1 2 3 4 5 6 7 8 9]]]