graph2mat.bindings.e3nn.modules.preprocessing

Classes

E3nnEdgeMessageBlock(irreps)

This is basically MACE's RealAgnosticResidualInteractionBlock, but only up to the part where it computes the partial mji messages.

E3nnInteraction(irreps[, avg_num_neighbors])

Basically MACE's RealAgnosticResidualInteractionBlock, without reshapes.

class graph2mat.bindings.e3nn.modules.preprocessing.E3nnEdgeMessageBlock(irreps: Dict[str, Irreps])[source]

Bases: Module

This is basically MACE’s RealAgnosticResidualInteractionBlock, but only up to the part where it computes the partial mji messages.

It computes a “message” for each edge in the graph. Note that the message is different for the edge (i, j) and the edge (j, i).

This function can be used for the preprocessing step of edges. It has no effect when used as the preprocessing step of nodes.

__init__(irreps: Dict[str, Irreps]) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: TorchBasisMatrixData, node_feats: Tensor) Tuple[None, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class graph2mat.bindings.e3nn.modules.preprocessing.E3nnInteraction(irreps: Dict[str, Irreps], avg_num_neighbors: float = 10)[source]

Bases: Module

Basically MACE’s RealAgnosticResidualInteractionBlock, without reshapes.

This function takes a graph and returns new states for the nodes.

This function can be used for the preprocessing step of both nodes and edges.

__init__(irreps: Dict[str, Irreps], avg_num_neighbors: float = 10) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: TorchBasisMatrixData, node_feats: Tensor) Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.