Using EMLP with FlaxΒΆ

To use EMLP with Flax is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation emlp.nn.flax

from jax import random
import numpy as np
import emlp.nn.flax as nn # import from the flax implementation
from emlp.reps import T,V
from emlp.groups import SO

repin= 4*V # Setup some example data representations
repout = V
G = SO(3)

x = np.random.randn(5,repin(G).size()) # generate some random data
model = nn.EMLP(repin,repout,G)

key = random.PRNGKey(0)
params = model.init(random.PRNGKey(42), x)

y = model.apply(params,  x) # Forward pass with inputs x and parameters

And indeed, the parameters of the model are registered as expected.

['modules_0', 'modules_1', 'modules_2', 'modules_3']