graph2mat.tools.lightning.SamplewiseMetricsLogger
- class graph2mat.tools.lightning.SamplewiseMetricsLogger(metrics: Sequence[Type[OrbitalMatrixMetric]] | None = None, splits: Sequence = ['train', 'val', 'test'], output_file: str | Path = 'sample_metrics.csv')[source]
Bases:
Callback
Creates a CSV file with multiple metrics for each sample of the dataset.
This callback is needed because otherwise the metrics are computed and logged on a per-batch basis.
Each row of the CSV file is a computation of all metrics for a single sample on a given epoch. Therefore, the csv file contains the following columns: [sample_name, …metrics…, split_key, epoch_index]
- Parameters:
metrics (Sequence[Type[OrbitalMatrixMetric]]) – List of metrics to compute.
splits (Sequence[str], optional) – List of splits for which to compute the metrics. Can be any combination of “train”, “val”, “test”.
output_file (Union[str, Path], optional) – Path to the output CSV file.
Methods
on_test_batch_end
(trainer, pl_module, ...[, ...])Called when the test batch ends.
on_test_epoch_end
(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start
(trainer, pl_module)Called when the test epoch begins.
on_train_batch_end
(trainer, pl_module, ...)Called when the train batch ends.
on_train_epoch_end
(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start
(trainer, pl_module)Called when the train epoch begins.
on_validation_batch_end
(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_epoch_end
(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start
(trainer, pl_module)Called when the val epoch begins.
Attributes
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, pl_module)[source]
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()