- {
- “cells”: [
- {
“cell_type”: “markdown”, “metadata”: {}, “source”: [
“# Using EMLP with Flax”
]
}, {
“cell_type”: “markdown”, “metadata”: {}, “source”: [
“To use EMLP with [Flax](https://github.com/google/flax) is pretty similar to Objax or Haiku. Just make sure to import from the flax implementation emlp.nn.flax”
]
}, {
“cell_type”: “code”, “execution_count”: 1, “metadata”: {}, “outputs”: [], “source”: [
“from jax import randomn”, “import numpy as npn”, “import emlp.nn.flax as nn # import from the flax implementationn”, “from emlp.reps import T,Vn”, “from emlp.groups import SOn”, “n”, “repin= 4*V # Setup some example data representationsn”, “repout = Vn”, “G = SO(3)n”, “n”, “x = np.random.randn(5,repin(G).size()) # generate some random data”
]
}, {
“cell_type”: “code”, “execution_count”: 2, “metadata”: {}, “outputs”: [], “source”: [
“model = nn.EMLP(repin,repout,G)n”, “n”, “key = random.PRNGKey(0)n”, “params = model.init(random.PRNGKey(42), x)n”, “n”, “y = model.apply(params, x) # Forward pass with inputs x and parameters”
]
}, {
“cell_type”: “markdown”, “metadata”: {}, “source”: [
“And indeed, the parameters of the model are registered as expected.”
]
}, {
“cell_type”: “code”, “execution_count”: 3, “metadata”: {}, “outputs”: [
- {
“output_type”: “execute_result”, “data”: {
- “text/plain”: [
“[‘modules_0’, ‘modules_1’, ‘modules_2’, ‘modules_3’]”
]
}, “metadata”: {}, “execution_count”: 3
}
], “source”: [
“list(params[‘params’].keys())”
]
}
], “metadata”: {
- “kernelspec”: {
“name”: “python3”, “display_name”: “Python 3.8.5 64-bit (‘freshenv’: conda)”
}, “language_info”: {
- “codemirror_mode”: {
“name”: “ipython”, “version”: 3
}, “file_extension”: “.py”, “mimetype”: “text/x-python”, “name”: “python”, “nbconvert_exporter”: “python”, “pygments_lexer”: “ipython3”, “version”: “3.8.5”
}, “interpreter”: {
“hash”: “ec74566b76234e57f2cd5bb0818dcd91369c1a3af290381c3b6efeb6aea6cdd5”
}
}, “nbformat”: 4, “nbformat_minor”: 4
}