graph2mat.models.mace

Classes

MatrixMACE(mace[, readout_per_interaction, ...])

Model that wraps a MACE model to produce a matrix output.

class graph2mat.models.mace.MatrixMACE(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]

Bases: Module

Model that wraps a MACE model to produce a matrix output.

Parameters:
  • mace – MACE model to wrap.

  • readout_per_interaction – If True, a separate readout is applied to the features of each message passing interaction. If False, the features of all interactions are concatenated and passed to a single readout.

  • graph2mat_cls – Class of the graph2mat model to use for the readouts.

  • **kwargs – Additional keyword arguments to pass to graph2mat_cls for initialization.

__init__(mace: MACE, readout_per_interaction: bool = False, graph2mat_cls: type[Graph2Mat] = graph2mat.bindings.e3nn.modules.graph2mat.E3nnGraph2Mat, **kwargs)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: TorchBasisMatrixData, compute_force: bool = False, **kwargs) dict[str, Tensor][source]

Forward pass of the model.

Parameters:
  • data – Input data.

  • compute_force – Passed directly to the compute_force argument of the MACE model.

  • **kwargs – Additional keyword arguments to pass to the MACE model for the forward pass.

Returns:

The output of the MACE model, with the additional keys “node_labels” and “edge_labels” containing the output of Graph2Mat.

Return type:

output