graph2mat.tools.lightning.callbacks

Pytorch_lightning uses callbacks that can be plugged into the main loop.

Callbacks are independent from each other and can be included into the main loop, which might be training, testing, predicting… They provide some extra functionality like writing predictions to disk or progress tracking.

This module implements some callbacks that might be useful in particular for the matrix learning process.

Classes

MatrixWriter(output_file[, splits])

Callback to write predicted matrices to disk.

PlotMatrixError([split, show, store_in_logger])

Add plots of MAE and RMSE for each entry of matrix.

SamplewiseMetricsLogger([metrics, splits, ...])

Creates a CSV file with multiple metrics for each sample of the dataset.

class graph2mat.tools.lightning.callbacks.MatrixWriter(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]

Bases: Callback

Callback to write predicted matrices to disk.

__init__(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the predict batch ends.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the test batch ends.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the validation batch ends.

class graph2mat.tools.lightning.callbacks.PlotMatrixError(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]

Bases: Callback

Add plots of MAE and RMSE for each entry of matrix. Does only work if the matrix is the same format for every datapoint as in molecular dynamics data

__init__(split: Literal[None, 'val', 'test'] = None, show: bool = False, store_in_logger: bool = True)[source]
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

class graph2mat.tools.lightning.callbacks.SamplewiseMetricsLogger(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]

Bases: Callback

Creates a CSV file with multiple metrics for each sample of the dataset.

This callback is needed because otherwise the metrics are computed and logged on a per-batch basis.

Each row of the CSV file is a computation of all metrics for a single sample on a given epoch. Therefore, the csv file contains the following columns: [sample_name, …metrics…, split_key, epoch_index]

Parameters:
  • metrics (Sequence[Type[OrbitalMatrixMetric]]) – List of metrics to compute.

  • splits (Sequence[str], optional) – List of splits for which to compute the metrics. Can be any combination of “train”, “val”, “test”.

  • output_file (Union[str, Path], optional) – Path to the output CSV file.

__init__(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
close_file_handle()[source]
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

open_file_handle()[source]