# Using EMLP in PyTorch¶

So maybe you haven’t yet realized that Jax is the best way of doing deep learning – that’s ok!

You can use EMLP and the equivariant linear layers *in PyTorch*. Simply replace `import emlp.nn as nn`

with `import emlp.nn.pytorch as nn`

.

If you’re using a GPU (which we recommend), you will want to set the environment variable so that Jax doesn’t steal all of the GPU memory from PyTorch. Note that if a GPU is visible under `CUDA_VISIBLE_DEVICES`

, you must use the PyTorch EMLP on the GPU.

```
[1]:
```

```
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
```

```
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
```

```
[2]:
```

```
import torch
import emlp.nn.pytorch as nn
```

```
[3]:
```

```
from emlp.reps import T,V
from emlp.groups import SO13
repin= 4*V # Setup some example data representations
repout = V**0
G = SO13() # The lorentz group
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
```

```
[4]:
```

```
x = torch.randn(5,repin(G).size()).to(device) # generate some random data
model = nn.EMLP(repin,repout,G).to(device) # initialize the model
model(x)
```

```
[4]:
```

```
tensor([[-0.0039],
[-0.0041],
[-0.0042],
[-0.0040],
[-0.0039]], device='cuda:0', grad_fn=<AddmmBackward>)
```

The model is a standard pytorch module.

```
[5]:
```

```
model
```

```
[5]:
```

```
EMLP(
(network): Sequential(
(0): EMLPBlock(
(linear): Linear(in_features=16, out_features=419, bias=True)
(bilinear): BiLinear()
(nonlinearity): GatedNonlinearity()
)
(1): EMLPBlock(
(linear): Linear(in_features=384, out_features=419, bias=True)
(bilinear): BiLinear()
(nonlinearity): GatedNonlinearity()
)
(2): EMLPBlock(
(linear): Linear(in_features=384, out_features=419, bias=True)
(bilinear): BiLinear()
(nonlinearity): GatedNonlinearity()
)
(3): Linear(in_features=384, out_features=1, bias=True)
)
)
```

## Example Training Loop¶

Ok what about training and autograd and all that? As you can see the training loop is very similar to the objax one in Constructing Equivariant Models.

```
[6]:
```

```
import torch
import emlp.nn.pytorch as nn
from emlp.groups import SO13
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from emlp.datasets import ParticleInteraction
trainset = ParticleInteraction(300) # Initialize dataset with 1000 examples
testset = ParticleInteraction(1000)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BS=500
lr=3e-3
NUM_EPOCHS=500
model = nn.EMLP(trainset.rep_in,trainset.rep_out,group=SO13(),num_layers=3,ch=384).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
def loss(x, y):
yhat = model(x.to(device))
return ((yhat-y.to(device))**2).mean()
def train_op(x, y):
optimizer.zero_grad()
lossval = loss(x,y)
lossval.backward()
optimizer.step()
return lossval
trainloader = DataLoader(trainset,batch_size=BS,shuffle=True)
testloader = DataLoader(testset,batch_size=BS,shuffle=True)
test_losses = []
train_losses = []
for epoch in tqdm(range(NUM_EPOCHS)):
train_losses.append(np.mean([train_op(*mb).cpu().data.numpy() for mb in trainloader]))
if not epoch%10:
with torch.no_grad():
test_losses.append(np.mean([loss(*mb).cpu().data.numpy() for mb in testloader]))
```

Ok so it’s not nearly as fast as in Jax (maybe 15x slower), but hey you said you wanted PyTorch

```
[7]:
```

```
import matplotlib.pyplot as plt
plt.plot(np.arange(NUM_EPOCHS),train_losses,label='Train loss')
plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label='Test loss')
plt.legend()
plt.yscale('log')
```

Bonus: Try out `model=nn.MLP(trainset.rep_in,trainset.rep_out,group=SO13()).to(device)`

and see how well it performs on this problem.

## Converting Jax functions to PyTorch functions (how it works)¶

You can use the underlying equivariant bases \(Q\in \mathbb{R}^{n\times r}\) and projection operators \(P = QQ^\top\) in pytorch also.

Since these objects are implicitly defined through `LinearOperators`

, it is not as straightforward as simply calling `torch.from_numpy(Q)`

. However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function `emlp.reps.pytorch_support.torchify_fn`

to do this.

```
[ ]:
```

```
import jax
import jax.numpy as jnp
from emlp.reps import V
from emlp.groups import S
```

For example, let’s setup a representation \(S_4\) consisting of three vectors and one matrix.

```
[8]:
```

```
W =V(S(4))
rep = 3*W+W**2
```

First we compute the equivariant basis and equivariant projector linear operators, and then wrap them as functions.

```
[9]:
```

```
Q = (rep>>rep).equivariant_basis()
P = (rep>>rep).equivariant_projector()
```

```
[10]:
```

```
applyQ = lambda v: Q@v
applyP = lambda v: P@v
```

We can convert any pure pytorch function into a jax function by applying `torchify_fn`

. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects.

```
[11]:
```

```
from emlp.nn.pytorch import torchify_fn
applyQ_torch = torchify_fn(applyQ)
applyP_torch = torchify_fn(applyP)
```

As you would hope, gradients are correctly propagated whether you use the original Jax functions or the torchified pytorch functions.

```
[12]:
```

```
x_torch = torch.arange(Q.shape[-1]).float().cuda()
x_torch.requires_grad=True
x_jax = jnp.asarray(x_torch.cpu().data.numpy())
```

```
[13]:
```

```
Qx1 = applyQ(x_jax)
Qx2 = applyQ_torch(x_torch)
print("jax output: ",Qx1[:5])
print("torch output: ",Qx2[:5])
```

```
jax output: [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]
torch output: tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device='cuda:0',
grad_fn=<SliceBackward>)
```

The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:

```
[14]:
```

```
torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]
```

```
[14]:
```

```
tensor([-2.8704, 2.7858, -2.8704, 2.7858, -2.8704], device='cuda:0')
```

```
[15]:
```

```
jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]
```

```
[15]:
```

```
DeviceArray([-2.8703732, 2.7858496, -2.8703732, 2.7858496, -2.8703732], dtype=float32)
```

So you can safely use these torchified functions within your model, and still compute the gradients correctly.

We use this `torchify_fn`

on the projection operators to convert EMLP to pytorch.