MACE+Graph2Mat
This notebook will show you how to integrate a MACE
model with Graph2Mat
through the python API. Note that you can also use MACE+Graph2Mat
through the Command Line Interface (CLI).
Prerequisites
Before reading this notebook, make sure you have read the notebook on computing a matrix and the notebook on batching, which introduce the basic concepts of graph2mat
that we are going to assume are already known. Also we will use exactly the same setup as in the batching notebook, with the only difference that we will add target matrices to each structure.
[1]:
import numpy as np
import pandas as pd
import torch
# To load plotly templates for sisl visualization
import sisl.viz
from e3nn import o3
from graph2mat import (
BasisConfiguration,
PointBasis,
BasisTableWithEdges,
MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat
from graph2mat.tools.viz import plot_basis_matrix
Generating a dataset
We generate a dataset here just as we have done in the other notebooks.
[2]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")
# The basis table.
table = BasisTableWithEdges([point_1, point_2])
# The data processor.
processor = MatrixDataProcessor(
basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)
positions = np.array([[0, 0, 0], [6.0, 0, 0], [9, 0, 0]])
config1 = BasisConfiguration(
point_types=["A", "B", "A"],
positions=positions,
basis=[point_1, point_2],
cell=np.eye(3) * 100,
pbc=(False, False, False),
)
config2 = BasisConfiguration(
point_types=["B", "A", "B"],
positions=positions,
basis=[point_1, point_2],
cell=np.eye(3) * 100,
pbc=(False, False, False),
)
configs = [config1, config2]
dataset = TorchBasisMatrixDataset(configs, data_processor=processor)
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=2)
data = next(iter(loader))
Initializing a MACE model
We will now initialize a normal MACE model.
Note that you must have MACE installed, which you can do with:
pip install mace_torch
[3]:
from mace.modules import MACE, RealAgnosticResidualInteractionBlock
num_interactions = 3
hidden_irreps = o3.Irreps("1x0e + 1x1o")
mace_model = MACE(
r_max=10,
num_bessel=10,
num_polynomial_cutoff=10,
max_ell=2, # 1,
interaction_cls=RealAgnosticResidualInteractionBlock,
interaction_cls_first=RealAgnosticResidualInteractionBlock,
num_interactions=num_interactions,
num_elements=2,
hidden_irreps=hidden_irreps,
MLP_irreps=o3.Irreps("2x0e"),
atomic_energies=torch.tensor([0, 0]),
avg_num_neighbors=2,
atomic_numbers=[0, 1],
correlation=2,
gate=None,
)
cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/mace/modules/blocks.py:312: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
Now, we can pass our data through the mace model. MACE outputs many things, but we are just interested in the node features, which we can get from the "node_feats"
key.
[4]:
mace_output = mace_model(data)
mace_output["node_feats"]
[4]:
tensor([[ 9.0013e-02, 0.0000e+00, 0.0000e+00, 1.1366e-03, -6.8797e-02,
0.0000e+00, 0.0000e+00, -3.2418e-04, -6.9800e-02],
[-9.0031e-01, 0.0000e+00, 0.0000e+00, 1.0144e-03, 5.5003e-01,
0.0000e+00, 0.0000e+00, 5.0510e-04, 4.2952e-01],
[ 1.0378e-01, 0.0000e+00, 0.0000e+00, 2.6012e-03, -7.8224e-02,
0.0000e+00, 0.0000e+00, -7.5807e-04, -7.9285e-02],
[-8.9246e-01, 0.0000e+00, 0.0000e+00, 7.0245e-04, 5.4509e-01,
0.0000e+00, 0.0000e+00, 3.3149e-04, 4.2565e-01],
[ 1.0681e-01, 0.0000e+00, 0.0000e+00, -3.7303e-03, -8.0749e-02,
0.0000e+00, 0.0000e+00, 1.1233e-03, -8.1784e-02],
[-8.9810e-01, 0.0000e+00, 0.0000e+00, -1.7167e-03, 5.4874e-01,
0.0000e+00, 0.0000e+00, -8.3481e-04, 4.2851e-01]],
grad_fn=<CatBackward0>)
Our Graph2Mat
model will take these node features and convert them to a matrix. Therefore we need to know what its irreps are, and then initialize the Graph2Mat
module.
[5]:
# MACE outputs as node features the hidden irreps for each interaction, except
# in the last interaction, where it computes just scalar features.
mace_out_irreps = hidden_irreps * (num_interactions - 1) + str(hidden_irreps[0])
# Initialize the matrix model with this information
matrix_model = E3nnGraph2Mat(
unique_basis=table,
irreps=dict(node_feats_irreps=mace_out_irreps),
symmetric=True,
# We would need to also implement passing the edge information in order to use
# preprocessing_edges. As shown later, graph2mat can do this automatically for you.
preprocessing_edges=None,
)
Now, we can use the matrix model, passing the node features computed by MACE:
[6]:
node_labels, edge_labels = matrix_model(data=data, node_feats=mace_output["node_feats"])
And plot the obtained matrices:
[7]:
matrices = processor.matrix_from_data(
data,
predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)
for config, matrix in zip(configs, matrices):
plot_basis_matrix(
matrix,
config,
point_lines={"color": "black"},
basis_lines={"color": "blue"},
colorscale="temps",
text=".2f",
basis_labels=True,
).show()
Data type cannot be displayed: application/vnd.plotly.v1+json
Data type cannot be displayed: application/vnd.plotly.v1+json
Using MatrixMACE
If you don’t want to handle the details of interacting MACE
with Graph2Mat
, you can also use MatrixMACE
, which takes a mace model and wraps it to also output the node_labels
and edge_labels
corresponding to a matrix.
Internally, it just initializes a E3nnGraph2Mat
layer. However it can handle the interaction between MACE
and Graph2Mat
in more complex cases like having an extra preprocessing step for edges, which needs some extra inputs from MACE.
[8]:
from graph2mat.models import MatrixMACE
from graph2mat.bindings.e3nn import E3nnEdgeMessageBlock
[9]:
matrix_mace_model = MatrixMACE(
mace_model,
unique_basis=table,
readout_per_interaction=True,
edge_hidden_irreps=o3.Irreps("10x0e + 10x1o + 10x2e"),
symmetric=True,
)
The output of this model is MACE’s output plus the node_labels
and edge_labels
for the predicted matrix:
[10]:
out = matrix_mace_model(data)
out
[10]:
{'energy': tensor([0.8221, 1.8579], grad_fn=<SumBackward1>),
'node_energy': tensor([-0.0717, 0.9773, -0.0835, 0.9688, -0.0858, 0.9749],
grad_fn=<SumBackward1>),
'contributions': tensor([[ 0.0000, 0.0000, 0.9226, 0.2668, -0.3673],
[ 0.0000, 0.0000, 2.1987, 0.6706, -1.0115]],
grad_fn=<StackBackward0>),
'forces': None,
'edge_forces': None,
'virials': None,
'stress': None,
'atomic_virials': None,
'atomic_stresses': None,
'displacement': tensor([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]),
'hessian': None,
'node_feats': tensor([[ 9.0013e-02, 0.0000e+00, 0.0000e+00, 1.1366e-03, -6.8797e-02,
0.0000e+00, 0.0000e+00, -3.2418e-04, -6.9800e-02],
[-9.0031e-01, 0.0000e+00, 0.0000e+00, 1.0144e-03, 5.5003e-01,
0.0000e+00, 0.0000e+00, 5.0510e-04, 4.2952e-01],
[ 1.0378e-01, 0.0000e+00, 0.0000e+00, 2.6012e-03, -7.8224e-02,
0.0000e+00, 0.0000e+00, -7.5807e-04, -7.9285e-02],
[-8.9246e-01, 0.0000e+00, 0.0000e+00, 7.0245e-04, 5.4509e-01,
0.0000e+00, 0.0000e+00, 3.3149e-04, 4.2565e-01],
[ 1.0681e-01, 0.0000e+00, 0.0000e+00, -3.7303e-03, -8.0749e-02,
0.0000e+00, 0.0000e+00, 1.1233e-03, -8.1784e-02],
[-8.9810e-01, 0.0000e+00, 0.0000e+00, -1.7167e-03, 5.4874e-01,
0.0000e+00, 0.0000e+00, -8.3481e-04, 4.2851e-01]],
grad_fn=<CatBackward0>),
'node_labels': tensor([ 1.1421e-03, -3.3615e-03, -1.7687e-02, 0.0000e+00, 0.0000e+00,
-3.0053e-04, -1.7687e-02, 2.2053e-01, 0.0000e+00, 0.0000e+00,
-1.5540e-04, 0.0000e+00, 0.0000e+00, 6.2363e-02, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.2363e-02,
0.0000e+00, -3.0053e-04, -1.5540e-04, 0.0000e+00, 0.0000e+00,
6.2363e-02, 1.5125e-03, -3.2697e-03, -1.7366e-02, 0.0000e+00,
0.0000e+00, -2.0601e-04, -1.7366e-02, 2.1663e-01, 0.0000e+00,
0.0000e+00, -1.0374e-04, 0.0000e+00, 0.0000e+00, 6.1270e-02,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
6.1270e-02, 0.0000e+00, -2.0601e-04, -1.0374e-04, 0.0000e+00,
0.0000e+00, 6.1270e-02, 1.6016e-03, -3.3589e-03, -1.7606e-02,
0.0000e+00, 0.0000e+00, 5.0701e-04, -1.7606e-02, 2.1948e-01,
0.0000e+00, 0.0000e+00, 2.5915e-04, 0.0000e+00, 0.0000e+00,
6.2062e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 6.2062e-02, 0.0000e+00, 5.0701e-04, 2.5915e-04,
0.0000e+00, 0.0000e+00, 6.2062e-02], grad_fn=<MeanBackward1>),
'edge_labels': tensor([ 2.8645e-05, -3.5231e-05, 0.0000e+00, 0.0000e+00, -1.6517e-05,
-1.0098e-05, -1.6300e-05, 0.0000e+00, 0.0000e+00, 3.1383e-05,
3.3831e-05, -4.1323e-05, 0.0000e+00, 0.0000e+00, 1.9075e-05,
-1.0373e-05, -1.6752e-05, 0.0000e+00, 0.0000e+00, -3.1856e-05,
-9.8223e-07, 7.5770e-07, 0.0000e+00, 0.0000e+00, 9.5773e-08,
7.5781e-07, 2.8189e-06, 0.0000e+00, 0.0000e+00, -2.6045e-08,
0.0000e+00, 0.0000e+00, 6.0607e-07, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0607e-07, 0.0000e+00,
-9.6107e-08, 2.6127e-08, 0.0000e+00, 0.0000e+00, 1.0831e-06],
grad_fn=<MeanBackward1>)}
You can of course plot the predicted matrices:
[11]:
matrices = processor.matrix_from_data(data, predictions=out)
for config, matrix in zip(configs, matrices):
plot_basis_matrix(
matrix,
config,
point_lines={"color": "black"},
basis_lines={"color": "blue"},
colorscale="temps",
text=".2f",
basis_labels=True,
).show()
Data type cannot be displayed: application/vnd.plotly.v1+json
Data type cannot be displayed: application/vnd.plotly.v1+json
Summary and next steps
In this notebook we learned how to interface MACE with Graph2Mat.
The next steps could be:
Train a MACE+Graph2Mat model following the steps in this notebook, replacing the model by the
MACE+Graph2Mat
model.