{
“cells”: [
{

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“# Constructing Equivariant Models”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Previously we showed examples of finding equivariant bases for different groups and representations, now we’ll show how these bases can be assembled into equivariant neural networks such as EMLP. n”, “n”, “We will give examples at a high level showing how the specific EMLP model can be applied to different groups and input-output types, and later in a lower level showing how models like EMLP can be constructed with equivariant layers and making use of the equivariant bases.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Using EMLP with different groups and representations (high level)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“![ex 2.13](imgs/EMLP_fig.png)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“A basic EMLP is a sequence of EMLP layers (containing G-equivariant linear layers, bilinear layers incorporated with a shortcut connection, and gated nonlinearities. While our numerical equivariance solver can work with any finite dimensional linear representation, for EMLP we restrict ourselves to _tensor_ representations.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“By tensor representations, we mean all representations which can be formed by arbitrary combinations of $\oplus$,$\otimes$,$^*$ (+,`*`,`.T`) of a base representation $\rho$. This is useful because it simplifies the construction of our bilinear layer, which is a crucial ingredient for expressiveness and universality in EMLP.n”, “n”, “Following the $T_{(p,q)}=V^{\otimes p}\otimes (V^*)^{\otimes p}$ notation in the paper, we provide the convenience function for constructing higher rank tensors.”

]

}, {

“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“V⊗V⊗V*⊗V*⊗V*n”, “V²⊗V*³n”

]

}

], “source”: [

“from emlp.reps import Vn”, “from emlp.groups import SO13n”, “n”, “def T(p,q=0):n”, ” return (V**p*V.T**q)n”, “n”, “print(T(2,3))n”, “print(T(2,3)(SO13()))”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Lets get started with a toy dataset: learning how an inertia matrix depends on the positions and masses of 5 point masses distributed in different ways. The data consists of mappings (positions, masses) –> (inertia matrix) pairs, and has an $G=O(3)$ symmetry (3D rotation and reflections). If we rotate all the positions, the resulting inertia matrix should be correspondingly rotated.”

]

}, {

“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“Input type: 5V⁰+5V, output type: V²n”

]

}

], “source”: [

“from emlp.datasets import Inertian”, “from emlp.groups import SO,O,S,Zn”, “n”, “trainset = Inertia(1000) # Initialize dataset with 1000 examplesn”, “testset = Inertia(2000)n”, “G = SO(3)n”, “print(f"Input type: {trainset.rep_in(G)}, output type: {trainset.rep_out(G)}")”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“For convenience, we store in the dataset the types for the input and the output. 5V⁰ are the $5$ mass values and 5V are the position vectors of those masses, is the matrix type for the output, equivalent to $T_2$. To initialize the [EMLP](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLP), we just need these input and output representations, the symmetry group, and the size of the network as parametrized by number of layers and number of channels (the dimension of the feature representation).”

]

}, {

“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [], “source”: [

“import emlpn”, “model = emlp.nn.EMLP(trainset.rep_in,trainset.rep_out,group=G,num_layers=3,ch=384)n”, “# uncomment the following line to instead try the MLP baselinen”, “#model = emlp.nn.MLP(trainset.rep_in,trainset.rep_out,group=G,num_layers=3,ch=384)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Example Objax Training Loop”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“We build our EMLP model with [objax](https://objax.readthedocs.io/en/latest/) because we feel the object oriented design makes building complicated layers easier. Below is a minimal training loop that you could use to train EMLP.”

]

}, {

“cell_type”: “code”, “execution_count”: 4, “metadata”: {}, “outputs”: [], “source”: [

“BS=500n”, “lr=3e-3n”, “NUM_EPOCHS=500n”, “n”, “import objaxn”, “import jax.numpy as jnpn”, “import numpy as npn”, “from tqdm.auto import tqdmn”, “from torch.utils.data import DataLoadern”, “n”, “n”, “opt = objax.optimizer.Adam(model.vars())n”, “n”, “@objax.Jitn”, “@objax.Function.with_vars(model.vars())n”, “def loss(x, y):n”, ” yhat = model(x)n”, ” return ((yhat-y)**2).mean()n”, “n”, “grad_and_val = objax.GradValues(loss, model.vars())n”, “n”, “@objax.Jitn”, “@objax.Function.with_vars(model.vars()+opt.vars())n”, “def train_op(x, y, lr):n”, ” g, v = grad_and_val(x, y)n”, ” opt(lr=lr, grads=g)n”, ” return vn”, “n”, “trainloader = DataLoader(trainset,batch_size=BS,shuffle=True)n”, “testloader = DataLoader(testset,batch_size=BS,shuffle=True)”

]

}, {

“cell_type”: “code”, “execution_count”: 5, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “d8674d2ad786477f88ae916672683edb”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

” 0%| | 0/500 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}

], “source”: [

“test_losses = []n”, “train_losses = []n”, “for epoch in tqdm(range(NUM_EPOCHS)):n”, ” train_losses.append(np.mean([train_op(jnp.array(x),jnp.array(y),lr) for (x,y) in trainloader]))n”, ” if not epoch%10:n”, ” test_losses.append(np.mean([loss(jnp.array(x),jnp.array(y)) for (x,y) in testloader]))”

]

}, {

“cell_type”: “code”, “execution_count”: 6, “metadata”: {}, “outputs”: [

{
“data”: {

“image/png”: “”, “text/plain”: [

“<Figure size 432x288 with 1 Axes>”

]

}, “metadata”: {

“needs_background”: “light”

}, “output_type”: “display_data”

}

], “source”: [

“import matplotlib.pyplot as pltn”, “plt.plot(np.arange(NUM_EPOCHS),train_losses,label=’Train loss’)n”, “plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label=’Test loss’)n”, “plt.legend()n”, “plt.yscale(‘log’)”

]

}, {

“cell_type”: “code”, “execution_count”: 7, “metadata”: {}, “outputs”: [], “source”: [

“from jax import vmapn”, “def rel_err(a,b):n”, ” return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#n”, “n”, “rin,rout = trainset.rep_in(G),trainset.rep_out(G)n”, “n”, “def equivariance_err(mb):n”, ” x,y = mbn”, ” x,y= jnp.array(x),jnp.array(y)n”, ” gs = G.samples(x.shape[0])n”, ” rho_gin = vmap(rin.rho_dense)(gs)n”, ” rho_gout = vmap(rout.rho_dense)(gs)n”, ” y1 = model((rho_gin@x[…,None])[…,0],training=False)n”, ” y2 = (rho_gout@model(x,training=False)[…,None])[…,0]n”, ” return rel_err(y1,y2)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“As expected, the network continues to be equivariant as it is trained.”

]

}, {

“cell_type”: “code”, “execution_count”: 8, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“Average test equivariance error 4.34e-07n”

]

}

], “source”: [

“print(f"Average test equivariance error {np.mean([equivariance_err(mb) for mb in testloader]):.2e}")”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Breaking EMLP down into equivariant layers (mid level)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Internally for EMLP, we use representations that [uniformly allocate dimensions](https://emlp.readthedocs.io/en/latest/package/emlp.models.mlp.html#emlp.models.uniform_rep) between different tensor representations.”

]

}, {

“cell_type”: “code”, “execution_count”: 9, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“122V⁰+40V+12V²+3V³+V⁴n”

]

}

], “source”: [

“from emlp.nn import uniform_repn”, “r = uniform_rep(512,G)n”, “print(r)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Below is a trimmed down version of EMLP, so you can see how it is built from the component layers Linear, BiLinear, and GatedNonlinearities. These layers can be constructed like ordinary objax modules, using the input and output representations.”

]

}, {

“cell_type”: “code”, “execution_count”: 12, “metadata”: {}, “outputs”: [], “source”: [

“from objax.module import Modulen”, “n”, “class EMLPBlock(Module):n”, ” """ Basic building block of EMLP consisting of G-Linear, biLinear,n”, ” and gated nonlinearity. """n”, ” def __init__(self,rep_in,rep_out):n”, ” super().__init__()n”, ” rep_out_wgates = emlp.nn.gated(rep_out)n”, ” self.linear = emlp.nn.Linear(rep_in,rep_out_wgates)n”, ” self.bilinear = emlp.nn.BiLinear(rep_out_wgates,rep_out_wgates)n”, ” self.nonlinearity = emlp.nn.GatedNonlinearity(rep_out)n”, ” def __call__(self,x):n”, ” lin = self.linear(x)n”, ” preact =self.bilinear(lin)+linn”, ” return self.nonlinearity(preact)n”, “n”, “class EMLP(Module):n”, ” def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):n”, ” super().__init__()n”, ” reps = [rep_in(group)]+num_layers*[uniform_rep(ch,group)]n”, ” self.network = emlp.nn.Sequential(n”, ” *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],n”, ” emlp.nn.Linear(reps[-1],rep_out(group))n”, ” )n”, ” def __call__(self,x,training=True):n”, ” return self.network(x)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The representations of the hidden layers (taking place of the number of channels in a standard MLP) is by default given by this uniform_rep shown above. Unlike this pedagogical implementation you can specify the representation of the hidden layers directly in the [full EMLP](https://emlp.readthedocs.io/en/latest/package/emlp.nn.html#emlp.nn.EMLP) by feeding in a representation to the ch argument, or even a list of representations to specify each hidden layer.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Note that since we are using the GatedNonlinearity, additional scalar gate channels need to be added to the output representation for the layer directly before the nonlinearity (in this case the Linear layer) which can be achieved with the gated function.”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## The equivariant linear layers (low level)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“At a lower level, the implementation of the Linear is fairly straightforward. An unconstrained bias b and weight matrix w are initialized. The projection matrices $P_b$ and $P_w$ are computed which are used project onto the symmetric subspace for each. Finally, during the forward pass, the unconstrained parameters are reshaped to vectors, projected via the matrices, and reshaped back to the original sizes. Then these projected parameters are applied to the input like a standard linear layer.”

]

}, {

“cell_type”: “code”, “execution_count”: 13, “metadata”: {}, “outputs”: [], “source”: [

“from objax.variable import TrainVarn”, “from objax.nn.init import orthogonaln”, “n”, “class Linear(Module):n”, ” """ Basic equivariant Linear layer from repin to repout."""n”, ” def __init__(self, repin, repout):n”, ” nin,nout = repin.size(),repout.size()n”, ” self.b = TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout))n”, ” self.w = TrainVar(orthogonal((nout, nin)))n”, ” self.rep_W = rep_W = repout*repin.Tn”, ” n”, ” self.Pb = repout.equivariant_projector() # the bias vector has representation repoutn”, ” self.Pw = rep_W.equivariant_projector()n”, ” n”, ” def __call__(self, x):n”, ” W = (self.Pw@self.w.value.reshape(-1)).reshape(*self.w.value.shape)n”, ” b = self.Pb@self.b.valuen”, ” return x@W.T+b

]

}

], “metadata”: {

“kernelspec”: {

“display_name”: “Python 3”, “language”: “python”, “name”: “python3”

}, “language_info”: {

“codemirror_mode”: {

“name”: “ipython”, “version”: 3

}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”

}

}, “nbformat”: 4, “nbformat_minor”: 2

}