graph2mat.bindings.torch.load

Functions

load_from_lit_ckpt(ckpt_file[, cpu, as_torch])

Load a model from a Lightning checkpoint file.

sanitize_checkpoint(checkpoint)

Makes sure that the checkpoint is compatible with the current version of e3nn_matrix.

graph2mat.bindings.torch.load.load_from_lit_ckpt(ckpt_file: Path | str, cpu: bool = True, as_torch: bool = False) Tuple[Module, MatrixDataProcessor][source]

Load a model from a Lightning checkpoint file.

Parameters:
  • ckpt_file (Union[Path, str]) – Path to the checkpoint file.

  • cpu (bool, optional) – If True, the model is loaded on the CPU regardless of whether it was in the GPU when saved, by default True.

  • as_torch (bool, optional) – If True, the model is returned as the bare torch.nn.Module, otherwise it is returned as a lightning module.

Returns:

  • torch.nn.Module – The model

  • MatrixDataProcessor – The processor to use for processing inputs and outputs.

graph2mat.bindings.torch.load.sanitize_checkpoint(checkpoint: dict) dict[source]

Makes sure that the checkpoint is compatible with the current version of e3nn_matrix.