{
“cells”: [
{

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“# Using EMLP with Flax”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation emlp.nn.flax

]

}, {

“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [], “source”: [

“from jax import randomn”, “import numpy as npn”, “import emlp.nn.flax as nn # import from the flax implementationn”, “from emlp.reps import T,Vn”, “from emlp.groups import SOn”, “n”, “repin= 4*V # Setup some example data representationsn”, “repout = Vn”, “G = SO(3)n”, “n”, “x = np.random.randn(5,repin(G).size()) # generate some random data”

]

}, {

“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [], “source”: [

“model = nn.EMLP(repin,repout,G)n”, “n”, “key = random.PRNGKey(0)n”, “params = model.init(random.PRNGKey(42), x)n”, “n”, “y = model.apply(params, x) # Forward pass with inputs x and parameters”

]

}, {

“cell_type”: “markdown”, “metadata”: {}, “source”: [

“And indeed, the parameters of the model are registered as expected.”

]

}, {

“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [

{

“output_type”: “execute_result”, “data”: {

“text/plain”: [

“[‘modules_0’, ‘modules_1’, ‘modules_2’, ‘modules_3’]”

]

}, “metadata”: {}, “execution_count”: 3

}

], “source”: [

“list(params[‘params’].keys())”

]

}

], “metadata”: {

“kernelspec”: {

“name”: “python3”, “display_name”: “Python 3.8.5 64-bit (‘freshenv’: conda)”

}, “language_info”: {

“codemirror_mode”: {

“name”: “ipython”, “version”: 3

}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”

}, “interpreter”: {

“hash”: “ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5”

}

}, “nbformat”: 4, “nbformat_minor”: 4

}