graph2mat.tools.lightning.callbacks
Pytorch_lightning uses callbacks that can be plugged into the main loop.
Callbacks are independent from each other and can be included into the main loop, which might be training, testing, predicting… They provide some extra functionality like writing predictions to disk or progress tracking.
This module implements some callbacks that might be useful in particular for the matrix learning process.
Classes
|
Callback to write predicted matrices to disk. |
|
Callback to create plots of MAE and RMSE for each entry of matrix. |
|
Creates a CSV file with multiple metrics for each sample of the dataset. |
- class graph2mat.tools.lightning.callbacks.MatrixWriter(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]
Bases:
Callback
Callback to write predicted matrices to disk.
- __init__(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]
Callback to write produced matrices to disk.
- Parameters:
output_file –
Path to the output file.
The structures might contain as metadata the path from which they were read. In that case, if
output_file
is a relative path, it will be relative to the directory from which the structure was read.If
output_file
contains the string “$name$”, it will be replaced by the name of the directory from which the structure was read.splits – List of splits for which to write the matrices. Can be any combination of “train”, “val”, “test”, “predict”.
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the predict batch ends.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the test batch ends.
- class graph2mat.tools.lightning.callbacks.PlotMatrixError(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]
Bases:
Callback
Callback to create plots of MAE and RMSE for each entry of matrix.
It only works if the matrix for all structures has always exactly the same shape. This happens for example if the structures are all snapshots from a molecular dynamics simulation.
This callback is then useful to visualize how the errors are distributed within the matrix.
- Parameters:
split – Split for which to plot the errors. If None, the callback will be used for both validation and test splits.
show – Whether to show the plots or not.
store_in_logger – Whether to store the plots in the logger or not.
- __init__(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]
Called when the test batch ends.
- class graph2mat.tools.lightning.callbacks.SamplewiseMetricsLogger(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
Bases:
Callback
Creates a CSV file with multiple metrics for each sample of the dataset.
This callback is needed because otherwise the metrics are computed and logged on a per-batch basis.
Each row of the CSV file is a computation of all metrics for a single sample on a given epoch. Therefore, the csv file contains the following columns: [sample_name, …metrics…, split_key, epoch_index]
- Parameters:
metrics (Sequence[Type[OrbitalMatrixMetric]]) – List of metrics to compute.
splits (Sequence[str], optional) – List of splits for which to compute the metrics. Can be any combination of “train”, “val”, “test”.
output_file (Union[str, Path], optional) – Path to the output CSV file.
- __init__(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, pl_module)[source]
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()