graph2mat.models.mace
Classes
|
Model that wraps a MACE model to produce a matrix output. |
- class graph2mat.models.mace.MatrixMACE(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]
Bases:
ModuleModel that wraps a MACE model to produce a matrix output.
- Parameters:
mace – MACE model to wrap.
readout_per_interaction – If
True, a separate readout is applied to the features of each message passing interaction. IfFalse, the features of all interactions are concatenated and passed to a single readout.graph2mat_cls – Class of the graph2mat model to use for the readouts.
**kwargs – Additional keyword arguments to pass to
graph2mat_clsfor initialization.
- __init__(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: TorchBasisMatrixData, compute_force: bool = False, **kwargs) dict[str, Tensor][source]
Forward pass of the model.
- Parameters:
data – Input data.
compute_force – Passed directly to the
compute_forceargument of the MACE model.**kwargs – Additional keyword arguments to pass to the MACE model for the forward pass.
- Returns:
The output of the MACE model, with the additional keys “node_labels” and “edge_labels” containing the output of
Graph2Mat.- Return type:
output