graph2mat.tools.lightning.MatrixDataModule
- class graph2mat.tools.lightning.MatrixDataModule(out_matrix: Literal['density_matrix', 'hamiltonian', 'energy_density_matrix', 'dynamical_matrix'] | None = None, basis_files: str | None = None, no_basis: dict | None = None, basis_table: BasisTableWithEdges | None = None, root_dir: str = '.', train_runs: str | None = None, val_runs: str | None = None, test_runs: str | None = None, predict_structs: str | None = None, runs_json: str | None = None, symmetric_matrix: bool = False, sub_point_matrix: bool = True, batch_size: int = 5, loader_threads: int = 1, copy_root_to_tmp: bool = False, store_in_memory: bool = False, rotating_pool_size: int | None = None, initial_node_feats: str = 'OneHotZ')[source]
Bases:
LightningDataModule
Methods
An iterable or collection of iterables specifying prediction samples.
Use this to download and prepare data.
setup
(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
teardown
(stage)Called at the end of fit (train + validate), validate, test, or predict.
An iterable or collection of iterables specifying test samples.
An iterable or collection of iterables specifying training samples.
An iterable or collection of iterables specifying validation samples.
Attributes
- predict_dataloader()[source]
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data
.predict
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying prediction samples.
- prepare_data()[source]
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setup
instead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In a distributed environment,
prepare_data
can be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- setup(stage: str)[source]
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- teardown(stage: str)[source]
Called at the end of fit (train + validate), validate, test, or predict.
- Parameters:
stage – either
'fit'
,'validate'
,'test'
, or'predict'
- test_dataloader()[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data
process and split in
setup
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step
, you don’t need to implement this method.
- train_dataloader()[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data
process and split in
setup
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data
.fit
validate
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step
, you don’t need to implement this method.