Source code for emlp.nn.objax

import jax
import jax.numpy as jnp
import objax.nn as nn
import objax.functional as F
import numpy as np
from emlp.reps import T,Rep,Scalar
from emlp.reps import bilinear_weights
from emlp.reps.product_sum_reps import SumRep
import collections
from emlp.utils import Named,export
import scipy as sp
import scipy.special
import random
import logging
from objax.variable import TrainVar, StateVar
from objax.nn.init import kaiming_normal, xavier_normal
from objax.module import Module
import objax
from objax.nn.init import orthogonal
from scipy.special import binom
from jax import jit,vmap
from functools import lru_cache as cache

def Sequential(*args):
    """ Wrapped to mimic pytorch syntax"""
    return nn.Sequential(args)

[docs]@export class Linear(nn.Linear): """ Basic equivariant Linear layer from repin to repout.""" def __init__(self, repin, repout): nin,nout = repin.size(),repout.size() super().__init__(nin,nout) self.b = TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout)) self.w = TrainVar(orthogonal((nout, nin))) self.rep_W = rep_W = repout*repin.T rep_bias = repout self.Pw = rep_W.equivariant_projector() self.Pb = rep_bias.equivariant_projector() logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}") def __call__(self, x): # (cin) -> (cout) logging.debug(f"linear in shape: {x.shape}") W = (self.Pw@self.w.value.reshape(-1)).reshape(*self.w.value.shape) b = self.Pb@self.b.value out = x@W.T+b logging.debug(f"linear out shape:{out.shape}") return out
[docs]@export class BiLinear(Module): """ Cheap bilinear layer (adds parameters for each part of the input which can be interpreted as a linear map from a part of the input to the output representation).""" def __init__(self, repin, repout): super().__init__() Wdim, weight_proj = bilinear_weights(repout,repin) self.weight_proj = jit(weight_proj) self.w = TrainVar(objax.random.normal((Wdim,))) logging.info(f"BiW components: dim:{Wdim}") def __call__(self, x,training=True): # compatible with non sumreps? need to check W = self.weight_proj(self.w.value,x) out= .1*(W@x[...,None])[...,0] return out
[docs]@export def gated(ch_rep:Rep) -> Rep: """ Returns the rep with an additional scalar 'gate' for each of the nonscalars and non regular reps in the input. To be used as the output for linear (and or bilinear) layers directly before a :func:`GatedNonlinearity` to produce its scalar gates. """ if isinstance(ch_rep,SumRep): return ch_rep+sum([Scalar(rep.G) for rep in ch_rep if rep!=Scalar and not rep.is_permutation]) else: return ch_rep+Scalar(ch_rep.G) if not ch_rep.is_permutation else ch_rep
[docs]@export class GatedNonlinearity(Module): """ Gated nonlinearity. Requires input to have the additional gate scalars for every non regular and non scalar rep. Applies swish to regular and scalar reps. """ def __init__(self,rep): super().__init__() self.rep=rep def __call__(self,values): gate_scalars = values[..., gate_indices(self.rep)] activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()] return activations
[docs]@export class EMLPBlock(Module): """ Basic building block of EMLP consisting of G-Linear, biLinear, and gated nonlinearity. """ def __init__(self,rep_in,rep_out): super().__init__() self.linear = Linear(rep_in,gated(rep_out)) self.bilinear = BiLinear(gated(rep_out),gated(rep_out)) self.nonlinearity = GatedNonlinearity(rep_out) def __call__(self,x): lin = self.linear(x) preact =self.bilinear(lin)+lin return self.nonlinearity(preact)
def uniform_rep_general(ch,*rep_types): """ adds all combinations of (powers of) rep_types up to a total size of ch channels. """ raise NotImplementedError
[docs]@export def uniform_rep(ch,group): """ A heuristic method for allocating a given number of channels (ch) into tensor types. Attempts to distribute the channels evenly across the different tensor types. Useful for hands off layer construction. Args: ch (int): total number of channels group (Group): symmetry group Returns: SumRep: The direct sum representation with dim(V)=ch """ d = group.d Ns = np.zeros((lambertW(ch,d)+1,),int) # number of tensors of each rank while ch>0: max_rank = lambertW(ch,d) # compute the max rank tensor that can fit up to Ns[:max_rank+1] += np.array([d**(max_rank-r) for r in range(max_rank+1)],dtype=int) ch -= (max_rank+1)*d**max_rank # compute leftover channels sum_rep = sum([binomial_allocation(nr,r,group) for r,nr in enumerate(Ns)]) sum_rep,perm = sum_rep.canonicalize() return sum_rep
def lambertW(ch,d): """ Returns solution to x*d^x = ch rounded down.""" max_rank=0 while (max_rank+1)*d**max_rank <= ch: max_rank += 1 max_rank -= 1 return max_rank def binomial_allocation(N,rank,G): """ Allocates N of tensors of total rank r=(p+q) into T(k,r-k) for k=0,1,...,r to match the binomial distribution. For orthogonal representations there is no distinction between p and q, so this op is equivalent to N*T(rank).""" if N==0: return 0 n_binoms = N//(2**rank) n_leftover = N%(2**rank) even_split = sum([n_binoms*int(binom(rank,k))*T(k,rank-k,G) for k in range(rank+1)]) ps = np.random.binomial(rank,.5,n_leftover) ragged = sum([T(int(p),rank-int(p),G) for p in ps]) out = even_split+ragged return out def uniform_allocation(N,rank): """ Uniformly allocates N of tensors of total rank r=(p+q) into T(k,r-k) for k=0,1,...,r. For orthogonal representations there is no distinction between p and q, so this op is equivalent to N*T(rank).""" if N==0: return 0 even_split = sum((N//(rank+1))*T(k,rank-k) for k in range(rank+1)) ragged = sum(random.sample([T(k,rank-k) for k in range(rank+1)],N%(rank+1))) return even_split+ragged
[docs]@export class EMLP(Module,metaclass=Named): """ Equivariant MultiLayer Perceptron. If the input ch argument is an int, uses the hands off uniform_rep heuristic. If the ch argument is a representation, uses this representation for the hidden layers. Individual layer representations can be set explicitly by using a list of ints or a list of representations, rather than use the same for each hidden layer. Args: rep_in (Rep): input representation rep_out (Rep): output representation group (Group): symmetry group ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers num_layers (int): number of hidden layers Returns: Module: the EMLP objax module.""" def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@ super().__init__() logging.info("Initing EMLP (objax)") self.rep_in =rep_in(group) self.rep_out = rep_out(group) self.G=group # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)] elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] #assert all((not rep.G is None) for rep in middle_layers[0].reps) reps = [self.rep_in]+middle_layers logging.info(f"Reps: {reps}") self.network = Sequential( *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], Linear(reps[-1],self.rep_out) ) def __call__(self,x,training=True): return self.network(x)
def swish(x): return jax.nn.sigmoid(x)*x def MLPBlock(cin,cout): return Sequential(nn.Linear(cin,cout),swish)#,nn.BatchNorm0D(cout,momentum=.9),swish)#,
[docs]@export class MLP(Module,metaclass=Named): """ Standard baseline MLP. Representations and group are used for shapes only. """ def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): super().__init__() self.rep_in =rep_in(group) self.rep_out = rep_out(group) self.G = group chs = [self.rep_in.size()] + num_layers*[ch] cout = self.rep_out.size() logging.info("Initing MLP") self.net = Sequential( *[MLPBlock(cin,cout) for cin,cout in zip(chs,chs[1:])], nn.Linear(chs[-1],cout) ) def __call__(self,x,training=True): y = self.net(x) return y
[docs]@export class Standardize(Module): """ A convenience module to wrap a given module, normalize its input by some dataset x mean and std stats, and unnormalize its output by the dataset y mean and std stats. Args: model (Module): model to wrap ds_stats ((μx,σx,μy,σy) or (μx,σx)): tuple of the normalization stats Returns: Module: Wrapped model with input normalization (and output unnormalization)""" def __init__(self,model,ds_stats): super().__init__() self.model = model self.ds_stats=ds_stats def __call__(self,x,training): if len(self.ds_stats)==2: muin,sin = self.ds_stats return self.model((x-muin)/sin,training=training) else: muin,sin,muout,sout = self.ds_stats y = sout*self.model((x-muin)/sin,training=training)+muout return y
# Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads)
[docs]@export class MLPode(Module,metaclass=Named): def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): super().__init__() self.rep_in =rep_in(group) self.rep_out = rep_out(group) self.G = group chs = [self.rep_in.size()] + num_layers*[ch] cout = self.rep_out.size() logging.info("Initing MLP") self.net = Sequential( *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])], nn.Linear(chs[-1],cout) ) def __call__(self,z,t): return self.net(z)
[docs]@export class EMLPode(EMLP): """ Neural ODE Equivariant MLP. Same args as EMLP.""" #__doc__ += EMLP.__doc__.split('.')[1] def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@ #super().__init__() logging.info("Initing EMLP") self.rep_in =rep_in(group) self.rep_out = rep_out(group) self.G=group # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)] elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)] else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch] #print(middle_layers[0].reps[0].G) #print(self.rep_in.G) reps = [self.rep_in]+middle_layers logging.info(f"Reps: {reps}") self.network = Sequential( *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])], Linear(reps[-1],self.rep_out) ) def __call__(self,z,t): return self.network(z)
# Networks for hamiltonian dynamics (need to sum for batched Hamiltonian grads)
[docs]@export class MLPH(Module,metaclass=Named): def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3): super().__init__() self.rep_in =rep_in(group) self.rep_out = rep_out(group) self.G = group chs = [self.rep_in.size()] + num_layers*[ch] cout = self.rep_out.size() logging.info("Initing MLP") self.net = Sequential( *[Sequential(nn.Linear(cin,cout),swish) for cin,cout in zip(chs,chs[1:])], nn.Linear(chs[-1],cout) ) def H(self,x):#,training=True): y = self.net(x).sum() return y def __call__(self,x): return self.H(x)
[docs]@export class EMLPH(EMLP): """ Equivariant EMLP modeling a Hamiltonian for HNN. Same args as EMLP""" #__doc__ += EMLP.__doc__.split('.')[1] def H(self,x):#,training=True): y = self.network(x) return y.sum() def __call__(self,x): return self.H(x)
[docs]@export @cache(maxsize=None) def gate_indices(ch_rep:Rep) -> jnp.ndarray: """ Indices for scalars, and also additional scalar gates added by gated(sumrep)""" channels = ch_rep.size() perm = ch_rep.perm indices = np.arange(channels) if not isinstance(ch_rep,SumRep): # If just a single rep, only one scalar at end return indices if ch_rep.is_permutation else np.ones(ch_rep.size())*ch_rep.size() num_nonscalars = 0 i=0 for rep in ch_rep: if rep!=Scalar and not rep.is_permutation: indices[perm[i:i+rep.size()]] = channels+num_nonscalars num_nonscalars+=1 i+=rep.size() return indices