Interactive online version: Open In Colab

Using EMLP with FlaxΒΆ

To use EMLP with Flax is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation emlp.nn.flax

[1]:
from jax import random
import numpy as np
import emlp.nn.flax as nn # import from the flax implementation
from emlp.reps import T,V
from emlp.groups import SO

repin= 4*V # Setup some example data representations
repout = V
G = SO(3)

x = np.random.randn(5,repin(G).size()) # generate some random data
[2]:
model = nn.EMLP(repin,repout,G)

key = random.PRNGKey(0)
params = model.init(random.PRNGKey(42), x)

y = model.apply(params,  x) # Forward pass with inputs x and parameters

And indeed, the parameters of the model are registered as expected.

[3]:
list(params['params'].keys())
[3]:
['modules_0', 'modules_1', 'modules_2', 'modules_3']