Using EMLP with FlaxΒΆ
To use EMLP with Flax is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation emlp.nn.flax
[1]:
from jax import random
import numpy as np
import emlp.nn.flax as nn # import from the flax implementation
from emlp.reps import T,V
from emlp.groups import SO
repin= 4*V # Setup some example data representations
repout = V
G = SO(3)
x = np.random.randn(5,repin(G).size()) # generate some random data
[2]:
model = nn.EMLP(repin,repout,G)
key = random.PRNGKey(0)
params = model.init(random.PRNGKey(42), x)
y = model.apply(params, x) # Forward pass with inputs x and parameters
And indeed, the parameters of the model are registered as expected.
[3]:
list(params['params'].keys())
[3]:
['modules_0', 'modules_1', 'modules_2', 'modules_3']