{
“cells”: [
{

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“# Implementing a New Group”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Implementing a new group equivariance in our library is fairly straightforward. n”, “You need to specify the discrete and continuous generators of the group in a given representation: $\rho(h_i)$ and $d\rho(A_k)$, and then call the init method. These two fields, self.discrete_generators and self.lie_algebra should be a sequence of square matrices. These can either be specified as dense arrays (such as through np.ndarray`s of size `(M,n,n) and (D,n,n)) or as LinearOperator objects that implement matmul lazily. In general it’s possible to implement any matrix group, and we’ll go through a few illustrative examples. After checking out these examples, you can browse through the implementations for many other groups [here](https://github.com/mfinzi/equivariant-MLP/blob/master/emlp/groups.py).”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Discrete Group Example: Alternating Group $A_n$”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“The alternating group $A_n$ is a discrete group that contains all _even_ permutations from the permutation group $S_n$. There are many different generating sets that we could use, but let’s for example choose the [generators](https://math.stackexchange.com/questions/1358030/set-of-generators-for-a-n-the-alternating-group) $(123),(124),…,(12n)$ where each term is a cyclic permutation of those indices. So to implement the group, all we need to do is specify these generators.”

]

}, {

“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [], “source”: [

“#import logging; logging.getLogger().setLevel(logging.INFO)n”, “from emlp.groups import Group,Sn”, “from emlp.reps import V,T,visn”, “import numpy as npn”, “n”, “class Alt(Group): n”, ” """ The alternating group in n dimensions"""n”, ” def __init__(self,n):n”, ” assert n>2n”, ” self.discrete_generators = np.zeros((n-2,n,n))+np.eye(n) # init an array of n-2 identity matricesn”, ” for i in range(n-2):n”, ” ids = (0,1,i+2)n”, ” permed_ids = np.roll(ids,1) # cyclic permutation (0,1,i+2)->(i+2,0,1)n”, ” self.discrete_generators[i,ids] = self.discrete_generators[i,permed_ids]n”, ” super().__init__(n)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“High rank objects for $Alt(n)$ have additional equivariant solutions compared to the permutation group $S_n$ that contains it.”

]

}, {

“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [

{

“name”: “stdout”, “output_type”: “stream”, “text”: [

“T5 basis for Alt(5) of shape (3125, 63)n”

]

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “f87807aeed484a1d91fdcb145f2d2b1c”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=10: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “b215800b4444471faf7cf95595a73a05”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=20: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “a3bfef6e39b7443688c4584e72c18a9e”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=40: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “fd350fa2428b4d1080d6a08abc1488a2”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=80: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“name”: “stdout”, “output_type”: “stream”, “text”: [

“T5 basis for S(5) of shape (3125, 52)n”

]

}

], “source”: [

“print("T5 basis for Alt(5) of shape ",T(5)(Alt(5)).equivariant_basis().shape)n”, “print("T5 basis for S(5) of shape ",T(5)(S(5)).equivariant_basis().shape)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“You can verify the equivariance:”

]

}, {

“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [

{
“data”: {
“text/plain”: [

“‘Equivariance Error: 1.08e-06’”

]

}, “execution_count”: 3, “metadata”: {}, “output_type”: “execute_result”

}

], “source”: [

“import jax.numpy as jnpn”, “def rel_err(a,b):n”, ” return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#n”, “n”, “G = Alt(5)n”, “rep = T(5)(G)n”, “Q = rep.equivariant_basis()n”, “gQ = rep.rho(G.sample())@Qn”, “f"Equivariance Error: {rel_err(Q,gQ):.2e}"”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Continuous Group Example: Special Orthogonal Group $\mathrm{SO}(n)$”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Many Lie groups lie in the image of the exponential map. A classic example is the special orthogonal group $SO(n)$ consisting of rotations in $n$ dimensions: $\mathrm{SO}(n) = \{R \in \mathbb{R}^{n\times n}: R^TR=I, \mathrm{det}(R)=1\}$. Because this is a continuous group, we need to specify the Lie Algebra, which can be found by differentiating the constraints at the identity or simply by looking it up on [wikipedia](https://en.wikipedia.org/wiki/3D_rotation_group#Lie_algebra). $\mathfrak{so}(n) = T_\mathrm{id}\mathrm{SO}(n) = \{A\in \mathbb{R}^{n\times n}: A^\top=-A \}$. We can choose any basis for this $n(n-1)$ dimensional subspace of antisymmetric matrices. Since $\mathrm{exp}(\mathfrak{so}(n)) = \mathrm{SO}(n)$, this is all we need to specify.”

]

}, {

“cell_type”: “code”, “execution_count”: 4, “metadata”: {}, “outputs”: [], “source”: [

“class SO(Group): #n”, ” def __init__(self,n):n”, ” """ The special orthogonal group SO(n) in n dimensions"""n”, ” self.lie_algebra = np.zeros(((n*(n-1))//2,n,n))n”, ” k=0n”, ” for i in range(n):n”, ” for j in range(i):n”, ” self.lie_algebra[k,i,j] = 1n”, ” self.lie_algebra[k,j,i] = -1n”, ” k+=1n”, ” super().__init__(n)”

]

}, {

“cell_type”: “code”, “execution_count”: 5, “metadata”: {}, “outputs”: [

{
“data”: {

“image/png”: “iVBORw0KGgoAAAANSUhEUgAAAV0AAAB+CAYAAACHx8KbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAE3klEQVR4nO3cPWtUaRzG4SdjCqPiC4IJVoZxBRFRsPELaGMti7WNYBHEyk9gJWIhbGMti7WNfgEbwSAiuDsklSSC+IIaC51sKVsk+E9ybg/xuto8h3PQ8ZdTOPfE6upqAyBj8KsfAOB3IroAQaILECS6AEGiCxAkugBBk+v9cPbOrd/y/5Ptmf1QOv9pYV/5HntHtd93H4fj8j26Nvl5onzNt939+0hNP6lfs3x2659js8a7vpfOD77sKN/j1JlR6fz802H5HtvBwtz1Nf9xeNMFCBJdgCDRBQgSXYAg0QUIEl2AINEFCBJdgCDRBQgSXYAg0QUIWnd7IWH4YKV0fnRxqqMn+aG6pXDk5OvyPRbb4fI1fdPHHYXW6lsKG9lROPCitjvx7kTtz+rYvfel86219ury/vI1VdUthaml2nvdykz/NkZaa+3oteKHam7tH3nTBQgSXYAg0QUIEl2AINEFCBJdgCDRBQgSXYAg0QUIEl2AINEFCBJdgKAtHby5eu5R+Zq77Xzp/MFntaGR1lp7e7rbYZbF5/0cr7l54X7p/I2Hlzp6ko0b/flX+Zphu9LBk/xfdcBmMP21dD4xXpNQHbCZOf6mfI+ll4dK50+dGZXvMX97A6tIa/CmCxAkugBBogsQJLoAQaILECS6AEGiCxAkugBBogsQJLoAQaILEDSxurr2d8hn79zqdrSg1b8HPf902NGTbFzi++L8vOoGyN3Htf2P1uobINX9j+pWQ2utjZd3lq/p2nbY/2itvgEymPlnzQ+IN12AINEFCBJdgCDRBQgSXYAg0QUIEl2AINEFCBJdgCDRBQgSXYAg0QUI+uWDN3003vW9dH7wZUf5Htth6GfvqP47++Nw3MGTbM6e2Q/laz4t7OvgSTZn+knt/PLZ+j0mP9eGfr7t/i0T0hbmrhu8AegD0QUIEl2AINEFCBJdgCDRBQgSXYAg0QUIEl2AINEFCBJdgKB1txfODS6Wvjj97+0NfJk74Ni996Xzry7v7+Q5NmNqqf77cWWmtnNw5OTr0vnF54dL51trbfhgpXR+dHGqfA9+zoEXtR2F1lp7d6J/WwqJzYkq2wsAPSG6AEGiCxAkugBBogsQJLoAQaILECS6AEGiCxAkugBBogsQJLoAQZPr/bA6YFMdnmitPj6xkZGOPg7YVFXHa1qrj+QstvqATVV1wObotfqHqo/DS30cXerjeE11dKm1+ue2OrrU2tYOL3nTBQgSXYAg0QUIEl2AINEFCBJdgCDRBQgSXYAg0QUIEl2AINEFCJpYXV37+9fjpT9KX84e/n1l0w/UhcH019L58fLOjp5k425euF++5sbDS6XzM8fflM4vvTxUOt9aa6fOjErn558Oy/fg5xx8Vt8xeXu6f3sNV889Kp2/+/h8R0/yw8Lc9TX/cL3pAgSJLkCQ6AIEiS5AkOgCBIkuQJDoAgSJLkCQ6AIEiS5AkOgCBIkuQNC6gzezd271b90iYPpJ7fzy2fo9Jj/Xxka+7e7fX0V1vKa1fg7YjHd9L18z+LKjgyfZnD2zH0rnPy3sK99j76j2nvZxOC7fYzsweAPQE6ILECS6AEGiCxAkugBBogsQJLoAQaILECS6AEGiCxAkugBB624vALC1vOkCBIkuQJDoAgSJLkCQ6AIEiS5A0H974uKhefIGoQAAAABJRU5ErkJggg==n”, “text/plain”: [

“<Figure size 432x288 with 1 Axes>”

]

}, “metadata”: {

“needs_background”: “light”

}, “output_type”: “display_data”

}

], “source”: [

“vis(V(SO(3))**3,V(SO(3))**2,cluster=False)”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Lie Group with Multiple Connected Components Example: $\mathrm{O}(n)$ “

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“Lie groups that are not in the image of the exponential map can be a bit more complicated because they often need to be constructued with both continuous and discrete generators. A good example is the orthogonal group $\mathrm{O}(n)$ consisting of rotations and reflections $\mathrm{O}(n) = \{R \in \mathbb{R}^{n\times n}: R^TR=I\}$. The Lie algebra for $\mathrm{O}(n)$ is the same as for $\mathrm{SO}(n)$. $\mathfrak{o}(n)=\mathfrak{so}(n)$ and so $\mathrm{exp}(\mathfrak{o}(n)) = \mathrm{SO}(n) \ne \mathrm{O}(n)$. Instead, the orthogonal group has $2$ connected components: orthogonal matrices with $\mathrm{det}(R)=1$ and $\mathrm{det}(R)=-1$, and so we need a generator to traverse between the two components such as $h = \begin{bmatrix}-1 &0\\ 0 & I\\n”, “\end{bmatrix}$. We can reuse the Lie algebra implementation from $\mathrm{SO}(n)$ and implement the additional discrete generator below.”

]

}, {

“cell_type”: “code”, “execution_count”: 6, “metadata”: {}, “outputs”: [], “source”: [

“class O(SO): #n”, ” def __init__(self,n):n”, ” """ The Orthogonal group O(n) in n dimensions"""n”, ” self.discrete_generators = np.eye(n)[None]n”, ” self.discrete_generators[0,0,0]=-1n”, ” super().__init__(n)”

]

}, {

“cell_type”: “code”, “execution_count”: 7, “metadata”: {}, “outputs”: [

{
“data”: {

“image/png”: “iVBORw0KGgoAAAANSUhEUgAAAV0AAAB+CAYAAACHx8KbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAB5ElEQVR4nO3asQ3CUAxAQX7EahmBKRkhu+FMABLNS5G71o2rJxdeM/MAoLFdvQDAnYguQEh0AUKiCxASXYCQ6AKEnr+G+/byTwbwp+PzXt9mLl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCa2au3gHgNly6ACHRBQiJLkBIdAFCogsQEl2A0AmJ1gr3vZqh3gAAAABJRU5ErkJggg==n”, “text/plain”: [

“<Figure size 432x288 with 1 Axes>”

]

}, “metadata”: {

“needs_background”: “light”

}, “output_type”: “display_data”

}

], “source”: [

“vis(V(O(3))**3,V(O(3))**2,cluster=False); #Unlike SO(n), O(n) has no solutions for odd parity V^3->V^2 = T5 = odd”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“## Accelerating the solver using lazy matrices”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“For larger representations our solver uses an iterative method that benefits from faster multiplies with the generators. Instead of specifying the generators using dense matrices, you can specify them as LinearOperator objects in a way that makes use of known structure (like sparsity, permutation, etc). These LinearOperator objects are modeled after [scipy Linear Operators](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html) but adapted to be compatible with jax and with some additional features.n”, “n”, “Returning to the alternating group example, we can specify the generators as permutation operators directly. There are many useful LinearOperators implemented in LinearOperator which we recommend using if available, but we will go through the minimum steps for implementing a new operator like Permutation as an example.n”, “n”, “Note that you need to be using quite large representations before any speedups will be felt due to the increased compile times with Jax (we are hoping to speed this up). So even with the fairly large examples below, the densely implemented generators are still faster.”

]

}, {

“cell_type”: “code”, “execution_count”: 9, “metadata”: {}, “outputs”: [], “source”: [

“from emlp.reps.linear_operator_base import LinearOperatorn”, “import numpy as npn”, “n”, “class LazyPerm(LinearOperator):n”, ” def __init__(self,perm):n”, ” self.perm=permn”, ” self.shape = (len(perm),len(perm))n”, ” def _matmat(self,V):n”, ” return V[self.perm]n”, ” def _matvec(self,V):n”, ” return V[self.perm]n”, ” def _adjoint(self):n”, ” return LazyPerm(np.argsort(self.perm))”

]

}, {

“cell_type”: “code”, “execution_count”: 10, “metadata”: {}, “outputs”: [], “source”: [

“class AltFast(Group): n”, ” """ The alternating group in n dimensions"""n”, ” def __init__(self,n):n”, ” assert n>2n”, ” perms =np.zeros((n-2,n)).astype(int)+np.arange(n)[None]n”, ” for i in range(n-2):n”, ” ids = (0,1,i+2)n”, ” permed_ids = np.roll(ids,1) # cyclic permutation (0,1,i+2)->(i+2,0,1)n”, ” perms[i,ids] = perms[i,permed_ids]n”, ” self.discrete_generators = [LazyPerm(perm) for perm in perms]n”, ” super().__init__(n)”

]

}, {

“cell_type”: “code”, “execution_count”: 11, “metadata”: {}, “outputs”: [], “source”: [

“#import logging; logging.getLogger().setLevel(logging.INFO)”

]

}, {

“cell_type”: “code”, “execution_count”: 12, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “f3e9336bb668455ebd55a49fe9dbaa44”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=10: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“name”: “stdout”, “output_type”: “stream”, “text”: [

“16.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)n”

]

}

], “source”: [

“%timeit -n1 -r1 T(2)(Alt(100)).equivariant_basis()”

]

}, {

“cell_type”: “code”, “execution_count”: 13, “metadata”: {}, “outputs”: [

{
“data”: {
“application/vnd.jupyter.widget-view+json”: {

“model_id”: “2839a2cfa0e44cf0afc7dc015aa1d420”, “version_major”: 2, “version_minor”: 0

}, “text/plain”: [

“Krylov Solving for Equivariant Subspace r<=10: 0%| | 0/100 [00:00<?, ?it/s]”

]

}, “metadata”: {}, “output_type”: “display_data”

}, {

“name”: “stdout”, “output_type”: “stream”, “text”: [

“29.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)n”

]

}

], “source”: [

“%timeit -n1 -r1 T(2)(AltFast(100)).equivariant_basis()”

]

}, {

“cell_type”: “code”, “execution_count”: null, “metadata”: {}, “outputs”: [], “source”: []

}

], “metadata”: {

“kernelspec”: {

“display_name”: “Python 3”, “language”: “python”, “name”: “python3”

}, “language_info”: {

“codemirror_mode”: {

“name”: “ipython”, “version”: 3

}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”

}

}, “nbformat”: 4, “nbformat_minor”: 2

}