graph2mat.bindings.torch.load
Functions
|
Load a model from a Lightning checkpoint file. |
|
Makes sure that the checkpoint is compatible with the current version of e3nn_matrix. |
- graph2mat.bindings.torch.load.load_from_lit_ckpt(ckpt_file: Path | str, cpu: bool = True, as_torch: bool = False) Tuple[Module, MatrixDataProcessor] [source]
Load a model from a Lightning checkpoint file.
- Parameters:
ckpt_file (Union[Path, str]) – Path to the checkpoint file.
cpu (bool, optional) – If True, the model is loaded on the CPU regardless of whether it was in the GPU when saved, by default True.
as_torch (bool, optional) – If True, the model is returned as the bare torch.nn.Module, otherwise it is returned as a lightning module.
- Returns:
torch.nn.Module – The model
MatrixDataProcessor – The processor to use for processing inputs and outputs.