graph2mat.tools.lightning.SamplewiseMetricsLogger

class graph2mat.tools.lightning.SamplewiseMetricsLogger(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]

Bases: Callback

Creates a CSV file with multiple metrics for each sample of the dataset.

This callback is needed because otherwise the metrics are computed and logged on a per-batch basis.

Each row of the CSV file is a computation of all metrics for a single sample on a given epoch. Therefore, the csv file contains the following columns: [sample_name, …metrics…, split_key, epoch_index]

Parameters:
  • metrics (Sequence[Type[OrbitalMatrixMetric]]) – List of metrics to compute.

  • splits (Sequence[str], optional) – List of splits for which to compute the metrics. Can be any combination of “train”, “val”, “test”.

  • output_file (Union[str, Path], optional) – Path to the output CSV file.

Methods

close_file_handle()

on_test_batch_end(trainer, pl_module, ...[, ...])

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

on_train_batch_end(trainer, pl_module, ...)

Called when the train batch ends.

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

on_train_epoch_start(trainer, pl_module)

Called when the train epoch begins.

on_validation_batch_end(trainer, pl_module, ...)

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

open_file_handle()

Attributes

close_file_handle()[source]
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

open_file_handle()[source]