graph2mat.bindings.torch.TorchBasisMatrixData

class graph2mat.bindings.torch.TorchBasisMatrixData(*args, **kwargs)[source]

Bases: BasisMatrixDataBase[Tensor], Data

Extension of BasisMatrixDataBase to be used within pytorch.

All this class implements is the conversion of numpy arrays to torch tensors and back. The rest of the functionality is inherited from BasisMatrixDataBase.

Please refer to the documentation of BasisMatrixDataBase for more information.

See also

graph2mat.BasisMatrixDataBase

The class that implements the heavy lifting of the data processing.

Methods

Attributes

neigh_isc

Shape (n_edges,).

positions

Shape (n_points, 3).

shifts

Shape (n_edges, 3).

cell

Shape (3,3).

n_supercells

Total number of auxiliary cells.

nsc

Number of auxiliary cells required in each direction to account for all neighbor interactions.

point_labels

Shape (n_point_labels,).

edge_labels

Shape (n_edge_labels,).

point_types

Shape (n_points,).

edge_types

Shape (n_edges,).

edge_type_nlabels

Shape (n_edge_types,).

labels_point_filter

labels_edge_filter

metadata

Contains any extra metadata that might be useful for the model or to postprocess outputs, for example.