graph2mat.Graph2Mat
- class graph2mat.Graph2Mat(unique_basis: BasisTableWithEdges | Sequence[PointBasis], preprocessing_nodes: Type | None = None, preprocessing_nodes_kwargs: dict = {}, preprocessing_edges: Type | None = None, preprocessing_edges_kwargs: dict = {}, preprocessing_edges_reuse_nodes: bool = True, node_operation: Type = None, node_operation_kwargs: dict = {}, edge_operation: Type = None, edge_operation_kwargs: dict = {}, symmetric: bool = False, blocks_symmetry: str = 'ij', self_blocks_symmetry: str | None = None, matrix_block_cls: Type[MatrixBlock] = graph2mat.core.modules.matrixblock.MatrixBlock, numpy: ModuleType | None = None, self_interactions_list: Callable = list, interactions_dict: Callable = dict)[source]
Bases:
Generic
[ArrayType
]Converts a graph to a sparse matrix.
The matrix that this module computes has variable size, which corresponds to the size of the graph. It is built by applying a convolution of its functions over the edges and nodes of the graph.
High level architecture overview
Graph2mat builds the matrix
\[M_{\nu\mu}\]block by block. We define a block as a region of the matrix where the rows are all the basis of a given point, and all the columns are the basis of another given point. I.e. given two points \((i, j)\):
\[M_{ij} = all \space M_{\nu\mu} \space \text{where} \space \nu \in i, \mu \in j\]The shape of the basis of points \(i\) and \(j\) determines then the shape of the block \(M_{ij}\). Therefore, we need a different function to produce each kind of block. There are two clearly different types of blocks by their origin, which might also obey different symmetries, and therefore we can classify the blocks in two categories:
Self interaction blocks (\(f_n\)): These are blocks that encode the interactions between basis functions of the same point. They correspond to nodes in the graph. These blocks are always square matrices. They are located at the diagonal of the matrix. If the matrix is symmetric, these blocks must also be symmetric.
Interaction blocks (\(f_e\)): All the rest of blocks that contain interactions between basis functions from different points. They correspond to edges in the graph. Even if the matrix is symmetric, these blocks do not need to be symmetric. For each pair of points \((i, j)\), there are two blocks: \(M_{ij}\) and \(M_{ji}\). However, if the matrix is symmetric, one block is the transpose of the other. Therefore in that case we only need to compute/predict one of them.
Node features from the graph are passed to the block producing functions. Each block producing function only receives the features that correspond to the blocks that it needs to produce, as depicted in the sketch.
Optionally, one can pass preprocessing functions \((p_n, p_e)\) that update the graph before passing it to the node/edge block producing functions. The edge preprocessing function can also return edge-wise messages.
Note
Graph2Mat
itself is not a learnable module. If you are doing machine learning, the only learnable parameters will be in the node/edge operations \(f\) and the preprocessing functions \(p\) functions.Graph2Mat
is just a skeleton so that you can quickly experiment with different functions.Warning
It is very likely that you need an extension like
TorchGraph2Mat
orE3nnGraph2Mat
in practice to do machine learning, as those set the appropriate defaults and add some extra things that are particular for the frameworks.- Parameters:
unique_basis –
Basis of the point types that the function should be able to handle. It can either be a list of the unique
PointBasis
objects or aBasisTableWithEdges
object.Note that when using the function, each graph does not need to contain all the point types.
preprocessing_nodes –
A module that preprocesses the node features before passing them to the node block producing functions. This is \(p_n\) in the sketch.
It should be a class with an
__init__
method that receives the initialization arguments and a__call__
method that receives the data to process. The data will be the same that has been passed toGraph2Mat
.It can output either a single array (the updated node features) or a tuple (updated node features, edge messages). In the second case, edge messages will be disregarded, this is just so that the preprocessing functions can be reused for nodes and edge processing.
preprocessing_nodes_kwargs – Initialization arguments passed directly to the preprocessing_nodes class.
preprocessing_edges –
A module that preprocesses the edge features before passing them to the edge block producing functions. This is \(p_e\) in the sketch.
It should be a class with an
__init__
method that receives the initialization arguments and a__call__
method that receives the data to process. The data will be the same that has been passed toGraph2Mat
.It can output either a single array (the updated node features) or a tuple (updated node features, edge messages). In the second case, the updated node features can be None.
preprocessing_edges_kwargs – Initialization arguments passed directly to the preprocessing_edges class.
preprocessing_edges_reuse_nodes –
If there is a preprocessing function for edges and it only returns edge messages, whether the un-updated node features should also be passed to the edge block producing functions.
It has no effect if there is no edge preprocessing function or the edge preprocessing function returns both node features and edge messages.
node_operation –
The operation used to compute the values for matrix blocks corresponding to self interactions (nodes). This is the \(f_n\) functions in the sketch.
It should be a class with an
__init__
method that receives the initialization arguments (such as i_basis, j_basis and symmetry) and a__call__
method that receives the data to process. It will receive the node features for the node blocks that the operation must compute.node_operation_kwargs – Initialization arguments for the node_operation class.
edge_operation –
The operation used to compute the values for matrix blocks corresponding to interactions between different nodes (edges). This is the \(f_e\) functions in the sketch.
It should be a class with an
__init__
method that receives the initialization arguments (such as i_basis, j_basis and symmetry) and a__call__
method that receives the data to process. It will receive:Node features as a tuple: (feats_senders, feats_receiver)
Edge messages as a tuple: (edge_message_ij, edge_message_ji)
Each item in the tuples is an array with length n_edges.
The operation does not need to handle permutation of the nodes. If the matrix is symmetric, permutation of nodes should lead to the transposed block, but this is handled by
Graph2Mat
.edge_operation_kwargs – Initialization arguments for the edge_operation class.
symmetric –
Whether the matrix is symmetric. If it is, edge blocks for edges connecting the same two atoms but in opposite directions will be computed only once (the block for the opposite direction is the transpose block).
This also determines the symmetry argument pass to the node_operation on initialization.
blocks_symmetry – The symmetry that each block (both edge and node blocks) must obey. If the blocks must be symmetric for example, this should be set to “ij=ji”.
self_blocks_symmetry –
The symmetry that node blocks must obey. If this is None:
If symmetric is False, self_blocks are assumed to have the same symmetry as other blocks, which is specified in the blocks_symmetry parameter.
If symmetric is True, self_blocks are assumed to be symmetric.
matrix_block_cls – Class that wraps matrix block operations.
numpy – Module used as
numpy
. This can for example be set totorch
. If None, we usenumpy
.self_interactions_list –
Wrapper for the list of self interaction functions (\(f_n\), node blocks).
This is for example used in
torch
to convert the list of functions to atorch.nn.ModuleList
.interactions_dict –
Wrapper for the dictionary of interaction functions (\(f_e\), edge blocks).
This is for example used in
torch
to convert the dictionary of functions to atorch.nn.ModuleDict
.
Examples
This is an example of how to use it with custom node and edge operations, which will allow you to understand what the operation receives so that you can tune it to your needs:
from graph2mat import Graph2Mat, PointBasis # Build a basis set basis = [ PointBasis("A", R=2, basis=[1], basis_convention="cartesian"), PointBasis("B", R=5, basis=[2, 1], basis_convention="cartesian") ] # Define the custom operation that just prints the arguments class CustomOperation: def __init__(self, i_basis, j_basis, symmetry): print("INITIALIZING OPERATION") print("I_BASIS", i_basis) print("J_BASIS", j_basis) print("SYMMETRY", symmetry) print() def __call__(self, **kwargs): print(kwargs) return kwargs # Initialize the module g2m = Graph2Mat( unique_basis=basis, symmetric=True, node_operation=CustomOperation, edge_operation=CustomOperation, ) print("SUMMARY") print(g2m.summary)
Methods
forward
(data, node_feats[, ...])Computes the matrix elements.
Attributes
High level summary of the architecture of the module.
The table holding all information about the basis.
List of self interaction functions (which compute node blocks).
Dictionary of interaction functions (which compute edge blocks).
- basis_table: BasisTableWithEdges
The table holding all information about the basis. This is an internal table created by the module from unique_basis, but it should probably be equal to the basis table that you use to process your data.
- forward(data: BasisMatrixData, node_feats: ArrayType, preprocessing_nodes_kwargs: dict = {}, preprocessing_edges_kwargs: dict = {}, node_kwargs: Dict[str, ArrayType] = {}, edge_kwargs: Dict[str, ArrayType] = {}, global_kwargs: dict = {}, node_operation_node_kwargs: Dict[str, ArrayType] = {}, node_operation_global_kwargs: dict = {}, edge_operation_node_kwargs: Dict[str, ArrayType] = {}, edge_operation_global_kwargs: dict = {}) Tuple[ArrayType, ArrayType] [source]
Computes the matrix elements.
Note
Edges are assumed to be sorted in a very specific way:
Opposite directions of the same edge should come consecutively.
The direction that has a positive edge type should come first. The “positive” direction in an edge {i, j}, between point types “type_i” and “type_j” is the direction from the smallest point type to the biggest point type.
Sorted by edge type within the same structure. That is, edges where the same two species interact should be grouped within each structure in the batch. These groups should be ordered by edge type.
This is all taken care of by
BasisMatrixData
, so if you use it you don’t need to worry about it.- Parameters:
data – The data object containing the graph information. It can also be a dictionary that mocks the
BasisMatrixData
object with the appropiate keys.node_kwargs (Dict[str, ArrayType] = {},) –
Arguments to pass to node and edge operations that are node-wise. Tensors should have shape (n_nodes, …).
If you want to pass a node-wise argument only to node/edge operations, you should pass it on {node/edge}_operation_node_kwargs.
The arguments passed here will be added to both node_operation_node_kwargs and edge_operation_node_kwargs. See those parameters for more information on how they are used.
If a key is present in both node_kwargs and *_operation_node_kwargs, the value in *_operation_node_kwargs will be used.
edge_kwargs (Dict[str, ArrayType] = {},) –
Arguments to pass to edge operations that are edge-wise. Tensors should have shape (n_edges, …).
The module will filter and organize them to pass a tuple (type X, type -X) for edge operation X. That is, the tuple will contain both directions of the edge.
NOTE: One can think of passing edge-wise arguments to the node operations, which can then be aggregated into node-wise arguments. However, all this module does with node-wise and endge-wise arguments is to organize and reshape them. Therefore, an aggregation operation should be done outside of this module.
global_kwargs (dict = {},) – Arguments to pass to node and edge operations that are global (e.g. neither node-wise nor edge-wise). They are used by the operations as provided.
node_operation_node_kwargs (Dict[str, ArrayType] = {}) –
Arguments to pass to node operations that are node-wise. Tensors should have shape (n_nodes, …).
The module will filter them to contain only the values for nodes of type X before passing them to function for node type X.
node_operation_global_kwargs (dict = {},) – Arguments to pass to node operations that are global. They will be passed to each function as provided.
edge_operation_node_kwargs (Dict[str, ArrayType] = {},) –
Arguments to pass to edge operations that are node-wise. Tensors should have shape (n_edges, …).
The module will filter and organize them to pass a tuple (type X, type Y) for edge operation X -> Y.
edge_operation_global_kwargs (dict = {},) – Arguments to pass to edge operations that are global. They will be passed to each function as provided.
- Returns:
node_labels – All the node blocks, flattened and concatenated.
edge_blocks – All the edge blocks, flattened and concatenated.
- interactions: Dict[Tuple[int, int], MatrixBlock]
Dictionary of interaction functions (which compute edge blocks).
- self_interactions: List[MatrixBlock]
List of self interaction functions (which compute node blocks).