graph2mat.BasisMatrixDataBase

class graph2mat.BasisMatrixDataBase(edge_index: ndarray | None = None, neigh_isc: ndarray | None = None, node_attrs: ndarray | None = None, positions: ndarray | None = None, shifts: ndarray | None = None, cell: ndarray | None = None, nsc: ndarray | None = None, point_labels: ndarray | None = None, edge_labels: ndarray | None = None, labels_point_filter: ndarray | None = None, labels_edge_filter: ndarray | None = None, point_types: ndarray | None = None, edge_types: ndarray | None = None, edge_type_nlabels: ndarray | None = None, data_processor: MatrixDataProcessor = None, metadata: Dict[str, Any] | None = None, already_basis: bool = False)[source]

Bases: Generic[ArrayType]

Stores a graph with the preprocessed data for one or multiple configurations.

Warning

This class just implements the generic functionality and you should not use it directly. Depending on the type of arrays that you use to store the data, you should use the corresponding subclass. E.g. BasisMatrixData for numpy arrays, or TorchBasisMatrixData for torch tensors.

The differences between this class and BasisConfiguration are:

  • This class stores examples as graphs, while BasisConfiguration just stores the raw data.

  • This class might store a batch of examples inside the same graph. Different dataset examples are just graph clusters that are not connected with each other.

This class is the main interface between the data and the models.

The class accepts positions, cell, and displacements in cartesian coordinates, but they are converted to the convention specified by the data processor (e.g. spherical harmonics), and stored in this way.

Parameters:
  • edge_index (graph2mat.core.data.processing.ArrayType) – Shape (2, n_edges). Array with point pairs (their index in the configuration) that form an edge.

  • neigh_isc (graph2mat.core.data.processing.ArrayType) – Shape (n_edges,). Array with the index of the supercell where the second point of each edge is located. This follows the conventions in sisl

  • node_attrs (graph2mat.core.data.processing.ArrayType) – Shape (n_points, n_node_feats). Inputs for each point in the configuration.

  • positions (graph2mat.core.data.processing.ArrayType) – Shape (n_points, 3). Cartesian coordinates of each point in the configuration.

  • shifts (graph2mat.core.data.processing.ArrayType) – Shape (n_edges, 3). Cartesian shift of the second atom in each edge with respect to its image in the primary cell. E.g. if the second atom is in the primary cell, the shift will be [0,0,0].

  • cell (graph2mat.core.data.processing.ArrayType) – Shape (3,3). Lattice vectors of the unit cell in cartesian coordinates.

  • nsc (graph2mat.core.data.processing.ArrayType) – Shape (3,). Number of auxiliary cells required in each direction to account for all neighbor interactions.

  • point_labels (graph2mat.core.data.processing.ArrayType) –

    Shape (n_point_labels,). The elements of the target matrix that correspond to interactions within the same node. This is flattened to deal with the fact that each block might have different shape.

    All values for a given block come consecutively and in row-major order.

  • edge_labels (graph2mat.core.data.processing.ArrayType) –

    Shape (n_edge_labels,). The elements of the target matrix that correspond to interactions between different nodes. This is flattened to deal with the fact that each block might have different shape.

    All values for a given block come consecutively and in row-major order.

    NOTE: These should be sorted by edge type.

  • point_types (graph2mat.core.data.processing.ArrayType) – Shape (n_points,). The type of each point (index in the basis table).

  • edge_types (graph2mat.core.data.processing.ArrayType) – Shape (n_edges,). The type of each edge as defined by the basis table.

  • edge_type_nlabels (graph2mat.core.data.processing.ArrayType) – Shape (n_edge_types,). Edge labels are sorted by edge type. This array contains the number of labels for each edge type.

  • data_processor – Data processor associated to this data.

  • metadata (Dict[str, Any]) – Contains any extra metadata that might be useful for the model or to postprocess outputs, for example.

See also

BasisMatrixData

The subclass that uses numpy arrays to store the data.

graph2mat.bindings.torch.TorchBasisMatrixData

The subclass that uses torch tensors to store the data.

Methods

convert_to(out_format[, threshold])

copy([cls])

Copy data object, optionally to a different class.

ensure_numpy(array)

This function might be implemented by subclasses to convert from their output to numpy arrays.

from_config(config, data_processor[, nsc])

Creates a basis matrix data object from a configuration.

is_edge_attr(key)

is_node_attr(key)

new(obj, data_processor[, labels])

Creates a new basis matrix data object.

node_types_subgraph(node_types)

Returns a subgraph with only the nodes of the given types.

numpy_arrays()

Returns object that provides data as numpy arrays.

process_input_array(key, array)

This function might be implemented by subclasses to e.g. convert the array to a torch tensor.

Attributes

num_nodes

Number of nodes in the configuration

edge_index

Shape (2, n_edges).

neigh_isc

Shape (n_edges,).

node_attrs

Shape (n_points, n_node_feats).

positions

Shape (n_points, 3).

shifts

Shape (n_edges, 3).

cell

Shape (3,3).

n_supercells

Total number of auxiliary cells.

nsc

Number of auxiliary cells required in each direction to account for all neighbor interactions.

point_labels

Shape (n_point_labels,).

edge_labels

Shape (n_edge_labels,).

point_types

Shape (n_points,).

edge_types

Shape (n_edges,).

edge_type_nlabels

Shape (n_edge_types,).

labels_point_filter

labels_edge_filter

metadata

Contains any extra metadata that might be useful for the model or to postprocess outputs, for example.

cell: ArrayType

Shape (3,3). Lattice vectors of the unit cell, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.

convert_to(out_format: str, threshold: float | None = None, **kwargs)[source]
copy(cls: type[BasisMatrixDataBase] | None = None)[source]

Copy data object, optionally to a different class.

Note that this does not copy the data arrays unless necessary (e.g. conversion from torch tensors to numpy arrays). It only creates a new container for the same data.

Parameters:

cls – The class of the new object to create. If None, the object is copied into an object of the same class.

edge_index: ArrayType

Shape (2, n_edges). Array with point pairs (their index in the configuration) that form an edge.

edge_labels: ArrayType

Shape (n_edge_labels,). The elements of the target matrix that correspond to interactions between different nodes. This is flattened to deal with the fact that each block might have different shape.

All values for a given block come consecutively and in row-major order.

edge_type_nlabels: ArrayType

Shape (n_edge_types,). Edge labels are sorted by edge type. This array contains the number of labels for each edge type.

edge_types: ArrayType

Shape (n_edges,). The type of each edge as defined by the basis table, i.e. a BasisTableWithEdges.

ensure_numpy(array: Any) ndarray[source]

This function might be implemented by subclasses to convert from their output to numpy arrays.

This is called by post processing utilities so that they can be sure they are dealing with numpy arrays.

classmethod from_config(config: BasisConfiguration, data_processor: MatrixDataProcessor, nsc=None) BasisMatrixData[source]

Creates a basis matrix data object from a configuration.

Parameters:
  • config – The configuration from which to create the basis matrix data object.

  • data_processor – The data processor that contains all the information needed to convert the configuration into the basis matrix data object. E.g. it contains the basis table.

is_edge_attr(key: str) bool[source]
is_node_attr(key: str) bool[source]
labels_edge_filter: ndarray
labels_point_filter: ndarray
metadata: Dict[str, Any]

Contains any extra metadata that might be useful for the model or to postprocess outputs, for example. It includes the data processor.

n_supercells: int

Total number of auxiliary cells.

neigh_isc: ArrayType

Shape (n_edges,). Array with the index of the supercell where the second point of each edge is located. This follows the conventions in sisl

classmethod new(obj: BasisConfiguration | Geometry | SparseOrbital | str | Path, data_processor: MatrixDataProcessor, labels: bool = True, **kwargs) BasisMatrixData[source]

Creates a new basis matrix data object.

If obj is a configuration, the from_config method is called. Otherwise, we try to first create a configuration from the provided arguments and then call the from_config method.

Parameters:
  • obj – The object to convert into this class.

  • data_processor – If obj is not a configuration, the data processor is needed to understand how to create the basis matrix data object. In any case, the data processor is needed to convert from configuration to basis matrix data ready for models to use (e.g. because it contains the basis table).

See also

OrbitalConfiguration.new

The method called to initialize a configuration if obj is not a configuration.

from_config

The method called to initialize the basis matrix data object.

node_attrs: ArrayType

Shape (n_points, n_node_feats). Inputs for each point in the configuration.

node_types_subgraph(node_types: ndarray) BasisMatrixData[source]

Returns a subgraph with only the nodes of the given types.

If the BasisMatrixData has labels (i.e. a matrix), this function will raise an error because we don’t support filtering labels yet.

Parameters:

node_types – Array with the node types to keep.

nsc: ArrayType

Number of auxiliary cells required in each direction to account for all neighbor interactions.

num_nodes: int | None

Number of nodes in the configuration

numpy_arrays() NumpyArraysProvider[source]

Returns object that provides data as numpy arrays.

point_labels: ArrayType

Shape (n_point_labels,). The elements of the target matrix that correspond to interactions within the same node. This is flattened to deal with the fact that each block might have different shape.

All values for a given block come consecutively and in row-major order.

point_types: ArrayType

Shape (n_points,). The type of each point (index in the basis table, i.e. a BasisTableWithEdges).

positions: ArrayType

Shape (n_points, 3). Coordinates of each point in the configuration, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.

process_input_array(key: str, array: ndarray) Any[source]

This function might be implemented by subclasses to e.g. convert the array to a torch tensor.

shifts: ArrayType

Shape (n_edges, 3). Shift of the second atom in each edge with respect to its image in the primary cell, in the convention specified by the data processor (e.g. spherical harmonics). IMPORTANT: This is not necessarily in cartesian coordinates.