{
“cells”: [
{

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“# Using EMLP in PyTorch”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“So maybe you haven’t yet realized that Jax is the best way of doing deep learning – that’s ok!n”, “n”, “You can use EMLP and the equivariant linear layers _in PyTorch_. Simply replace import emlp.nn as nn with import emlp.nn.pytorch as nn.n”, “n”, “If you’re using a GPU (which we recommend), you will want to set the environment variable so that Jax doesn’t steal all of the GPU memory from PyTorch. Note that if a GPU is visible under CUDA_VISIBLE_DEVICES, you must use the PyTorch EMLP on the GPU.”

]

}, {

“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“env: XLA_PYTHON_CLIENT_PREALLOCATE=falsen”

]

}

], “source”: [

“%env XLA_PYTHON_CLIENT_PREALLOCATE=false”

]

}, {

“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [], “source”: [

“import torchn”, “import emlp.nn.pytorch as nn”

]

}, {

“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [], “source”: [

“from emlp.reps import T,Vn”, “from emlp.groups import SO13n”, “n”, “repin= 4*V # Setup some example data representationsn”, “repout = V**0n”, “G = SO13() # The lorentz groupn”, “n”, “device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)”

]

}, {

“cell_type”: “code”, “execution_count”: 4, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “3da7e6d57cf64559b02621650bfd2c92”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=10: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “63623850b71a4526ae0bd3777ea644b9”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=20: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “6114c51884874685b10d93b4700658d3”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=40: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“text/plain”: [

“tensor([[-0.0039],n”, ” [-0.0041],n”, ” [-0.0042],n”, ” [-0.0040],n”, ” [-0.0039]], device=’cuda:0’, grad_fn=<AddmmBackward>)”

]

}, “execution_count”: 4, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“x = torch.randn(5,repin(G).size()).to(device) # generate some random datan”, “n”, “model = nn.EMLP(repin,repout,G).to(device) # initialize the modeln”, “n”, “model(x)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The model is a standard pytorch module.”

]

}, {

“cell_type”: “code”, “execution_count”: 5, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“EMLP(n”, ” (network): Sequential(n”, ” (0): EMLPBlock(n”, ” (linear): Linear(in_features=16, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (1): EMLPBlock(n”, ” (linear): Linear(in_features=384, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (2): EMLPBlock(n”, ” (linear): Linear(in_features=384, out_features=419, bias=True)n”, ” (bilinear): BiLinear()n”, ” (nonlinearity): GatedNonlinearity()n”, ” )n”, ” (3): Linear(in_features=384, out_features=1, bias=True)n”, ” )n”, “)”

]

}, “execution_count”: 5, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“model”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Example Training Loop”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Ok what about training and autograd and all that? As you can see the training loop is very similar to the objax one in [Constructing Equivariant Models](https://equivariant-mlp.readthedocs.io/en/latest/notebooks/2building_a_model.html).”

]

}, {

“cell_type”: “code”, “execution_count”: 6, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “38cbce4a562c4340871a9191255820cd”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

” 0%| | 0/500 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}

], “source”: [

“import torchn”, “import emlp.nn.pytorch as nnn”, “from emlp.groups import SO13n”, “import numpy as npn”, “from tqdm.auto import tqdmn”, “from torch.utils.data import DataLoadern”, “from emlp.datasets import ParticleInteractionn”, “n”, “trainset = ParticleInteraction(300) # Initialize dataset with 1000 examplesn”, “testset = ParticleInteraction(1000)n”, “n”, “device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)n”, “BS=500n”, “lr=3e-3n”, “NUM_EPOCHS=500n”, “n”, “model = nn.EMLP(trainset.rep_in,trainset.rep_out,group=SO13(),num_layers=3,ch=384).to(device)n”, “n”, “optimizer = torch.optim.Adam(model.parameters(),lr=lr)n”, “n”, “def loss(x, y):n”, ” yhat = model(x.to(device))n”, ” return ((yhat-y.to(device))**2).mean()n”, “n”, “def train_op(x, y):n”, ” optimizer.zero_grad()n”, ” lossval = loss(x,y)n”, ” lossval.backward()n”, ” optimizer.step()n”, ” return lossvaln”, “n”, “trainloader = DataLoader(trainset,batch_size=BS,shuffle=True)n”, “testloader = DataLoader(testset,batch_size=BS,shuffle=True)n”, “n”, “test_losses = []n”, “train_losses = []n”, “for epoch in tqdm(range(NUM_EPOCHS)):n”, ” train_losses.append(np.mean([train_op(*mb).cpu().data.numpy() for mb in trainloader]))n”, ” if not epoch%10:n”, ” with torch.no_grad():n”, ” test_losses.append(np.mean([loss(*mb).cpu().data.numpy() for mb in testloader]))”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Ok so it’s not nearly as fast as in Jax (maybe 15x slower), but hey you said you wanted PyTorch”

]

}, {

“cell_type”: “code”, “execution_count”: 7, “metadata”: {}, “outputs”: [

{
“data”: {

“image/png”: “iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAA0KElEQVR4nO3deXhU1f348feZLclkJRuELCRhD1vACCirOwrUfbdqvy5f21qXPq612kVrbWvV+tNK1YJ+3bAuuKFSFRGUNewggUAIJCRkJfs+c35/3EkIkISQZWYy+bye5z73zr13Zj4nyueeOefec5TWGiGEEL7P5OkAhBBCuIckfCGE6Cck4QshRD8hCV8IIfoJSfhCCNFPWDwdQEciIyN1YmKip8MQQog+Y+PGjcVa66i2jnl1wk9MTCQ9Pd3TYQghRJ+hlDrQ3jFp0hFCiH5CEr4QQvQTkvCFEKKf8Oo2fCGEb2psbCQ3N5e6ujpPh9Jn+fv7ExcXh9Vq7fR7JOELIdwuNzeX4OBgEhMTUUp5Opw+R2tNSUkJubm5JCUldfp9bmvSUUoFKqVeV0q9opS63l3fK4TwPnV1dUREREiy7yKlFBEREaf8C6lbCV8ptVApVaiU2nHc/jlKqd1Kqb1KqYdcuy8D3tda3wb8pDvfK4To+yTZd09X/n7dreG/Bsw5Lggz8CJwIZACXKuUSgHigBzXaY5ufm/H1i6A7e9DY22vfo0QQvQl3Ur4WuuVQOlxuycDe7XWWVrrBmAxcDGQi5H0O/xepdTtSql0pVR6UVHRqQfldMLGRfDBLfC34fDRLyHrO2O/EEIAJSUlpKamkpqayqBBg4iNjW153dDQ0OF709PTueuuu07p+xITEykuLu5OyD2iNzptYzlakwcj0U8BngdeUErNBT5t781a65eBlwHS0tJOfXYWkwl+vhqyv4dt/4EfP4Ytb0JILIy7AhJngF8w2IKMdfO2xXbKXyWE6JsiIiLYsmULAL///e8JCgrivvvuazne1NSExdJ2ekxLSyMtLc0dYfa43kj4bTUsaa11NfCzXvi+E5nMkDzLWC76G+z5Ara+C6tfgB/+0fZ74qfAT5eALdAtIQohvMvNN99MeHg4mzdvZtKkSVx99dXcc8891NbWEhAQwKJFixg5ciQrVqzg6aef5rPPPuP3v/89Bw8eJCsri4MHD3LPPfectPb/zDPPsHDhQgBuvfVW7rnnHqqrq7nqqqvIzc3F4XDw6KOPcvXVV/PQQw/xySefYLFYOP/883n66ae7VcbeSPi5QHyr13FAXi98T+fY7DD2cmOpLoaSfdBQCfVV0FBlrKsK4Ifn4KNfwJWvgXQmCeE2f/h0Jz/mVfToZ6YMDuF388ec8vv27NnD119/jdlspqKigpUrV2KxWPj666/5zW9+wwcffHDCezIyMvj222+prKxk5MiR/PznP2/33viNGzeyaNEi1q1bh9aaKVOmMGvWLLKyshg8eDBLly4FoLy8nNLSUpYsWUJGRgZKKcrKyk65PMfrjYS/ARiulEoCDgHXANedygcopeYD84cNG9azkQVGGktbAsLgq8dg1d9h5n1tnyOE8GlXXnklZrMZMJLuTTfdRGZmJkopGhsb23zP3Llz8fPzw8/Pj+joaAoKCoiLi2vz3O+//55LL72UwECjJeGyyy5j1apVzJkzh/vuu48HH3yQefPmMWPGDJqamvD39+fWW29l7ty5zJs3r9vl61bCV0q9A8wGIpVSucDvtNb/VkrdCSwDzMBCrfXOU/lcrfWnwKdpaWm3dSWuJocTk1KYTKdQUz/zLsjfBsufgEHjYMQFXflqIcQp6kpNvLc0J2KARx99lLPOOoslS5aQnZ3N7Nmz23yPn59fy7bZbKapqandz9e67W7JESNGsHHjRj7//HMefvhhzj//fB577DHWr1/PN998w+LFi3nhhRdYvnx51wrm0t27dK7VWsdora1a6zit9b9d+z/XWo/QWg/VWv+pWxGeooYmJ7f+Xzp//mLXqb1RKfjJ/zOS/Qe3QnFm7wQohOgTysvLiY2NBeC1117rkc+cOXMmH330ETU1NVRXV7NkyRJmzJhBXl4edrudG264gfvuu49NmzZRVVVFeXk5F110Ec8991xLJ3N3+NzQClazYki4nVdW7WdwWAA/m9b5x46x2eGat+Hl2fDOtXDbN+Af2muxCiG81wMPPMBNN93EM888w9lnn90jnzlp0iRuvvlmJk+eDBidthMnTmTZsmXcf//9mEwmrFYrL730EpWVlVx88cXU1dWhtebZZ5/t9ver9n5ieFKrNvzbMjNPvabtcGrueHMj32YU8tld0xk1KOTUPiD7e/i/i2HYecYFwCSDigrRk3bt2sXo0aM9HUaf19bfUSm1UWvd5n2jXpnJtNafaq1vDw3tWu3abFL89fLxhAZYufTF1Ux6/CtufX0D+4qq2HGonJV7ith88AiVdW13wpA4HeY8ZdzOueLJbpRECCG8h8816TQbEGjj8UvG8ou3NlHb6GB5RiFf7yo84bzYsAASwu3EhwfwkwmxTB/uuovn9Fshfyus/Jtxj/7w89xcAiGE6Fk+m/ABLhoXw4e/OJNAmwWbxcSqzCKig/2ICPLjSHUDmYVVZByuJL+slmU7C/hPei5XnBbHny8bh9VsMh7aytsMS/4X7vgeQgZ7ukhCCNFlPp3wASYlDGjZToo89ina81vdDVbf5ODF5Xt5fvleAm1m/nDxWLAGwBWLjE7cD26FGz8Bs8//yYQQPsor2/CVUvOVUi+Xl5e77Tv9LGZ+ff5IfjYtkdfXHGB7ruu7o0bAvGfgwA+w8q9ui0cIIXqaVyb87nbadse9540gxN/Cq99nHd054RpIvQG++ytkrXB7TEII0RO8MuF7Uoi/lfkTBrNs52Gq61s9MXfRXyFyBHxwG1Sd2PkrhOg7ujM8MsCKFStYvXp1m8dee+017rzzzp4OuUdIwm/D3PEx1DU6WbOv5OhOW6AxsFp9BXx4m4yvL0Qf1jw88pYtW7jjjju49957W17bbCcfKr2jhO/NJOG34bQhAwiwmlmVedwELANT4EJXs863T8CRbGio9kSIQogetnHjRmbNmsVpp53GBRdcQH5+PgDPP/88KSkpjB8/nmuuuYbs7GwWLFjAs88+S2pqKqtWrWr3Mw8cOMA555zD+PHjOeecczh48CAA7733HmPHjmXChAnMnDkTgJ07dzJ58mRSU1MZP348XXno9GTklpM2+FnMnDE0gqXbD/PLs4YRHeJ/9OCkGyF7lTGq5qq/G/usga6ROKOMdcAA8A8z1gGutV8wNNUZF4jWS2OtMY6P2Qomq2ttOfraZG61bTHuEjJZwWwzts22Vq9br5u3bcavE4tfW0UVwvO+eAgOb+/Zzxw0Di58qtOna6351a9+xccff0xUVBTvvvsujzzyCAsXLuSpp55i//79+Pn5UVZWRlhYGHfccccJk6a05c477+TGG2/kpptuYuHChdx111189NFH/PGPf2TZsmXExsa2DHu8YMEC7r77bq6//noaGhpwOHp+JlivTPi9NjzyKZg3PoblGYVMfvIb9v7pQixmU3NwcMlLMP4aqDoM1UVQVWSsq4ug/BAU7ITaMmPc/Y4oE1jtoDU4GsDZzpO/PcEWDPZw44JkjwB7JESNNDqkgwf13vcK0QfU19ezY8cOzjvPeMDS4XAQExMDwPjx47n++uu55JJLuOSSS07pc9esWcOHH34IwE9/+lMeeOABAKZNm8bNN9/MVVddxWWXXQbAGWecwZ/+9Cdyc3O57LLLGD58eA+V7iivTPjdHR65J8wZO4jffrSDmgYHmYVVjI5pNR6P2QrDzz35hziaoK4c6sqMtTXASPC2oKO17taTrWgNToeR+B2N4GwylhO2XcdbthuOvm7Zbji61FdBTQnUFBvrqgIo+BG2vg3f/BFGzIHTboJh5xq/KFopr23kL19m8OjcFAJsxx4TokecQk28t2itGTNmDGvWrDnh2NKlS1m5ciWffPIJjz/+ODt3ntJo78dQrn/vCxYsYN26dSxdupTU1FS2bNnCddddx5QpU1i6dCkXXHABr776ao8N2tbMKxO+N7DbLHx+1wxmP72CrTllxyb8zjJbIDDCWDpDKVczjcW4OPS2kn2w6XXY8jbsXmrM+zvxBki9DgYkAvDit3t5e91BhkYFccv0Uxh5VIg+xM/Pj6KiItasWcMZZ5xBY2Mje/bsYfTo0eTk5HDWWWcxffp03n77baqqqggODqai4uSzdJ155pksXryYn/70p7z11ltMnz4dgH379jFlyhSmTJnCp59+Sk5ODuXl5SQnJ3PXXXeRlZXFtm3bejzhS6dtB4ZE2AkPtLEmq+TkJ/dFEUPhvD/CvT/CVW9A9GjjWYN/TIAF0+HbJ4mq2AXodiduEMIXmEwm3n//fR588EEmTJhAamoqq1evxuFwcMMNNzBu3DgmTpzIvffeS1hYGPPnz2fJkiUn7bR9/vnnWbRoEePHj+eNN97gH/8w5tS+//77GTduHGPHjmXmzJlMmDCBd999l7Fjx5KamkpGRgY33nhjj5fTK4dHbpaWlqbT09M9GsNDH2xj8YYcXv+fycwaEeXRWNyi7CDs/Ah2fw4560A7ydfhHIk7h5QLboWEqZ6OUPgAGR65Z/jE8Mje5OrT47GaFbe8toEtOWWeDqf3hSXAtLvgf76E+zL5OPG3bHUOZXj+p7BwjjGYnBCiT5KEfxITEwaw8dHz8LeaWbz+oKfDca/ASLZFzuWOxnt5a9qXxl0+/33U6FwWQvQ5XpnwPTF4WkdC/K2cNSqaxRtyOOfvK8gvr/V0SG7TfA9RvSUEZj9sPIOw+wuPxiR8gzc3J/cFXfn7eWXC9+Tgae1pvkNlX1E1n28/7OFoPOS0myFiOHz1qHHrpxBd5O/vT0lJiST9LtJaU1JSgr+//8lPbkVuy+yk1PgwPvrlNK59eS1//TKDV1dlsfj2qZiUIsTfitmsCPLzvT9n68cEMFvh/MfhnWsgfRFMud1jcYm+LS4ujtzcXIqKik5+smiTv78/cXFxp/Qe38tQvSg1Pow/XjyGL3Yc5tvdhcz62woAwuxWTEoRFmAldkAAA+w2kqMCKa1uIDEiELNJMTQqCH+riehgf8xmhVkpTCawmEzHbFvMCotJtTyg4S2anK6a2Ig5kDgDVvwZJlwN/t7zK0z0HVarlaQkea7D3SThn6Ir0+K5Mi2eV1dl8cTSXQA0OTRhdgv55XVkFVfjZzFR39T10TRNyhjPx2YxYTWbsJkVVte2xaSwWYy11WzCZjHhZzHhbzW7FhP+FmM7wGZuORboZyYswEaY3UqY3UZ4oI0Qf0unLywNzeVRCs5/wpgFbNXfjfv4hRB9giT8Lrp1RjK3zkimsKKOMLsNm8XoDqmqbyLQZqa8thF/q7Gua3SQX15HRW0jlXVNOJwah9Y0OTVOpzZeu/Y1Njmpb3JS1+igyalpcDhpbHLS4HDS5NCutZNGh6bR4aSqvoniKif1jQ7qGh3UNTmpbXBQ1+Q46c00wf4WhkTYSYkJYWpyBOeMGkio3XrMOY0O40OOuYANTjXG4Fn7EqTdAgOG9OSftmdoDZWHoSgDinYfXTdUQXAMhMQYTxY3b0eOhLB4T0ctRK+ShN9Nx4ykCS3t+GF2Y0xtf6sx/syQiGPn0+1tWhsXh7pG4+JRXd9EWW0j5TWNHKlpoKSqgZwjNewvrm6ZwD3AaubmaYnce+6IlgtYc6JvOP4Xy9mPGg9offMHuGKhW8vWriMHYM+XxpK7Eepb3eXlH2Y8SRwcA5V5cGijMbZQa8POgyn/C0PPAZNX3s8gRLdIwvdRSin8LGb8LGZCA6wdnut0arYfKue11dm8tGIfa7NKeOOWKQT5WVoS/Rtrs7ljVvLRC1xoLJx5J6z8G0z9BcS1+WBf9zmdsO8bY/A5/1AjcfuHGsNO+wUbI5Pu/sJYCl2DWkUMh3GXQ3SKMSJo1Chj6Orjm6+a6qEyHyryjTkONi6Ct66AAUkw+TZIvd74HiF8hFcOrdBqeOTbemMSANG+z7blcffiLVw8YTDPXJ3Kr97ZzKdb81qO/25+Cj+b5upsq6+E5ydBeBJc807nB4nrDKcDfvwIvvsbFO3q+FxlhoQzYOSFxhIxtGvf2dQAuz6B9a9AzlpjZNNJN8J5j4Pl5LMgCeENOhpawSsTfjNvGEunP3p62W5e+HYvn945nRe+zWTZzoJjjmc/Nffoi81vwse/NMb2b0m6F3U96TqaYOeHxi+H4j1G2/rM+yFmvFHLry1rNeR0GYQNMYZ1tod3tbhty9sC6/5lDCE95jK4/NUTho4Wwht1lPClSUec4PZZySz6YT+LVu8/se3+eBNvgIFjIGOp0azy398aS9QoI/mPmg+xk05sTjleTSlkfAbfPwel+yB6jDGH8OiLPdOePjgVLn0JokbA1783mnbmPnPycgjhxSThixOE+FuZN34wn2/PZ/TgTswDMHiisZz9W2Oe391fGuPr//A8fP8sBA+G0fNg1DwYMs0Y719rKM6EPV/AnmVwcC1ohzE13dVvwsi53tFxOv1e41fFD88ZU1We85inIxKiyyThizadPTqad9NzWL+/9IRjdY2OlruPTjAgEabeYSw1pUYyz/gMNr0B6182kmbiDCjYAaVZxnsGjoMZvzYe6oo9zftq0ef+3mg+WvV3o9N42l0eDkiIrpGEL9o0bVgkJgXONrp4znv2Ox6bN4bzUgZ2/CH2cEi91lgaqmHfctj1KexfZTQDTf2FkeS9/f53pYzmnLpyYxyhgDCjM1eIPkYSvmhTkJ+FEQODyTh84kTsOaW13PHmRvY9eVHnP9AWCKPnG0tfZDLDpS8bdyZ9ejf4hcCYSzwdlRCnxAsaSYW3mpgQ1u4xb767q9dYbMZUkHGT4cPbjAe9hOhDJOGLdl0wZlDLdsBxbfb9MN0bbHa4/BVwNMD2/3g6GiFOiSR80a5ZI6I4e1Q0T102jh8eOpsFN0xqOaY1vJeew7sb+tksYGBMA5lwJmx7T2b/En2KtOGLdimlWHjz6S2v54yNOeb4/e9vA+CSibH4WfrZQ0njroClv4bD242HwoToA7yyhu9tUxyKjmUWVHk6BPcbcymYLNKsI/oUr0z43jjFoWjfj/kVng7B/ezhxuia2z8wxv0Rog/wyoQvvNczV01gjqszd0pSOAFWM7v6Y8IHGH+lMdTygR88HYkQnSJt+OKUXDYpjssmxVFYWYfVZOJnr23gx7wKNh44wrjY0JZx9PuFEReCLQi2/QeSZno6GiFOqh/96xQ9KTrYnwGBNlIGh7BufymXv7SaN9f2s/vSbXbjQbIfP4HGOk9HI8RJScIX3ZISc3RwtS93HOa5r/dQ09DkwYjcbNyVxsxamf/1dCRCnJQ06YhuuXxSHFaz4qsfC/h6VyHrs43B1ibEhzF7RFSnJ0nvs5JmGbNpbX8PUn7i6WiE6JAkfNEtATYzV5+ewJCIQBodmu/2FPHc18YsZXeeNYz7Lhjp4Qh7mdkCYy+H9EVHp2EUwktJk47oEVOTI3j9fybzwnUTW/a9sfYAz3y1hz0FJw7A5lPGXQWOeqMtXwgvJglf9Kg5YwZx24wk/ndWMuW1jTz/TSZ3vLGRukYHjY6TzJ7VV8VOgvBkeQhLeD1p0hE9ymI28cjcFGobHNitFhodTl74di+jHv2S2LAA/n1zGkOjgrCafaiuoZRRy//uL1CRByGDPR2REG3yoX91wpsE2Mzcfe5w7rtgJL+YPZRZI6I4VFbLnOdWcf6zKymsrPOtIZbHXwVo2PGBpyMRol3Km//RpaWl6fT0dE+HIXrIrvwKlmw+xOurs6lvcjIwxI8nLx3H8OhgEiLsng6v+14+C5xNcMcqT0ci+jGl1EatdVpbx6RJR7jN6JgQRseEcNbIaJZuz+PLHYe55XXjgn7pxFimDYtk3viY9ufL9Xbjr4IvH4LCDIge5elohDiB1PCFxxRU1LEqs5j07FLeTc9Bawj2szAhPoz5E2I4PTGc5KggT4fZeVWF8PdRcOadcN4fPR2N6Kc6quF7ZcJXSs0H5g8bNuy2zMxMT4cj3KCu0cGG7FI+25rPD/uKyT1SC0BihJ3ZI6OZNTKKSfEDCLVbPRxp+7TWHFl4JQNKt6J+/SOYvTdW4bv6XMJvJjX8/qnJ4SSruJq1WSV8m1HI6n0l1Dc5MSnjCd4xg42moVGDQhg1KJhAP+9omfzPhhy+WrKQV2zPwLWLYeSFng5J9EPShi/6FIvZxIiBwYwYGMyNZyRS2+Bg08EjrM0qYV1WKR9vzuPNtcbUikrBkHA7owa5LgIxwaTEhBA3IMDtwzr8mF/Bt85Uam3hBGx+UxK+8DqS8IXXC7CZmTYskmnDIgGj6ST3SC0ZhyvZlV/BrvwKMg5XsuzHwy1TzAb5WRg1KLjlIjA6JoSUmJBe7RCuqm+iCQsH4n7CqD1vQlURBEX12vcJcaok4Ys+RylFfLid+HA756UMbNlfXd/E7oJKMvKPXgiWbD5E1Vpj9E6b2URqQhhnJEcwNTmCiQlhPXoBqK43vudgwqWMynoNtr1rdOAK4SUk4QufEehnYVLCACYlDGjZ1/xrYGdeBRsPlLI2q5Tnl2fyj28ysVlMTEoI44zkSM4cFsHE+DAs3XgCuMqV8IsCkiA2DTa/CWf80mh3EsILSMIXPq31r4E5Y42pGctrG9mwv5S1WSWsySrhuW/28OzXMMBu5axR0Zw7eiCzR0Zht53aP4/mGn5doxMmXg+f3Qt5myD2tB4vlxBdIQlf9DuhAVbOTRnIua7moPKaRr7fW8zXuwpYnlHIh5sOEeRn4ZKJg7lu8hBSBoec5BMN1fXGZOZ1jQ5Iuxy+fNio5UvCF15CEr7o90LtVuaOj2Hu+BiaHE42ZB/hvfQc/pOey5trDzJzRBS/nD2UyUnhHd75U1nXCEBtg8MYFz/lYtj+AVzwJFgD3FUcIdolg6cJ0YrFbOKMoRE8c3Uq639zDg/MGcnOQ+Vc/fJably4nr2FVe2+t6zWSPh1jUZNn9TrjekPd33mjtCFOClJ+EK0I8xu4xezh/H9g2fz27mj2XKwjDnPreRf3+3D6Tz2gcX6Jgc1DUaif/X7/SzbeRgSZ0BYAmx+wxPhC3ECSfhCnESAzcytM5L59v7ZnDt6IH/+IoPb30hv6aQFoyO4tUeW7ACTCVJvgP0r4cgBd4ctxAkk4QvRSZFBfrx0wyR+Nz+Fb3cXcePC9S3t9uU1xyb8QD/X/f2p1xrrre+4M1Qh2iQJX4hToJTiZ9OSeOHaiWzNKePuxVtwOHVL+32zgOYHusISIHkWbHkLnD46xaPoMyThC9EFF46L4Xc/GcPyjEKe/yaTsuNq+A6npqbB1eSTegOUHYQDP3ggUiGOkoQvRBf9dOoQLp0Yy4vf7mX1vuJjjmUWVnHW0yuMF6MuApMV9n7l/iCFaEUSvhDd8Lv5KYTZbSz6IfuEYwUV9caGLRDiTjc6b4XwIEn4QnRDmN3GE5eMASA1PoyYUP+2T0yeBXlboPaI+4IT4jiS8IXopjljY1h8+1TevHUKy+6dSUpMG0MxJM0ENGR/7/b4hGgmCV+IHjA1OYIgPwsh/lbC2pqGMTYNrHZp1hEeJQlfiB5W39TG7ZcWGyScAVnfuT8gIVwk4QvRwxpaJfzW2yTPguLdUHnYA1EJIQlfiB7XOslXtRp+gaRZxlqadYSHuC3hK6WSlVL/Vkq9767vFMITWo+f3zz0AgCDxoF/mDTrCI/pVMJXSi1UShUqpXYct3+OUmq3UmqvUuqhjj5Da52ltb6lO8EK0Rc8eek4bp+ZDMAD7287esBkhqQZsP87WmZbF8KNOlvDfw2Y03qHUsoMvAhcCKQA1yqlUpRS45RSnx23RPdo1EJ4sQCbmWnDIgFYt7+U0uqGoweTZkF5DhzZ76HoRH/WqYSvtV4JlB63ezKw11VzbwAWAxdrrbdrrecdtxR2NiCl1O1KqXSlVHpRUVGnCyKENxkaFdiyvTar5OiB5nZ8adYRHtCdNvxYIKfV61zXvjYppSKUUguAiUqph9s7T2v9stY6TWudFhUV1Y3whPCcuAF29jxxIRaTYseh8qMHIodDcIx03AqP6M6ctm1N7tluw6TWugS4oxvfJ0SfYrOYCPa3UNG641Yp46nbvd8YwyWb5EY54T7d+b8tF4hv9ToOyOteOEL4lpAAKxW1TcfuTJoFNcVQtMszQYl+qzsJfwMwXCmVpJSyAdcAn/REUEqp+Uqpl8vLy09+shBeLNjfcuytmeAaVwdpxxdu19nbMt8B1gAjlVK5SqlbtNZNwJ3AMmAX8B+t9c6eCEpr/anW+vbQ0NCe+DghPCbYz0pl3XE1/LB4CE82bs8Uwo061Yavtb62nf2fA5/3aERC+JBgfwsHSmpOPJA0C7a/D44mMHenK02IzpMeIyF6UUiAld0FlXyy9bjureRZ0FAJeZs9E5jolyThC9GLgv2N2vtd7xyX2BNnGOv9K9wbkOjXvDLhS6et8BVWczv/xAIjYeA4uR9fuJVXJnzptBW+orCirv2DSTPh4DporHVfQKJf88qEL4SvSAi3t38weRY46iFnvfsCEv2aJHwhetGdZw9nbGwIA9qa9nDImaBMMs+tcBtJ+EL0IpvFxLShkdQ2Ok486BcMA8dCzjr3Byb6Ja9M+NJpK3yJv9VMXaMTp7ONoabip8Chjcb9+EL0Mq9M+NJpK3yJ3WYGoK6pjVp+wlRoqILCHnlIXYgOeWXCF8KXNCf8moY2En78ZGMtHbfCDSThC9HL/K1Gwr/ulbUnHgyNN8bHl3Z84QaS8IXoZXab8bTtnoIq9PFz2SpltOMflIQvep9XJnzptBW+pLlJB2j7bp34KVB+ECpkOgnRu7wy4UunrfAlzU06AFX1bdyNEz/FWEs7vuhlXpnwhfAlVvPR2UCrjh8bHyBmPFgCpB1f9DpJ+EL0MqVaJfy2avhmK8ROkoQvep0kfCF62aSEMK6dnAC0U8MH4/bM/K0ykJroVZLwhehlSimun2Ik/Mq2avgA8VPB2QSHNrkxMtHfSMIXwg2C/IxbM6vbS/hxpxtradYRvcgrE77clil8TZBr5qs22/ABAiMgYrjcqSN6lVcmfLktU/ia5hp+ZXtt+AAJU4wa/vEPZwnRQ7wy4Qvha/wsJiwmxSdb8k582rZZ/BSoLYWSve4NTvQbkvCFcAOlFGF2K7sLKsk4XNn2SS0PYEk7vugdkvCFcJMFN5wGQEF789xGDIeAAXCwjUHWhOgBkvCFcJOoYD8Aiirr2z7BZIK4ydJxK3qNJHwh3CQyyEj4BRV1lFS1k/TjJ0PxbqgpdWNkor+QhC+EmwT6WQiwmnn6v3s47YmvaXQ4TzwpYaqxzt3g3uBEvyAJXwg3igy2tWwfLm+jLX/wJFBm6bgVvcIrE748eCV8VXOzDkBeWRvj5tjsxuiZMiGK6AVemfDlwSvhq5IiA1u2n/lqDw5nG/fkx0+FQxvB0ejGyER/4JUJXwhfdUZyRMv2uv2lrN5XfOJJ8ZOhqRYOb3djZKI/kIQvhBudM3og4YFH2/HbHEyt+QEsuR9f9DBJ+EK4UXigjU2PnseCGyYBcKSmjWab0FiIToEtb8m4OqJHScIXwgNmjYgG4EhNQ9snnHEnFOyAfd+4MSrh6yThC+EBATYzfhYTZW3V8AHGXQnBMfDDP9wbmPBpkvCF8JAwu5Wy9mr4FhtM/QXsXymzYIkeIwlfCA8ZYLe13Ybf7LSbwS8EVj/vtpiEb5OEL4SHmE2Kr34sILu4uu0T/EMg7X/gx4+hNMu9wQmf5JUJX560Ff3B2MHGg4Wvrc5u/6SpPweTBVa/4J6ghE/zyoQvT9qK/uCx+SmMjglh/f4ORsYMHgQTrjFu0awqcl9wwid5ZcIXoj8I9LMwZ8wgdh2uoLS6nc5bgDPvgqZ6WP+y+4ITPkkSvhAedPaoaLSGb3YVtH9S5HAYNddI+PVV7gtO+BxJ+EJ40NjYEGLDAliy+VDHJ067G+rKYPObbolL+CZJ+EJ4kFKKm84cwup9JezKr2j/xPjJkHAGrHlBRtEUXSYJXwgPO2ukMcxCZuFJmmum3Q3lObBziRuiEr5IEr4QHhYd7A9AYUUbM2C1NvwCiBoNK/4MTR108grRDkn4QnhYSIAFm8VEYWU7E5s3M5ng/MeNh7A2vOKe4IRPkYQvhIcppYgO9jt5DR9g+Hkw9BxY8ReoLun94IRPkYQvhBeIDvZj2c4CKuo60SF7wZ+gocpo2hHu5WiEmlKoyAOnw9PRnDKLpwMQQoBDQ22jg39+u4+HLhzV8cnRoyHtZ5C+EE6/FaJPcr44NfVVkP097FsOh9KhrhzqKqC+0ph6spnFHyKGQ9QIiBoFkc3r4WAyey7+DkjCF8IL/HxWMne8uYn9xZ18sGr2b2Dbe/DfR+CGD3o3OF/ndELBdtj7jZHkD64FZyNYAiD+dBiQCH7BxsilfiHGoHYmi9GXUrwHcjfAjlb/DfxCIC7NmIw+frKx7RfsseK1JglfCC8wZ2wMs0ZEsWxnAZsOHmFSwoCO3xAYAbMeMBJ+5ldG277oHKcTCncatfjs7+HAD1B7xDg2cJwxYN2wc4yEbfXv3Gc2VENxJhTugtz1kLPe1eSmQZlg4BgYMh0Sp8OQM8Ee3mvF64jSXjxnZlpamk5PT/d0GEK4xYPvb+Pd9BwAsp+ae/I3NDXAP6eAyQo//wHM1l6OsA9qrIPSfUZNvHgv5G8xknxdmXF8QKKRhBNnQPJZEDyw5767rhxy0yFnHRxcAzkbXE1CCgaONb43aYbxQF0PXgCUUhu11mltHfPKGr5Saj4wf9iwYZ4ORQi3iQ7xa9neX1xNUmRgx2+w2OC8x+Hd62HjazD5tt4N0FtobQwm11ANtaVQXQRVhca6ebs8x6hxlx0EWlVqByTB6HlGgh8yDcLiey9O/1Djl8Kwc4zXTQ1waKPrl8Uq2LgI1r0EU+6AC//Se3G0IjV8IbxETUMTn27N48EPtvObi0Zx+8yhJ3+T1vD6fCjYCXdtgoCTNAV5mtMBtWVGDbuuHOorjnaKtrx2dZDWN3eWVhgdqQ3VrqUKdHt3yCiwR0DIYKPzNHKEsY4YDhFDwXaSi6g7NdUb01fawyFqZI99bEc1fEn4QniZOc+tJONwJQtvTuPsUZ1oYsjfBv+aaTQRjLwIBqfCoHHu7ShsqIbyQ0bNuuIQlOcar2uKjdsYa0qM2nhtGcfUuNvS3DnqF2x0kPoFG4stEGzN60CwBRm16KAoCIyGoGgICAezVzZcuE2fa9IRoj+7+cxEHvpwO//ZkNu5hB8zHs79HaxdANkPu3YqiBgGMROMO0WGnWvUcLtCayORF+w0knnlYajMN9YV+VCZd7TTs5kyQdBACIw0atyhccbaHm4k5YAwI1n7hRhr/1AjuduCjSeKRa+QGr4QXuiXb29i6bZ8/ndWMg9fOLrzb6w8bNT487dA/lbI2wIVucax8GQYdp5xR0/idLAGnPh+rY3aeN5mo7350CZjXVN89JzmZB48CIJjjHVoHITGu9Zxxn7pRPYIqeEL0cckhNsB+Nd3WTw0ZxRKqc69MXiQsYw4/+i+0v2w92vj9s1N/wfr/2U8NBQ1ChwNRnNMY61rqWnVPq6Mc0ZcAIMnQkyqkcyDor32wSLRMUn4Qnih66ck8NKKfQDkHqkl3nUB6JLwJOMOnsm3GbcpHvjBuAAU7zFq+Va7ax1orAMGGP0AMRO85oEh0TMk4QvhheIG2PnsV9OZ9/++578/FnDL9KSe+WCr/7G3Cop+RXpHhPBSKTEhTB8Wyd+WZVBWI+Pfi+6ThC+ElzKZFL+5aDR1jU7e35jr6XCED5CEL4QXSxkcwuiYEJZnFHo6FOEDJOEL4eXOSI5g9b4SDpXVnvxkITogCV8ILzdjeCQANy1c7+FIRF8nCV8ILzd7ZBSXTYxlb2GVdN6KbpGEL4SXU0px1enGqI6/emcz3vx0vPBukvCF6ANS48OIDQtgVWYx3+ySDlzRNZLwhegD/K1mVtw/m4RwO09+vouiynpPhyT6IEn4QvQRVrOJP1w8hqziaq761xqcTmnaEadGEr4QfchZI6N54pKx7C+u5tJ//kBDk9PTIYk+RBK+EH3MFafFMTDEj6255XyyNc/T4Yg+RBK+EH2Mv9XM2ofPYXRMCE9+vovcIzWeDkn0EZLwheiDlFK8eN1E6hsdPLJkh6fDEX2EJHwh+qjkqCDuOXcE3+0p4r30HE+HI/oAtyV8pdQlSqlXlFIfK6XOP/k7hBAnc/3UBGLDArj//W28sSZbHsoSHepUwldKLVRKFSqldhy3f45SardSaq9S6qGOPkNr/ZHW+jbgZuDqLkcshGhht1n44p4ZJEUG8ujHO/nbst2eDkl4sc7W8F8D5rTeoZQyAy8CFwIpwLVKqRSl1Dil1GfHLdGt3vpb1/uEED0gxN/KR7+cxpwxg/jnin384dOdco++aFOnpjjUWq9USiUet3sysFdrnQWglFoMXKy1/jMw7/jPUMYszE8BX2itN7X3XUqp24HbARISEjoTnhD9XmiAlRevn8QTS39k0Q/ZNDQ5+e3cFAJsMtm4OKo7c9rGAq17inKBKR2c/yvgXCBUKTVMa72grZO01i8DLwOkpaVJNUWITjKbFI/NS6G+yclb6w6SWVDFk5eNZVi0TEQuDN3ptFVt7Gs3QWutn9dan6a1vqO9ZC+E6B6lFE9eOo4/XzaOLbllXLFgDZ9ty5POXAF0L+HnAvGtXscB8tifEF7g2skJfH3vLAaHBnDn25u57pV17D5c6emwhId1J+FvAIYrpZKUUjbgGuCTnghKKTVfKfVyeXl5T3ycEP1SQoSdT381nScuGcuuwxXM+cdKrn91Lav3FUuNv59SnfkPr5R6B5gNRAIFwO+01v9WSl0EPAeYgYVa6z/1ZHBpaWk6PT29Jz9SiH7pSHUDr63OZtEP+6moa2LUoGB+Ni2Rn0yIlY5dH6OU2qi1TmvzmDdf6SXhC9Gzahqa+Hz7YV5dlUXG4UpC/C1cmRbPtZMTGBoViHEznejLJOELIY6htWZD9hH+b002X+44TJNTEzcggOnDIpk+PJIzh0YSHmjzdJiiCzpK+N25LVMI0UcppZicFM7kpHAKK+r4cudhvs8sZum2fBZvyEEpGD0ohMlJ4aQlDmBCXBhxAwLkF0Af55U1fKXUfGD+sGHDbsvMzPR0OEL0G00OJ9sOlfN9ZjFrs0rYfLCM2kYHYDzcNTY2hDGDQxk5MJgRA4MZFh0kfQBeRpp0hBBd0uhwsiu/gu2HytlxqJwdhyrYfbiSBocx05ZSMCTczvCBwYwYGERiRCBDIgJJjLATFewnvwg8QJp0hBBdYjWbGB8Xxvi4sJZ9TQ4nB0pr2HO4kt0FlWQWVLG7oJLlGYU4Wo3hY7eZSQi3GxeBSNc63M6QyEBiQvwxmeRi4G6S8IUQp8RiNjE0KoihUUFcOC6mZX+jw0leWS3ZJTUcKKkmu7iGg6XV7C2qYnlGYcuvAgCbxUT8gACGRASSEG4nIdzOkAhjiRtgx98qzUS9QRK+EKJHWM0mhriadCDqmGMOp+ZwRR0HiqtbLggHSmo4UFrDuqwSqhscx5w/KMSfhAjXhSDcfnQ7IpABdqs0FXWRV7bhS6etEP2H1prS6gYOlNaQU1pjXAhKjF8HB0trKKioP+b8YD8L8a5fBAmui0FiRCDDooOIln4D6bQVQvRdtQ0Oco7UcND1i+BgiXEhOFBaQ25p7TFNRaEBVoZHBzFiUDAjooMYMTCY4QODiQr282AJ3Es6bYUQfVaAzcwI122gx2tuKsouriazoJI9hVVkFlSydFs+b9c2tpwXFezHuNjQo0tcKAND/N1ZDK8gCV8I0WeZTYrYsABiwwKYNiyyZb/WmqKqeuMOosOV7MyrYPuhMlbsLqT5RqKoYD/Gx4YyIT6MCfFhpMaFEWq3eqgk7iEJXwjhc5RSRAf7Ex3sf8yFoKahiR/zjOcKth8qZ1tuOct3F9Lcsp0cGWgkf9dFYHRMMH4W37ljyCvb8KXTVgjhLhV1jWzPLWdLTlnLUlRpdBTbzCZSBoeQ2uoikBhh9+qOYem0FUKITtJak19ex5acMrbmlLE5p4ztueUtQ0yE2a1MiDOS/0TXRcCbBpqTTlshhOgkpRSDwwIYHBbARa4Hy5ocTjILq9ja6lfAC8szW/oDEsLtLU1BqfFhjBkc4pUPj0kNXwghuqC6vonth8pbLgJbc8rIK68DwGJSjI4JYUJ8KKnxA0iNDyM5MtAtw0lIk44QQrhBYUUdm13Jf0tOGdtyy6mqbwIg2N/C2MGhjI0NYWxsKGMGh5AUGYS5hy8CkvCFEMIDnE7NvqKqlmagHYfK2XW4koYm42GxAKuZ0THBLReAMYNDGTEwGJul69ONS8IXQggv0ehwsq+oih2HKtiZV87OQxX8mF/R8kvAalaMGBjM27dNJTTg1J8L6HOdtq1uy/R0KEII0aOsZhOjBoUwalAIV5wWBxi/BA6U1rAzz5hzYH9xFSH+PZ+epYYvhBA+pKMaftcbioQQQvQpkvCFEKKfkIQvhBD9hCR8IYToJyThCyFEPyEJXwgh+glJ+EII0U94ZcJXSs1XSr1cXl7u6VCEEMJnePWDV0qpIuBAF98eCRT3YDh9gZS5f5Ay9w9dLfMQrXVUWwe8OuF3h1Iqvb2nzXyVlLl/kDL3D71RZq9s0hFCCNHzJOELIUQ/4csJ/2VPB+ABUub+QcrcP/R4mX22DV8IIcSxfLmGL4QQohVJ+EII0U/4XMJXSs1RSu1WSu1VSj3k6Xh6ilJqoVKqUCm1o9W+cKXUV0qpTNd6QKtjD7v+BruVUhd4JuruUUrFK6W+VUrtUkrtVErd7drvs+VWSvkrpdYrpba6yvwH136fLXMzpZRZKbVZKfWZ67VPl1kpla2U2q6U2qKUSnft690ya619ZgHMwD4gGbABW4EUT8fVQ2WbCUwCdrTa91fgIdf2Q8BfXNsprrL7AUmuv4nZ02XoQpljgEmu7WBgj6tsPltuQAFBrm0rsA6Y6stlblX2XwNvA5+5Xvt0mYFsIPK4fb1aZl+r4U8G9mqts7TWDcBi4GIPx9QjtNYrgdLjdl8MvO7afh24pNX+xVrreq31fmAvxt+mT9Fa52utN7m2K4FdQCw+XG5tqHK9tLoWjQ+XGUApFQfMBV5ttduny9yOXi2zryX8WCCn1etc1z5fNVBrnQ9GcgSiXft97u+glEoEJmLUeH263K6mjS1AIfCV1trnyww8BzwAOFvt8/Uya+C/SqmNSqnbXft6tcw9Py26Z6k29vXH+0596u+glAoCPgDu0VpXKNVW8YxT29jX58qttXYAqUqpMGCJUmpsB6f3+TIrpeYBhVrrjUqp2Z15Sxv7+lSZXaZprfOUUtHAV0qpjA7O7ZEy+1oNPxeIb/U6DsjzUCzuUKCUigFwrQtd+33m76CUsmIk+7e01h+6dvt8uQG01mXACmAOvl3macBPlFLZGM2wZyul3sS3y4zWOs+1LgSWYDTR9GqZfS3hbwCGK6WSlFI24BrgEw/H1Js+AW5ybd8EfNxq/zVKKT+lVBIwHFjvgfi6RRlV+X8Du7TWz7Q65LPlVkpFuWr2KKUCgHOBDHy4zFrrh7XWcVrrRIx/s8u11jfgw2VWSgUqpYKbt4HzgR30dpk93VPdCz3fF2HczbEPeMTT8fRgud4B8oFGjKv9LUAE8A2Q6VqHtzr/EdffYDdwoafj72KZp2P8bN0GbHEtF/lyuYHxwGZXmXcAj7n2+2yZjyv/bI7epeOzZca4k3Cra9nZnKt6u8wytIIQQvQTvtakI4QQoh2S8IUQop+QhC+EEP2EJHwhhOgnJOELIUQ/IQlfCCH6CUn4QgjRT/x/N+zKj8ArzI0AAAAASUVORK5CYII=n”, “text/plain”: [

“<Figure size 432x288 with 1 Axes>”

]

}, “metadata”: {

“needs_background”: “light”

}, “output_type”: “display_data”

}

], “source”: [

“import matplotlib.pyplot as pltn”, “plt.plot(np.arange(NUM_EPOCHS),train_losses,label=’Train loss’)n”, “plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label=’Test loss’)n”, “plt.legend()n”, “plt.yscale(‘log’)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Bonus: Try out model=nn.MLP(trainset.rep_in,trainset.rep_out,group=SO13()).to(device) and see how well it performs on this problem.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Converting Jax functions to PyTorch functions (how it works)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“You can use the underlying equivariant bases $Q\in \mathbb{R}^{n\times r}$ and projection operators $P = QQ^\top$ in pytorch also.n”, “n”, “Since these objects are implicitly defined through LinearOperators, it is not as straightforward as simply calling torch.from_numpy(Q). However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function emlp.reps.pytorch_support.torchify_fn to do this.”

]

}, {

“cell_type”: “code”, “execution_count”: null, “metadata”: {}, “outputs”: [], “source”: [

“import jaxn”, “import jax.numpy as jnpn”, “from emlp.reps import Vn”, “from emlp.groups import S”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“For example, let’s setup a representation $S_4$ consisting of three vectors and one matrix.”

]

}, {

“cell_type”: “code”, “execution_count”: 8, “metadata”: {}, “outputs”: [], “source”: [

“W =V(S(4))n”, “rep = 3*W+W**2”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“First we compute the equivariant basis and equivariant projector linear operators, and then wrap them as functions.”

]

}, {

“cell_type”: “code”, “execution_count”: 9, “metadata”: {}, “outputs”: [], “source”: [

“Q = (rep>>rep).equivariant_basis()n”, “P = (rep>>rep).equivariant_projector()”

]

}, {

“cell_type”: “code”, “execution_count”: 10, “metadata”: {}, “outputs”: [], “source”: [

“applyQ = lambda v: Q@vn”, “applyP = lambda v: P@v”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“We can convert any pure pytorch function into a jax function by applying torchify_fn. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects.”

]

}, {

“cell_type”: “code”, “execution_count”: 11, “metadata”: {}, “outputs”: [], “source”: [

“from emlp.nn.pytorch import torchify_fnn”, “applyQ_torch = torchify_fn(applyQ)n”, “applyP_torch = torchify_fn(applyP)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“As you would hope, gradients are correctly propagated whether you use the original Jax functions or the torchified pytorch functions.”

]

}, {

“cell_type”: “code”, “execution_count”: 12, “metadata”: {}, “outputs”: [], “source”: [

“x_torch = torch.arange(Q.shape[-1]).float().cuda()n”, “x_torch.requires_grad=Truen”, “x_jax = jnp.asarray(x_torch.cpu().data.numpy()) “

]

}, {

“cell_type”: “code”, “execution_count”: 13, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“jax output: [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]n”, “torch output: tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device=’cuda:0’,n”, ” grad_fn=<SliceBackward>)n”

]

}

], “source”: [

“Qx1 = applyQ(x_jax)n”, “Qx2 = applyQ_torch(x_torch)n”, “print("jax output: ",Qx1[:5])n”, “print("torch output: ",Qx2[:5])”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:”

]

}, {

“cell_type”: “code”, “execution_count”: 14, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“tensor([-2.8704, 2.7858, -2.8704, 2.7858, -2.8704], device=’cuda:0’)”

]

}, “execution_count”: 14, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]”

]

}, {

“cell_type”: “code”, “execution_count”: 15, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“DeviceArray([-2.8703732, 2.7858496, -2.8703732, 2.7858496, -2.8703732], dtype=float32)”

]

}, “execution_count”: 15, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“So you can safely use these torchified functions within your model, and still compute the gradients correctly.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“We use this torchify_fn on the projection operators to convert EMLP to pytorch.”

]

}

], “metadata”: {

“kernelspec”: {

“display_name”: “Python 3”, “language”: “python”, “name”: “python3”

}, “language_info”: {

“codemirror_mode”: {

“name”: “ipython”, “version”: 3

}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”

}

}, “nbformat”: 4, “nbformat_minor”: 4

}