Layers and Models

class emlp.nn.Linear(repin, repout)[source]

Bases: objax.nn.layers.Linear

Basic equivariant Linear layer from repin to repout.

class emlp.nn.BiLinear(repin, repout)[source]

Bases: objax.module.Module

Cheap bilinear layer (adds parameters for each part of the input which can be interpreted as a linear map from a part of the input to the output representation).

emlp.nn.gated(ch_rep)[source]

Returns the rep with an additional scalar ‘gate’ for each of the nonscalars and non regular reps in the input. To be used as the output for linear (and or bilinear) layers directly before a GatedNonlinearity() to produce its scalar gates.

Parameters

ch_rep (Rep) –

Return type

Rep

class emlp.nn.GatedNonlinearity(rep)[source]

Bases: objax.module.Module

Gated nonlinearity. Requires input to have the additional gate scalars for every non regular and non scalar rep. Applies swish to regular and scalar reps.

class emlp.nn.EMLPBlock(rep_in, rep_out)[source]

Bases: objax.module.Module

Basic building block of EMLP consisting of G-Linear, biLinear, and gated nonlinearity.

emlp.nn.uniform_rep(ch, group)[source]

A heuristic method for allocating a given number of channels (ch) into tensor types. Attempts to distribute the channels evenly across the different tensor types. Useful for hands off layer construction.

Parameters
  • ch (int) – total number of channels

  • group (Group) – symmetry group

Returns

The direct sum representation with dim(V)=ch

Return type

SumRep

class emlp.nn.EMLP(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: objax.module.Module

Equivariant MultiLayer Perceptron. If the input ch argument is an int, uses the hands off uniform_rep heuristic. If the ch argument is a representation, uses this representation for the hidden layers. Individual layer representations can be set explicitly by using a list of ints or a list of representations, rather than use the same for each hidden layer.

Parameters
  • rep_in (Rep) – input representation

  • rep_out (Rep) – output representation

  • group (Group) – symmetry group

  • ch (int or list[int] or Rep or list[Rep]) – number of channels in the hidden layers

  • num_layers (int) – number of hidden layers

Returns

the EMLP objax module.

Return type

Module

class emlp.nn.MLP(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: objax.module.Module

Standard baseline MLP. Representations and group are used for shapes only.

class emlp.nn.Standardize(model, ds_stats)[source]

Bases: objax.module.Module

A convenience module to wrap a given module, normalize its input by some dataset x mean and std stats, and unnormalize its output by the dataset y mean and std stats.

Parameters
  • model (Module) – model to wrap

  • ds_stats ((μx,σx,μy,σy) or (μx,σx)) – tuple of the normalization stats

Returns

Wrapped model with input normalization (and output unnormalization)

Return type

Module

class emlp.nn.MLPode(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: objax.module.Module

class emlp.nn.EMLPode(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: emlp.nn.objax.EMLP

Neural ODE Equivariant MLP. Same args as EMLP.

class emlp.nn.MLPH(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: objax.module.Module

class emlp.nn.EMLPH(rep_in, rep_out, group, ch=384, num_layers=3)[source]

Bases: emlp.nn.objax.EMLP

Equivariant EMLP modeling a Hamiltonian for HNN. Same args as EMLP

emlp.nn.gate_indices(ch_rep)[source]

Indices for scalars, and also additional scalar gates added by gated(sumrep)

Parameters

ch_rep (Rep) –

Return type

Array