graph2mat.tools.lightning.cli

Implements a custom CLI that slightly tweaks pytorch_lightning’s default.

Most

Classes

OrbitalMatrixCLI([model_class, ...])

Custom pytorch_lightning CLI optimized for matrix learning.

SaveConfigSkipBasisTableCallback(parser, config)

class graph2mat.tools.lightning.cli.OrbitalMatrixCLI(model_class: type[LightningModule] | Callable[[...], LightningModule] | None = None, datamodule_class: type[LightningDataModule] | Callable[[...], LightningDataModule] | None = None, save_config_callback: type[SaveConfigCallback] | None = pytorch_lightning.cli.SaveConfigCallback, save_config_kwargs: dict[str, Any] | None = None, trainer_class: type[Trainer] | Callable[[...], Trainer] = pytorch_lightning.trainer.trainer.Trainer, trainer_defaults: dict[str, Any] | None = None, seed_everything_default: bool | int = True, parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: list[str] | dict[str, Any] | Namespace | None = None, run: bool = True, auto_configure_optimizers: bool = True)[source]

Bases: LightningCLI

Custom pytorch_lightning CLI optimized for matrix learning.

There are some defaults that change.

However, the most relevant change is that when loading a checkpoint with the ckpt_path key, all the options stored in the checkpoint file will be used as defaults. This change was made so that you can just load a checkpoint file and use it without needing to provide all the settings that were used to generate that checkpoint (which is the way raw pytorch_lightning works).

add_arguments_to_parser(parser: LightningArgumentParser)[source]

Implement to add extra arguments to the parser or link arguments.

Parameters:

parser – The parser object to which arguments can be added

before_instantiate_classes() None[source]

Implement to run some code before instantiating the classes.

parse_arguments(parser: LightningArgumentParser, args: list[str] | dict[str, Any] | Namespace | None) None[source]

Parses command line arguments and stores it in self.config.

class graph2mat.tools.lightning.cli.SaveConfigSkipBasisTableCallback(parser: LightningArgumentParser, config: Namespace, config_filename: str = 'config.yaml', overwrite: bool = False, multifile: bool = False)[source]

Bases: SaveConfigCallback

__init__(parser: LightningArgumentParser, config: Namespace, config_filename: str = 'config.yaml', overwrite: bool = False, multifile: bool = False) None[source]