Interactive online version: Open In Colab

Using EMLP in PyTorch

So maybe you haven’t yet realized that Jax is the best way of doing deep learning – that’s ok!

You can use EMLP and the equivariant linear layers in PyTorch. Simply replace import emlp.nn as nn with import emlp.nn.pytorch as nn.

If you’re using a GPU (which we recommend), you will want to set the environment variable so that Jax doesn’t steal all of the GPU memory from PyTorch. Note that if a GPU is visible under CUDA_VISIBLE_DEVICES, you must use the PyTorch EMLP on the GPU.

[1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
[2]:
import torch
import emlp.nn.pytorch as nn
[3]:
from emlp.reps import T,V
from emlp.groups import SO13

repin= 4*V # Setup some example data representations
repout = V**0
G = SO13() # The lorentz group

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
[4]:
x = torch.randn(5,repin(G).size()).to(device) # generate some random data

model = nn.EMLP(repin,repout,G).to(device) # initialize the model

model(x)
[4]:
tensor([[-0.0039],
        [-0.0041],
        [-0.0042],
        [-0.0040],
        [-0.0039]], device='cuda:0', grad_fn=<AddmmBackward>)

The model is a standard pytorch module.

[5]:
model
[5]:
EMLP(
  (network): Sequential(
    (0): EMLPBlock(
      (linear): Linear(in_features=16, out_features=419, bias=True)
      (bilinear): BiLinear()
      (nonlinearity): GatedNonlinearity()
    )
    (1): EMLPBlock(
      (linear): Linear(in_features=384, out_features=419, bias=True)
      (bilinear): BiLinear()
      (nonlinearity): GatedNonlinearity()
    )
    (2): EMLPBlock(
      (linear): Linear(in_features=384, out_features=419, bias=True)
      (bilinear): BiLinear()
      (nonlinearity): GatedNonlinearity()
    )
    (3): Linear(in_features=384, out_features=1, bias=True)
  )
)

Example Training Loop

Ok what about training and autograd and all that? As you can see the training loop is very similar to the objax one in Constructing Equivariant Models.

[6]:
import torch
import emlp.nn.pytorch as nn
from emlp.groups import SO13
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from emlp.datasets import ParticleInteraction

trainset = ParticleInteraction(300) # Initialize dataset with 1000 examples
testset = ParticleInteraction(1000)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BS=500
lr=3e-3
NUM_EPOCHS=500

model = nn.EMLP(trainset.rep_in,trainset.rep_out,group=SO13(),num_layers=3,ch=384).to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)

def loss(x, y):
    yhat = model(x.to(device))
    return ((yhat-y.to(device))**2).mean()

def train_op(x, y):
    optimizer.zero_grad()
    lossval = loss(x,y)
    lossval.backward()
    optimizer.step()
    return lossval

trainloader = DataLoader(trainset,batch_size=BS,shuffle=True)
testloader = DataLoader(testset,batch_size=BS,shuffle=True)

test_losses = []
train_losses = []
for epoch in tqdm(range(NUM_EPOCHS)):
    train_losses.append(np.mean([train_op(*mb).cpu().data.numpy() for mb in trainloader]))
    if not epoch%10:
        with torch.no_grad():
            test_losses.append(np.mean([loss(*mb).cpu().data.numpy() for mb in testloader]))

Ok so it’s not nearly as fast as in Jax (maybe 15x slower), but hey you said you wanted PyTorch

[7]:
import matplotlib.pyplot as plt
plt.plot(np.arange(NUM_EPOCHS),train_losses,label='Train loss')
plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label='Test loss')
plt.legend()
plt.yscale('log')
../_images/notebooks_pytorch_support_12_0.png

Bonus: Try out model=nn.MLP(trainset.rep_in,trainset.rep_out,group=SO13()).to(device) and see how well it performs on this problem.

Converting Jax functions to PyTorch functions (how it works)

You can use the underlying equivariant bases \(Q\in \mathbb{R}^{n\times r}\) and projection operators \(P = QQ^\top\) in pytorch also.

Since these objects are implicitly defined through LinearOperators, it is not as straightforward as simply calling torch.from_numpy(Q). However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function emlp.reps.pytorch_support.torchify_fn to do this.

[ ]:
import jax
import jax.numpy as jnp
from emlp.reps import V
from emlp.groups import S

For example, let’s setup a representation \(S_4\) consisting of three vectors and one matrix.

[8]:
W =V(S(4))
rep = 3*W+W**2

First we compute the equivariant basis and equivariant projector linear operators, and then wrap them as functions.

[9]:
Q = (rep>>rep).equivariant_basis()
P = (rep>>rep).equivariant_projector()
[10]:
applyQ = lambda v: Q@v
applyP = lambda v: P@v

We can convert any pure pytorch function into a jax function by applying torchify_fn. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects.

[11]:
from emlp.nn.pytorch import torchify_fn
applyQ_torch = torchify_fn(applyQ)
applyP_torch = torchify_fn(applyP)

As you would hope, gradients are correctly propagated whether you use the original Jax functions or the torchified pytorch functions.

[12]:
x_torch = torch.arange(Q.shape[-1]).float().cuda()
x_torch.requires_grad=True
x_jax  = jnp.asarray(x_torch.cpu().data.numpy())
[13]:
Qx1 = applyQ(x_jax)
Qx2 = applyQ_torch(x_torch)
print("jax output: ",Qx1[:5])
print("torch output: ",Qx2[:5])
jax output:  [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]
torch output:  tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device='cuda:0',
       grad_fn=<SliceBackward>)

The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:

[14]:
torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]
[14]:
tensor([-2.8704,  2.7858, -2.8704,  2.7858, -2.8704], device='cuda:0')
[15]:
jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]
[15]:
DeviceArray([-2.8703732,  2.7858496, -2.8703732,  2.7858496, -2.8703732],            dtype=float32)

So you can safely use these torchified functions within your model, and still compute the gradients correctly.

We use this torchify_fn on the projection operators to convert EMLP to pytorch.