graph2mat.tools.lightning.MatrixWriter

class graph2mat.tools.lightning.MatrixWriter(output_file: str, splits: Sequence = ['train', 'val', 'test', 'predict'])[source]

Bases: Callback

Callback to write predicted matrices to disk.

Methods

on_predict_batch_end(trainer, pl_module, ...)

Called when the predict batch ends.

on_test_batch_end(trainer, pl_module, ...[, ...])

Called when the test batch ends.

on_train_batch_end(trainer, pl_module, ...)

Called when the train batch ends.

on_validation_batch_end(trainer, pl_module, ...)

Called when the validation batch ends.

Attributes

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the predict batch ends.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the test batch ends.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None)[source]

Called when the validation batch ends.