{
“cells”: [
{

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“# Using EMLP in PyTorch”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“So maybe you haven’t yet realized that Jax is the best way of doing deep learning – that’s ok!n”, “n”, “You can use EMLP and the equivariant linear layers _in PyTorch_. Simply replace import emlp.nn as nn with import emlp.nn.pytorch as nn.n”, “n”, “If you’re using a GPU (which we recommend), you will want to set the environment variable so that Jax doesn’t steal all of the GPU memory from PyTorch. Note that if a GPU is visible under CUDA_VISIBLE_DEVICES, you must use the PyTorch EMLP on the GPU.”

]

}, {

“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“env: XLA_PYTHON_CLIENT_PREALLOCATE=falsen”

]

}

], “source”: [

“%env XLA_PYTHON_CLIENT_PREALLOCATE=false”

]

}, {

“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [], “source”: [

“import torchn”, “import emlp.nn.pytorch as nn”

]

}, {

“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [], “source”: [

“from emlp.reps import T,Vn”, “from emlp.groups import SO13n”, “n”, “repin= 4*V # Setup some example data representationsn”, “repout = V**0n”, “G = SO13() # The lorentz groupn”, “n”, “device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)”

]

}, {

“cell_type”: “code”, “execution_count”: 4, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “3da7e6d57cf64559b02621650bfd2c92”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=10: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “63623850b71a4526ae0bd3777ea644b9”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=20: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “6114c51884874685b10d93b4700658d3”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=40: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“text/plain”: [

“tensor([[-0.0039],n”, ” [-0.0041],n”, ” [-0.0042],n”, ” [-0.0040],n”, ” [-0.0039]], device=’cuda:0’, grad_fn=<AddmmBackward>)”

]

}, “execution_count”: 4, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“x = torch.randn(5,repin(G).size()).to(device) # generate some random datan”, “n”, “model = nn.EMLP(repin,repout,G).to(device) # initialize the modeln”, “n”, “model(x)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The model is a standard pytorch module.”

]

}, {

“cell_type”: “code”, “execution_count”: 5, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“EMLP(n”, ” (network): Sequential(n”, ” (0): EMLPBlock(n”, ” (linear): Linear(in_features=16, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (1): EMLPBlock(n”, ” (linear): Linear(in_features=384, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (2): EMLPBlock(n”, ” (linear): Linear(in_features=384, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (3): Linear(in_features=384, out_features=1, bias=True)n”, ” )n”, “)”

]

}, “execution_count”: 5, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“model”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Example Training Loop”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Ok what about training and autograd and all that? As you can see the training loop is very similar to the objax one in [Constructing Equivariant Models](https://equivariant-mlp.readthedocs.io/en/latest/notebooks/2building_a_model.html).”

]

}, {

“cell_type”: “code”, “execution_count”: 6, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “38cbce4a562c4340871a9191255820cd”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

” 0%| | 0/500 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}

], “source”: [

“import torchn”, “import emlp.nn.pytorch as nnn”, “from emlp.groups import SO13n”, “import numpy as npn”, “from tqdm.auto import tqdmn”, “from torch.utils.data import DataLoadern”, “from emlp.datasets import ParticleInteractionn”, “n”, “trainset = ParticleInteraction(300) # Initialize dataset with 1000 examplesn”, “testset = ParticleInteraction(1000)n”, “n”, “device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)n”, “BS=500n”, “lr=3e-3n”, “NUM_EPOCHS=500n”, “n”, “model = nn.EMLP(trainset.rep_in,trainset.rep_out,group=SO13(),num_layers=3,ch=384).to(device)n”, “n”, “optimizer = torch.optim.Adam(model.parameters(),lr=lr)n”, “n”, “def loss(x, y):n”, ” yhat = model(x.to(device))n”, ” return ((yhat-y.to(device))**2).mean()n”, “n”, “def train_op(x, y):n”, ” optimizer.zero_grad()n”, ” lossval = loss(x,y)n”, ” lossval.backward()n”, ” optimizer.step()n”, ” return lossvaln”, “n”, “trainloader = DataLoader(trainset,batch_size=BS,shuffle=True)n”, “testloader = DataLoader(testset,batch_size=BS,shuffle=True)n”, “n”, “test_losses = []n”, “train_losses = []n”, “for epoch in tqdm(range(NUM_EPOCHS)):n”, ” train_losses.append(np.mean([train_op(*mb).cpu().data.numpy() for mb in trainloader]))n”, ” if not epoch%10:n”, ” with torch.no_grad():n”, ” test_losses.append(np.mean([loss(*mb).cpu().data.numpy() for mb in testloader]))”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Ok so it’s not nearly as fast as in Jax (maybe 15x slower), but hey you said you wanted PyTorch”

]

}, {

“cell_type”: “code”, “execution_count”: 7, “metadata”: {}, “outputs”: [

{
“data”: {

“image/png”: “n”, “text/plain”: [

“<Figure size 432x288 with 1 Axes>”

]

}, “metadata”: {

“needs_background”: “light”

}, “output_type”: “display_data”

}

], “source”: [

“import matplotlib.pyplot as pltn”, “plt.plot(np.arange(NUM_EPOCHS),train_losses,label=’Train loss’)n”, “plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label=’Test loss’)n”, “plt.legend()n”, “plt.yscale(‘log’)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Bonus: Try out model=nn.MLP(trainset.rep_in,trainset.rep_out,group=SO13()).to(device) and see how well it performs on this problem.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Converting Jax functions to PyTorch functions (how it works)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“You can use the underlying equivariant bases $Q\in \mathbb{R}^{n\times r}$ and projection operators $P = QQ^\top$ in pytorch also.n”, “n”, “Since these objects are implicitly defined through LinearOperators, it is not as straightforward as simply calling torch.from_numpy(Q). However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function emlp.reps.pytorch_support.torchify_fn to do this.”

]

}, {

“cell_type”: “code”, “execution_count”: null, “metadata”: {}, “outputs”: [], “source”: [

“import jaxn”, “import jax.numpy as jnpn”, “from emlp.reps import Vn”, “from emlp.groups import S”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“For example, let’s setup a representation $S_4$ consisting of three vectors and one matrix.”

]

}, {

“cell_type”: “code”, “execution_count”: 8, “metadata”: {}, “outputs”: [], “source”: [

“W =V(S(4))n”, “rep = 3*W+W**2”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“First we compute the equivariant basis and equivariant projector linear operators, and then wrap them as functions.”

]

}, {

“cell_type”: “code”, “execution_count”: 9, “metadata”: {}, “outputs”: [], “source”: [

“Q = (rep>>rep).equivariant_basis()n”, “P = (rep>>rep).equivariant_projector()”

]

}, {

“cell_type”: “code”, “execution_count”: 10, “metadata”: {}, “outputs”: [], “source”: [

“applyQ = lambda v: Q@vn”, “applyP = lambda v: P@v”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“We can convert any pure pytorch function into a jax function by applying torchify_fn. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects.”

]

}, {

“cell_type”: “code”, “execution_count”: 11, “metadata”: {}, “outputs”: [], “source”: [

“from emlp.nn.pytorch import torchify_fnn”, “applyQ_torch = torchify_fn(applyQ)n”, “applyP_torch = torchify_fn(applyP)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“As you would hope, gradients are correctly propagated whether you use the original Jax functions or the torchified pytorch functions.”

]

}, {

“cell_type”: “code”, “execution_count”: 12, “metadata”: {}, “outputs”: [], “source”: [

“x_torch = torch.arange(Q.shape[-1]).float().cuda()n”, “x_torch.requires_grad=Truen”, “x_jax = jnp.asarray(x_torch.cpu().data.numpy()) “

]

}, {

“cell_type”: “code”, “execution_count”: 13, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“jax output: [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]n”, “torch output: tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device=’cuda:0’,n”, ” grad_fn=<SliceBackward>)n”

]

}

], “source”: [

“Qx1 = applyQ(x_jax)n”, “Qx2 = applyQ_torch(x_torch)n”, “print("jax output: ",Qx1[:5])n”, “print("torch output: ",Qx2[:5])”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:”

]

}, {

“cell_type”: “code”, “execution_count”: 14, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“tensor([-2.8704, 2.7858, -2.8704, 2.7858, -2.8704], device=’cuda:0’)”

]

}, “execution_count”: 14, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]”

]

}, {

“cell_type”: “code”, “execution_count”: 15, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“DeviceArray([-2.8703732, 2.7858496, -2.8703732, 2.7858496, -2.8703732], dtype=float32)”

]

}, “execution_count”: 15, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“So you can safely use these torchified functions within your model, and still compute the gradients correctly.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“We use this torchify_fn on the projection operators to convert EMLP to pytorch.”

]

}

], “metadata”: {

“kernelspec”: {

“display_name”: “Python 3”, “language”: “python”, “name”: “python3”

}, “language_info”: {

“codemirror_mode”: {

“name”: “ipython”, “version”: 3

}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”

}

}, “nbformat”: 4, “nbformat_minor”: 4

}