graph2mat.bindings.e3nn.modules.graph2mat

Classes

E3nnGraph2Mat(unique_basis, irreps[, ...])

Extension of TorchGraph2Mat to deal with irreps.

class graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat(unique_basis: Sequence[PointBasis], irreps: Dict[str, Irreps], preprocessing_nodes: Type[Module] | None = None, preprocessing_nodes_kwargs: dict = {}, preprocessing_edges: Type[Module] | None = None, preprocessing_edges_kwargs: dict = {}, preprocessing_edges_reuse_nodes: bool = True, node_operation: Type = graph2mat.bindings.e3nn.modules.node_operations.E3nnSimpleNodeBlock, node_operation_kwargs: dict = {}, edge_operation: Type = graph2mat.bindings.e3nn.modules.edge_operations.E3nnSimpleEdgeBlock, edge_operation_kwargs: dict = {}, symmetric: bool = False, blocks_symmetry: str = 'ij', self_blocks_symmetry: str | None = None, matrix_block_cls: Type[MatrixBlock] = graph2mat.bindings.e3nn.modules.matrixblock.E3nnIrrepsMatrixBlock, **kwargs)[source]

Bases: TorchGraph2Mat

Extension of TorchGraph2Mat to deal with irreps.

Parameters:
  • unique_basis

    Basis of the point types that the function should be able to handle. It can either be a list of the unique PointBasis objects or a BasisTableWithEdges object.

    Note that when using the function, each graph does not need to contain all the point types.

  • irreps

    Dictionary containing the irreps of all the possible features that the model has to deal with.

    The only required key is: “node_feats_irreps”.

    The rest depend on what the preprocessing and block producing functions use.

  • preprocessing_nodes

    A module that preprocesses the node features before passing them to the node block producing functions. This is \(p_n\) in the sketch.

    It should be a class with an __init__ method that receives the initialization arguments and a __call__ method that receives the data to process. The data will be the same that has been passed to Graph2Mat.

    It can output either a single array (the updated node features) or a tuple (updated node features, edge messages). In the second case, edge messages will be disregarded, this is just so that the preprocessing functions can be reused for nodes and edge processing.

  • preprocessing_nodes_kwargs – Initialization arguments passed directly to the preprocessing_nodes class.

  • preprocessing_edges

    A module that preprocesses the edge features before passing them to the edge block producing functions. This is \(p_e\) in the sketch.

    It should be a class with an __init__ method that receives the initialization arguments and a __call__ method that receives the data to process. The data will be the same that has been passed to Graph2Mat.

    It can output either a single array (the updated node features) or a tuple (updated node features, edge messages). In the second case, the updated node features can be None.

  • preprocessing_edges_kwargs – Initialization arguments passed directly to the preprocessing_edges class.

  • preprocessing_edges_reuse_nodes

    If there is a preprocessing function for edges and it only returns edge messages, whether the un-updated node features should also be passed to the edge block producing functions.

    It has no effect if there is no edge preprocessing function or the edge preprocessing function returns both node features and edge messages.

  • node_operation

    The operation used to compute the values for matrix blocks corresponding to self interactions (nodes). This is the \(f_n\) functions in the sketch.

    It should be a class with an __init__ method that receives the initialization arguments (such as i_basis, j_basis and symmetry) and a __call__ method that receives the data to process. It will receive the node features for the node blocks that the operation must compute.

  • node_operation_kwargs – Initialization arguments for the node_operation class.

  • edge_operation

    The operation used to compute the values for matrix blocks corresponding to interactions between different nodes (edges). This is the \(f_e\) functions in the sketch.

    It should be a class with an __init__ method that receives the initialization arguments (such as i_basis, j_basis and symmetry) and a __call__ method that receives the data to process. It will receive:

    • Node features as a tuple: (feats_senders, feats_receiver)

    • Edge messages as a tuple: (edge_message_ij, edge_message_ji)

    Each item in the tuples is an array with length n_edges.

    The operation does not need to handle permutation of the nodes. If the matrix is symmetric, permutation of nodes should lead to the transposed block, but this is handled by Graph2Mat.

  • edge_operation_kwargs – Initialization arguments for the edge_operation class.

  • symmetric

    Whether the matrix is symmetric. If it is, edge blocks for edges connecting the same two atoms but in opposite directions will be computed only once (the block for the opposite direction is the transpose block).

    This also determines the symmetry argument pass to the node_operation on initialization.

  • blocks_symmetry – The symmetry that each block (both edge and node blocks) must obey. If the blocks must be symmetric for example, this should be set to “ij=ji”.

  • self_blocks_symmetry

    The symmetry that node blocks must obey. If this is None:

    • If symmetric is False, self_blocks are assumed to have the same symmetry as other blocks, which is specified in the blocks_symmetry parameter.

    • If symmetric is True, self_blocks are assumed to be symmetric.

  • matrix_block_cls – Class that wraps matrix block operations.

  • **kwargs – Additional arguments passed to the Graph2Mat class.

Examples

This is an example of how to use it with custom node and edge operations, which will allow you to understand what the operation receives so that you can tune it to your needs:

import torch
from e3nn import o3

from graph2mat import PointBasis
from graph2mat.bindings.e3nn import E3nnGraph2Mat

# Build a basis set
basis = [
    PointBasis("A", R=2, basis=[1], basis_convention="cartesian"),
    PointBasis("B", R=5, basis=[2, 1], basis_convention="cartesian")
]

# Define the custom operation that just prints the arguments
class CustomOperation(torch.nn.Module):

    def __init__(self, node_feats_irreps, irreps_out):
        print("INITIALIZING OPERATION")
        print("INPUT NODE FEATS IRREPS:", node_feats_irreps)
        print("IRREPS_OUT:", irreps_out)
        print("")

    def __call__(self, node_feats):
        print(data, node_feats)

        # This return will create an error. Instead, you should
        # produce something of irreps_out.
        return node_feats

# Initialize the module
g2m = E3nnGraph2Mat(
    unique_basis=basis,
    irreps={"node_feats_irreps": o3.Irreps("2x0e + 1x1o")},
    symmetric=True,
    node_operation=CustomOperation,
    edge_operation=CustomOperation,
)

print("SUMMARY")
print(g2m.summary)

See also

Graph2Mat

The class that E3nnGraph2Mat extends. Its documentation contains a more detailed explanation of the inner workings of the class.

__init__(unique_basis: Sequence[PointBasis], irreps: Dict[str, Irreps], preprocessing_nodes: Type[Module] | None = None, preprocessing_nodes_kwargs: dict = {}, preprocessing_edges: Type[Module] | None = None, preprocessing_edges_kwargs: dict = {}, preprocessing_edges_reuse_nodes: bool = True, node_operation: Type = graph2mat.bindings.e3nn.modules.node_operations.E3nnSimpleNodeBlock, node_operation_kwargs: dict = {}, edge_operation: Type = graph2mat.bindings.e3nn.modules.edge_operations.E3nnSimpleEdgeBlock, edge_operation_kwargs: dict = {}, symmetric: bool = False, blocks_symmetry: str = 'ij', self_blocks_symmetry: str | None = None, matrix_block_cls: Type[MatrixBlock] = graph2mat.bindings.e3nn.modules.matrixblock.E3nnIrrepsMatrixBlock, **kwargs)[source]