graph2mat.models.mace
Classes
|
Model that wraps a MACE model to produce a matrix output. |
- class graph2mat.models.mace.MatrixMACE(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]
Bases:
Module
Model that wraps a MACE model to produce a matrix output.
- Parameters:
mace – MACE model to wrap.
readout_per_interaction – If
True
, a separate readout is applied to the features of each message passing interaction. IfFalse
, the features of all interactions are concatenated and passed to a single readout.graph2mat_cls – Class of the graph2mat model to use for the readouts.
**kwargs – Additional keyword arguments to pass to
graph2mat_cls
for initialization.
- __init__(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: TorchBasisMatrixData, compute_force: bool = False, **kwargs) dict[str, Tensor] [source]
Forward pass of the model.
- Parameters:
data – Input data.
compute_force – Passed directly to the
compute_force
argument of the MACE model.**kwargs – Additional keyword arguments to pass to the MACE model for the forward pass.
- Returns:
The output of the MACE model, with the additional keys “node_labels” and “edge_labels” containing the output of
Graph2Mat
.- Return type:
output