graph2mat.BasisMatrixDataBase
- class graph2mat.BasisMatrixDataBase(edge_index: ndarray | None = None, neigh_isc: ndarray | None = None, node_attrs: ndarray | None = None, positions: ndarray | None = None, shifts: ndarray | None = None, cell: ndarray | None = None, nsc: ndarray | None = None, point_labels: ndarray | None = None, edge_labels: ndarray | None = None, labels_point_filter: ndarray | None = None, labels_edge_filter: ndarray | None = None, point_types: ndarray | None = None, edge_types: ndarray | None = None, edge_type_nlabels: ndarray | None = None, data_processor: MatrixDataProcessor = None, metadata: Dict[str, Any] | None = None, already_basis: bool = False)[source]
Bases:
Generic
[ArrayType
]Stores a graph with the preprocessed data for one or multiple configurations.
Warning
This class just implements the generic functionality and you should not use it directly. Depending on the type of arrays that you use to store the data, you should use the corresponding subclass. E.g.
BasisMatrixData
fornumpy
arrays, orTorchBasisMatrixData
fortorch
tensors.The differences between this class and
BasisConfiguration
are:This class stores examples as graphs, while
BasisConfiguration
just stores the raw data.This class might store a batch of examples inside the same graph. Different dataset examples are just graph clusters that are not connected with each other.
This class is the main interface between the data and the models.
The class accepts positions, cell, and displacements in cartesian coordinates, but they are converted to the convention specified by the data processor (e.g. spherical harmonics), and stored in this way.
- Parameters:
edge_index (graph2mat.core.data.processing.ArrayType) – Shape (2, n_edges). Array with point pairs (their index in the configuration) that form an edge.
neigh_isc (graph2mat.core.data.processing.ArrayType) – Shape (n_edges,). Array with the index of the supercell where the second point of each edge is located. This follows the conventions in
sisl
node_attrs (graph2mat.core.data.processing.ArrayType) – Shape (n_points, n_node_feats). Inputs for each point in the configuration.
positions (graph2mat.core.data.processing.ArrayType) – Shape (n_points, 3). Cartesian coordinates of each point in the configuration.
shifts (graph2mat.core.data.processing.ArrayType) – Shape (n_edges, 3). Cartesian shift of the second atom in each edge with respect to its image in the primary cell. E.g. if the second atom is in the primary cell, the shift will be [0,0,0].
cell (graph2mat.core.data.processing.ArrayType) – Shape (3,3). Lattice vectors of the unit cell in cartesian coordinates.
nsc (graph2mat.core.data.processing.ArrayType) – Shape (3,). Number of auxiliary cells required in each direction to account for all neighbor interactions.
point_labels (graph2mat.core.data.processing.ArrayType) –
Shape (n_point_labels,). The elements of the target matrix that correspond to interactions within the same node. This is flattened to deal with the fact that each block might have different shape.
All values for a given block come consecutively and in row-major order.
edge_labels (graph2mat.core.data.processing.ArrayType) –
Shape (n_edge_labels,). The elements of the target matrix that correspond to interactions between different nodes. This is flattened to deal with the fact that each block might have different shape.
All values for a given block come consecutively and in row-major order.
NOTE: These should be sorted by edge type.
point_types (graph2mat.core.data.processing.ArrayType) – Shape (n_points,). The type of each point (index in the basis table).
edge_types (graph2mat.core.data.processing.ArrayType) – Shape (n_edges,). The type of each edge as defined by the basis table.
edge_type_nlabels (graph2mat.core.data.processing.ArrayType) – Shape (n_edge_types,). Edge labels are sorted by edge type. This array contains the number of labels for each edge type.
data_processor – Data processor associated to this data.
metadata (Dict[str, Any]) – Contains any extra metadata that might be useful for the model or to postprocess outputs, for example.
See also
BasisMatrixData
The subclass that uses
numpy
arrays to store the data.graph2mat.bindings.torch.TorchBasisMatrixData
The subclass that uses
torch
tensors to store the data.
Methods
convert_to
(out_format[, threshold])copy
([cls])Copy data object, optionally to a different class.
ensure_numpy
(array)This function might be implemented by subclasses to convert from their output to numpy arrays.
from_config
(config, data_processor[, nsc])Creates a basis matrix data object from a configuration.
is_edge_attr
(key)is_node_attr
(key)new
(obj, data_processor[, labels])Creates a new basis matrix data object.
node_types_subgraph
(node_types)Returns a subgraph with only the nodes of the given types.
Returns object that provides data as numpy arrays.
process_input_array
(key, array)This function might be implemented by subclasses to e.g. convert the array to a torch tensor.
Attributes
Number of nodes in the configuration
Shape (2, n_edges).
Shape (n_edges,).
Shape (n_points, n_node_feats).
Shape (n_points, 3).
Shape (n_edges, 3).
Shape (3,3).
Total number of auxiliary cells.
Number of auxiliary cells required in each direction to account for all neighbor interactions.
Shape (n_point_labels,).
Shape (n_edge_labels,).
Shape (n_points,).
Shape (n_edges,).
Shape (n_edge_types,).
Contains any extra metadata that might be useful for the model or to postprocess outputs, for example.
- cell: ArrayType
Shape (3,3). Lattice vectors of the unit cell, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.
- copy(cls: type[BasisMatrixDataBase] | None = None)[source]
Copy data object, optionally to a different class.
Note that this does not copy the data arrays unless necessary (e.g. conversion from torch tensors to numpy arrays). It only creates a new container for the same data.
- Parameters:
cls – The class of the new object to create. If None, the object is copied into an object of the same class.
- edge_index: ArrayType
Shape (2, n_edges). Array with point pairs (their index in the configuration) that form an edge.
- edge_labels: ArrayType
Shape (n_edge_labels,). The elements of the target matrix that correspond to interactions between different nodes. This is flattened to deal with the fact that each block might have different shape.
All values for a given block come consecutively and in row-major order.
- edge_type_nlabels: ArrayType
Shape (n_edge_types,). Edge labels are sorted by edge type. This array contains the number of labels for each edge type.
- edge_types: ArrayType
Shape (n_edges,). The type of each edge as defined by the basis table, i.e. a
BasisTableWithEdges
.
- ensure_numpy(array: Any) ndarray [source]
This function might be implemented by subclasses to convert from their output to numpy arrays.
This is called by post processing utilities so that they can be sure they are dealing with numpy arrays.
- classmethod from_config(config: BasisConfiguration, data_processor: MatrixDataProcessor, nsc=None) BasisMatrixData [source]
Creates a basis matrix data object from a configuration.
- Parameters:
config – The configuration from which to create the basis matrix data object.
data_processor – The data processor that contains all the information needed to convert the configuration into the basis matrix data object. E.g. it contains the basis table.
- metadata: Dict[str, Any]
Contains any extra metadata that might be useful for the model or to postprocess outputs, for example. It includes the data processor.
- neigh_isc: ArrayType
Shape (n_edges,). Array with the index of the supercell where the second point of each edge is located. This follows the conventions in
sisl
- classmethod new(obj: BasisConfiguration | Geometry | SparseOrbital | str | Path, data_processor: MatrixDataProcessor, labels: bool = True, **kwargs) BasisMatrixData [source]
Creates a new basis matrix data object.
If obj is a configuration, the
from_config
method is called. Otherwise, we try to first create a configuration from the provided arguments and then call thefrom_config
method.- Parameters:
obj – The object to convert into this class.
data_processor – If obj is not a configuration, the data processor is needed to understand how to create the basis matrix data object. In any case, the data processor is needed to convert from configuration to basis matrix data ready for models to use (e.g. because it contains the basis table).
See also
OrbitalConfiguration.new
The method called to initialize a configuration if obj is not a configuration.
from_config
The method called to initialize the basis matrix data object.
- node_attrs: ArrayType
Shape (n_points, n_node_feats). Inputs for each point in the configuration.
- node_types_subgraph(node_types: ndarray) BasisMatrixData [source]
Returns a subgraph with only the nodes of the given types.
If the BasisMatrixData has labels (i.e. a matrix), this function will raise an error because we don’t support filtering labels yet.
- Parameters:
node_types – Array with the node types to keep.
- nsc: ArrayType
Number of auxiliary cells required in each direction to account for all neighbor interactions.
- point_labels: ArrayType
Shape (n_point_labels,). The elements of the target matrix that correspond to interactions within the same node. This is flattened to deal with the fact that each block might have different shape.
All values for a given block come consecutively and in row-major order.
- point_types: ArrayType
Shape (n_points,). The type of each point (index in the basis table, i.e. a
BasisTableWithEdges
).
- positions: ArrayType
Shape (n_points, 3). Coordinates of each point in the configuration, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.
- process_input_array(key: str, array: ndarray) Any [source]
This function might be implemented by subclasses to e.g. convert the array to a torch tensor.
- shifts: ArrayType
Shape (n_edges, 3). Shift of the second atom in each edge with respect to its image in the primary cell, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.