graph2mat.bindings.torch.TorchBasisMatrixDataset

class graph2mat.bindings.torch.TorchBasisMatrixDataset(input_data: Sequence[BasisConfiguration | Path | str | Geometry], data_processor: MatrixDataProcessor, data_cls: Type[TorchBasisMatrixData] = graph2mat.bindings.torch.data.data.TorchBasisMatrixData, load_labels: bool = True)[source]

Bases: Dataset

Stores all configuration info of a dataset.

Given all of its arguments, it has information to generate all the BasisMatrixTorchData objects. However, the objects are created on the fly as they are requested. They are not stored by this class.

torch_geometric’s data loader can be used out of the box to load data from this dataset.

Parameters:
  • input_data – A list of input data. Each item can be of any kind that is possible to convert to the class specified by data_cls using the new method.

  • data_processor – A data processor object that is passed to data_cls.new to assist with the creation of the data objects from the input_data.

  • data_cls – The class of the data objects that will be generated from this dataset. Must have a new method that takes the input_data and data_processor as arguments to create a new object. The new method also receives a labels argument specifying whether matrix labels should be loaded or not for the configurations.

  • load_labels – Whether to load the matrix labels or not.

See also

InMemoryData

A wrapper for a dataset that loads all data into memory.

RotatingPoolData

A wrapper for a dataset that continously loads data into a smaller pool.

Examples

from graph2mat import BasisConfiguration, MatrixDataProcessor
from graph2mat.bindings.torch import

# Initialize basis configurations (substitute ... by appropriate arguments)
config_1 = BasisConfiguration(...)
config_2 = BasisConfiguration(...)

# Initialize data processor (substitute ... by appropriate arguments)
processor = MatrixDataProcessor(...)

# Initialize dataset
dataset = TorchBasisMatrixDataset([config_1, config_2], processor)

# Import the loader class from torch_geometric
from torch_geometric.loader import DataLoader

# Create a data loader from this dataset
loader = DataLoader(dataset, batch_size=2)

Methods