ipynb download badge.

MACE+Graph2Mat

This notebook will show you how to integrate a MACE model with Graph2Mat through the python API. Note that you can also use MACE+Graph2Mat through the Command Line Interface (CLI).

Prerequisites

Before reading this notebook, make sure you have read the notebook on computing a matrix and the notebook on batching, which introduce the basic concepts of graph2mat that we are going to assume are already known. Also we will use exactly the same setup as in the batching notebook, with the only difference that we will add target matrices to each structure.

[1]:
import numpy as np
import pandas as pd
import torch

# To load plotly templates for sisl visualization
import sisl.viz

from e3nn import o3

from graph2mat import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData

from graph2mat.bindings.e3nn import E3nnGraph2Mat

from graph2mat.tools.viz import plot_basis_matrix

Generating a dataset

We generate a dataset here just as we have done in the other notebooks.

[2]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")

# The basis table.
table = BasisTableWithEdges([point_1, point_2])

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

positions = np.array([[0, 0, 0], [6.0, 0, 0], [9, 0, 0]])

config1 = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=[point_1, point_2],
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

config2 = BasisConfiguration(
    point_types=["B", "A", "B"],
    positions=positions,
    basis=[point_1, point_2],
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
)

configs = [config1, config2]

dataset = TorchBasisMatrixDataset(configs, data_processor=processor)

from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=2)

data = next(iter(loader))

Initializing a MACE model

We will now initialize a normal MACE model.

Note that you must have MACE installed, which you can do with:

pip install mace_torch
[3]:
from mace.modules import MACE, RealAgnosticResidualInteractionBlock

num_interactions = 3
hidden_irreps = o3.Irreps("1x0e + 1x1o")

mace_model = MACE(
    r_max=10,
    num_bessel=10,
    num_polynomial_cutoff=10,
    max_ell=2,  # 1,
    interaction_cls=RealAgnosticResidualInteractionBlock,
    interaction_cls_first=RealAgnosticResidualInteractionBlock,
    num_interactions=num_interactions,
    num_elements=2,
    hidden_irreps=hidden_irreps,
    MLP_irreps=o3.Irreps("2x0e"),
    atomic_energies=torch.tensor([0, 0]),
    avg_num_neighbors=2,
    atomic_numbers=[0, 1],
    correlation=2,
    gate=None,
)
cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/mace/modules/blocks.py:312: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),

Now, we can pass our data through the mace model. MACE outputs many things, but we are just interested in the node features, which we can get from the "node_feats" key.

[4]:
mace_output = mace_model(data)
mace_output["node_feats"]
[4]:
tensor([[ 9.0013e-02,  0.0000e+00,  0.0000e+00,  1.1366e-03, -6.8797e-02,
          0.0000e+00,  0.0000e+00, -3.2418e-04, -6.9800e-02],
        [-9.0031e-01,  0.0000e+00,  0.0000e+00,  1.0144e-03,  5.5003e-01,
          0.0000e+00,  0.0000e+00,  5.0510e-04,  4.2952e-01],
        [ 1.0378e-01,  0.0000e+00,  0.0000e+00,  2.6012e-03, -7.8224e-02,
          0.0000e+00,  0.0000e+00, -7.5807e-04, -7.9285e-02],
        [-8.9246e-01,  0.0000e+00,  0.0000e+00,  7.0245e-04,  5.4509e-01,
          0.0000e+00,  0.0000e+00,  3.3149e-04,  4.2565e-01],
        [ 1.0681e-01,  0.0000e+00,  0.0000e+00, -3.7303e-03, -8.0749e-02,
          0.0000e+00,  0.0000e+00,  1.1233e-03, -8.1784e-02],
        [-8.9810e-01,  0.0000e+00,  0.0000e+00, -1.7167e-03,  5.4874e-01,
          0.0000e+00,  0.0000e+00, -8.3481e-04,  4.2851e-01]],
       grad_fn=<CatBackward0>)

Our Graph2Mat model will take these node features and convert them to a matrix. Therefore we need to know what its irreps are, and then initialize the Graph2Mat module.

[5]:
# MACE outputs as node features the hidden irreps for each interaction, except
# in the last interaction, where it computes just scalar features.
mace_out_irreps = hidden_irreps * (num_interactions - 1) + str(hidden_irreps[0])

# Initialize the matrix model with this information
matrix_model = E3nnGraph2Mat(
    unique_basis=table,
    irreps=dict(node_feats_irreps=mace_out_irreps),
    symmetric=True,
    # We would need to also implement passing the edge information in order to use
    # preprocessing_edges. As shown later, graph2mat can do this automatically for you.
    preprocessing_edges=None,
)

Now, we can use the matrix model, passing the node features computed by MACE:

[6]:
node_labels, edge_labels = matrix_model(data=data, node_feats=mace_output["node_feats"])

And plot the obtained matrices:

[7]:
matrices = processor.matrix_from_data(
    data,
    predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)

for config, matrix in zip(configs, matrices):
    plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    ).show()

Data type cannot be displayed: application/vnd.plotly.v1+json

Data type cannot be displayed: application/vnd.plotly.v1+json

Using MatrixMACE

If you don’t want to handle the details of interacting MACE with Graph2Mat, you can also use MatrixMACE, which takes a mace model and wraps it to also output the node_labels and edge_labels corresponding to a matrix.

Internally, it just initializes a E3nnGraph2Mat layer. However it can handle the interaction between MACE and Graph2Mat in more complex cases like having an extra preprocessing step for edges, which needs some extra inputs from MACE.

[8]:
from graph2mat.models import MatrixMACE
from graph2mat.bindings.e3nn import E3nnEdgeMessageBlock
[9]:
matrix_mace_model = MatrixMACE(
    mace_model,
    unique_basis=table,
    readout_per_interaction=True,
    edge_hidden_irreps=o3.Irreps("10x0e + 10x1o + 10x2e"),
    symmetric=True,
)

The output of this model is MACE’s output plus the node_labels and edge_labels for the predicted matrix:

[10]:
out = matrix_mace_model(data)

out
[10]:
{'energy': tensor([0.8221, 1.8579], grad_fn=<SumBackward1>),
 'node_energy': tensor([-0.0717,  0.9773, -0.0835,  0.9688, -0.0858,  0.9749],
        grad_fn=<SumBackward1>),
 'contributions': tensor([[ 0.0000,  0.0000,  0.9226,  0.2668, -0.3673],
         [ 0.0000,  0.0000,  2.1987,  0.6706, -1.0115]],
        grad_fn=<StackBackward0>),
 'forces': None,
 'edge_forces': None,
 'virials': None,
 'stress': None,
 'atomic_virials': None,
 'atomic_stresses': None,
 'displacement': tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]),
 'hessian': None,
 'node_feats': tensor([[ 9.0013e-02,  0.0000e+00,  0.0000e+00,  1.1366e-03, -6.8797e-02,
           0.0000e+00,  0.0000e+00, -3.2418e-04, -6.9800e-02],
         [-9.0031e-01,  0.0000e+00,  0.0000e+00,  1.0144e-03,  5.5003e-01,
           0.0000e+00,  0.0000e+00,  5.0510e-04,  4.2952e-01],
         [ 1.0378e-01,  0.0000e+00,  0.0000e+00,  2.6012e-03, -7.8224e-02,
           0.0000e+00,  0.0000e+00, -7.5807e-04, -7.9285e-02],
         [-8.9246e-01,  0.0000e+00,  0.0000e+00,  7.0245e-04,  5.4509e-01,
           0.0000e+00,  0.0000e+00,  3.3149e-04,  4.2565e-01],
         [ 1.0681e-01,  0.0000e+00,  0.0000e+00, -3.7303e-03, -8.0749e-02,
           0.0000e+00,  0.0000e+00,  1.1233e-03, -8.1784e-02],
         [-8.9810e-01,  0.0000e+00,  0.0000e+00, -1.7167e-03,  5.4874e-01,
           0.0000e+00,  0.0000e+00, -8.3481e-04,  4.2851e-01]],
        grad_fn=<CatBackward0>),
 'node_labels': tensor([ 1.1421e-03, -3.3615e-03, -1.7687e-02,  0.0000e+00,  0.0000e+00,
         -3.0053e-04, -1.7687e-02,  2.2053e-01,  0.0000e+00,  0.0000e+00,
         -1.5540e-04,  0.0000e+00,  0.0000e+00,  6.2363e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  6.2363e-02,
          0.0000e+00, -3.0053e-04, -1.5540e-04,  0.0000e+00,  0.0000e+00,
          6.2363e-02,  1.5125e-03, -3.2697e-03, -1.7366e-02,  0.0000e+00,
          0.0000e+00, -2.0601e-04, -1.7366e-02,  2.1663e-01,  0.0000e+00,
          0.0000e+00, -1.0374e-04,  0.0000e+00,  0.0000e+00,  6.1270e-02,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          6.1270e-02,  0.0000e+00, -2.0601e-04, -1.0374e-04,  0.0000e+00,
          0.0000e+00,  6.1270e-02,  1.6016e-03, -3.3589e-03, -1.7606e-02,
          0.0000e+00,  0.0000e+00,  5.0701e-04, -1.7606e-02,  2.1948e-01,
          0.0000e+00,  0.0000e+00,  2.5915e-04,  0.0000e+00,  0.0000e+00,
          6.2062e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  6.2062e-02,  0.0000e+00,  5.0701e-04,  2.5915e-04,
          0.0000e+00,  0.0000e+00,  6.2062e-02], grad_fn=<MeanBackward1>),
 'edge_labels': tensor([ 2.8645e-05, -3.5231e-05,  0.0000e+00,  0.0000e+00, -1.6517e-05,
         -1.0098e-05, -1.6300e-05,  0.0000e+00,  0.0000e+00,  3.1383e-05,
          3.3831e-05, -4.1323e-05,  0.0000e+00,  0.0000e+00,  1.9075e-05,
         -1.0373e-05, -1.6752e-05,  0.0000e+00,  0.0000e+00, -3.1856e-05,
         -9.8223e-07,  7.5770e-07,  0.0000e+00,  0.0000e+00,  9.5773e-08,
          7.5781e-07,  2.8189e-06,  0.0000e+00,  0.0000e+00, -2.6045e-08,
          0.0000e+00,  0.0000e+00,  6.0607e-07,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  6.0607e-07,  0.0000e+00,
         -9.6107e-08,  2.6127e-08,  0.0000e+00,  0.0000e+00,  1.0831e-06],
        grad_fn=<MeanBackward1>)}

You can of course plot the predicted matrices:

[11]:
matrices = processor.matrix_from_data(data, predictions=out)

for config, matrix in zip(configs, matrices):
    plot_basis_matrix(
        matrix,
        config,
        point_lines={"color": "black"},
        basis_lines={"color": "blue"},
        colorscale="temps",
        text=".2f",
        basis_labels=True,
    ).show()

Data type cannot be displayed: application/vnd.plotly.v1+json

Data type cannot be displayed: application/vnd.plotly.v1+json

Summary and next steps

In this notebook we learned how to interface MACE with Graph2Mat.

The next steps could be:

  • Train a MACE+Graph2Mat model following the steps in this notebook, replacing the model by the MACE+Graph2Mat model.