graph2mat.tools.lightning.callbacks
Pytorch_lightning uses callbacks that can be plugged into the main loop.
Callbacks are independent from each other and can be included into the main loop, which might be training, testing, predicting… They provide some extra functionality like writing predictions to disk or progress tracking.
This module implements some callbacks that might be useful in particular for the matrix learning process.
Classes
|
Callback to write predicted matrices to disk. |
|
Add plots of MAE and RMSE for each entry of matrix. |
|
Creates a CSV file with multiple metrics for each sample of the dataset. |
- class graph2mat.tools.lightning.callbacks.MatrixWriter(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]
Bases:
Callback
Callback to write predicted matrices to disk.
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the predict batch ends.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the test batch ends.
- class graph2mat.tools.lightning.callbacks.PlotMatrixError(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]
Bases:
Callback
Add plots of MAE and RMSE for each entry of matrix. Does only work if the matrix is the same format for every datapoint as in molecular dynamics data
- __init__(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]
Called when the test batch ends.
- class graph2mat.tools.lightning.callbacks.SamplewiseMetricsLogger(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
Bases:
Callback
Creates a CSV file with multiple metrics for each sample of the dataset.
This callback is needed because otherwise the metrics are computed and logged on a per-batch basis.
Each row of the CSV file is a computation of all metrics for a single sample on a given epoch. Therefore, the csv file contains the following columns: [sample_name, …metrics…, split_key, epoch_index]
- Parameters:
metrics (Sequence[Type[OrbitalMatrixMetric]]) – List of metrics to compute.
splits (Sequence[str], optional) – List of splits for which to compute the metrics. Can be any combination of “train”, “val”, “test”.
output_file (Union[str, Path], optional) – Path to the output CSV file.
- __init__(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, pl_module)[source]
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()