Fitting matrices
This notebook shows how you can fit your function to predict matrices for configurations. We create the target matrices synthetically.
Prerequisites
Before reading this notebook, make sure you have read the notebook on computing a matrix and the notebook on batching, which introduce the basic concepts of graph2mat
that we are going to assume are already known. Also we will use exactly the same setup as in the batching notebook, with the only difference that we will compute add target matrices to each structure.
In this notebook we will:
Introduce the addition of a target matrix to a configuration.
Introduce the metrics that can be used as loss functions.
Introduce the simplest training loop.
It is specially useful if you are quite new to machine learning, because it goes step by step. It also serves as a minimal example from which you can expand to create training flows different from the ones we propose.
[1]:
import numpy as np
import pandas as pd
import torch
# To load plotly templates for sisl visualization
import sisl.viz
from e3nn import o3
from graph2mat import (
BasisConfiguration,
PointBasis,
BasisTableWithEdges,
MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat
from graph2mat.tools.viz import plot_basis_matrix
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
Setting up the model
As usual, let’s create our model:
[2]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")
basis = [point_1, point_2]
# The basis table.
table = BasisTableWithEdges(basis)
# The data processor.
processor = MatrixDataProcessor(
basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])
# The shape of the node features.
node_feats_irreps = o3.Irreps("0e + 1o")
# The fake environment representation function that we will use
# to compute node features.
def get_environment_representation(data, irreps):
"""Function that mocks a true calculation of an environment representation.
Computes a random array and then ensures that the numbers obey our particular
system's symmetries.
"""
import torch
torch.manual_seed(0)
node_features = irreps.randn(data.num_nodes, -1)
# The point in the middle sees the same in -X and +X directions
# therefore its representation must be 0.
# In principle the +/- YZ are also equivalent, but let's say that there
# is something breaking the symmetry to make the numbers more interesting.
# Note that the spherical harmonics convention is YZX.
node_features[1, 3] = 0
# We make both A points have equivalent features except in the X direction,
# where the features are opposite
node_features[2::3, :3] = node_features[0::3, :3]
node_features[2::3, 3] = -node_features[0::3, 3]
return node_features
# The matrix readout function
model = E3nnGraph2Mat(
unique_basis=basis,
irreps=dict(node_feats_irreps=node_feats_irreps),
symmetric=True,
)
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
Including target matrices in the data
We will now create our data. The difference between this notebook and the previous notebooks is that each configuration will have an associated matrix, which is what we will try to fit.
Usually, this matrix would be computed by the algorithm we are trying to substitute with ML (e.g. DFT for atomic systems) or experimental observations, but here we will just take random matrices.
We create a function to compute random symmetric matrices:
[3]:
def true_matrix(size):
"""Mocks the algorithm that provides the training matrices.
It just computes a random matrix
"""
matrix = np.random.random((size, size)) * 2 - 1
matrix += matrix.T
return matrix
And then initialize the configurations as we have done in the previous notebooks, except that in this case we use the matrix
argument to pass the matrix associated with the configuration:
[4]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])
config1 = BasisConfiguration(
point_types=["A", "B", "A"],
positions=positions,
basis=basis,
cell=np.eye(3) * 100,
pbc=(False, False, False),
matrix=true_matrix(size=7),
)
config2 = BasisConfiguration(
point_types=["B", "A", "B"],
positions=positions,
basis=basis,
cell=np.eye(3) * 100,
pbc=(False, False, False),
matrix=true_matrix(size=11),
)
configs = [config1, config2]
# Create the dataset
dataset = TorchBasisMatrixDataset(configs, data_processor=processor)
We can take one example from the dataset and check that it now has point_labels
and edge_labels
, which contain the values of the matrix organized in the same way that are returned by Graph2Mat
:
[5]:
data_example = dataset[0]
data_example.point_labels, data_example.edge_labels
[5]:
(tensor([ 1.7335, 1.4600, 0.0292, -0.4449, 0.6652, -0.0883, 0.0292, -1.6100,
0.0986, -0.7116, -0.4372, -0.4449, 0.0986, -0.3698, -0.1068, -1.0903,
0.6652, -0.7116, -0.1068, -1.6045, 1.5662, -0.0883, -0.4372, -1.0903,
1.5662, 0.9022, -1.9510]),
tensor([ 0.8589, 0.8605, 1.2530, 0.7530, -1.0636, -0.1412, -0.6818, -1.1393,
0.1295, 0.3016]))
During training, we will compare these to the output of Graph2Mat
.
We can also plot the target matrices from the data example:
[6]:
def plot_matrices(data, predictions=None, title="", show=True):
"""Helper function to plot (possibly batched) matrices"""
matrices = processor.matrix_from_data(data, predictions=predictions)
if not isinstance(matrices, (tuple, list)):
matrices = (matrices,)
for i, (config, matrix) in enumerate(zip(configs, matrices)):
if show is True or show == i:
plot_basis_matrix(
matrix,
config,
point_lines={"color": "black"},
basis_lines={"color": "blue"},
colorscale="temps",
text=".2f",
basis_labels=True,
).update_layout(title=f"{title} [{i}]").show()
plot_matrices(data_example, title="Labels")
The simplest training loop
Below we just create a simple pytorch
training loop that:
Uses the model to compute predictions for the matrix
Computes the loss (error).
Computes the gradients and updates the model parameters.
Goes back to 1.
While doing so we store the errors at each step so that we can plot their evolution later.
There is just one last thing that we need to introduce: graph2mat
’s metrics. The metrics
module contains several functions that compare matrices in different ways. They can be used as loss functions. In this case, we will use elementwise_mse
, which just computes the Mean Squared Error of all the matrix elements.
[7]:
# Create the data loader
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=2)
# Number of training steps
n_steps = 4000
# Initialize an optimizer
optimizer = torch.optim.Adam(model.parameters())
# Initialize arrays to store errors
losses = np.zeros(n_steps)
node_rmse = np.zeros(n_steps)
edge_rmse = np.zeros(n_steps)
# The loss function, which we get from graph2mat's metrics functions
from graph2mat import metrics
loss_fn = metrics.elementwise_mse
# Loop
for i in range(n_steps):
for data in loader:
# Reset gradients
optimizer.zero_grad()
# Get the node feats. Since this function is not learnable, it could be
# outside the loop, but we keep it here to show how things could work
# with a learnable environment representation.
node_feats = get_environment_representation(data, node_feats_irreps)
# Make predictions for this batch
step_predictions = model(data, node_feats=node_feats)
# Compute the loss
loss, info = loss_fn(
nodes_pred=step_predictions[0],
nodes_ref=data.point_labels,
edges_pred=step_predictions[1],
edges_ref=data.edge_labels,
)
# Store errors
losses[i] = loss
node_rmse[i] = info["node_rmse"]
edge_rmse[i] = info["edge_rmse"]
# Compute gradients
loss.backward()
# Update weights
optimizer.step()
Checking results
After training, we store all the errors in a dataframe:
[8]:
df = pd.DataFrame(
np.array([losses, node_rmse, edge_rmse]).T,
columns=["loss", "node_rmse", "edge_rmse"],
)
And plot them:
[9]:
df.plot(backend="plotly").update_layout(
yaxis_type="log", yaxis_showgrid=True, xaxis_showgrid=True
).update_layout(
yaxis_title="Value",
xaxis_title="Training step",
title="Error evolution during training",
)
The model has learned something, but still the errors are quite high.
We can plot the first target matrix and the corresponding prediction:
[10]:
plot_matrices(data, title=f"Target matrix", show=0)
plot_matrices(
data,
predictions={
"node_labels": step_predictions[0],
"edge_labels": step_predictions[1],
},
title=f"Prediction after {n_steps} training steps",
show=0,
)
As you can see, the matrices are very different. That is, the model has no idea how to predict the matrices!
This could be shocking considering that it has only been tasked with fitting 2 matrices, a super simple problem that any model would overfit without any trouble. Well, you must take into account two things:
The target matrix is random, while the model is designed to learn equivariant matrices!. All operations are equivariant and therefore result into an equivariant predicted matrix. For example, symmetry determines that the scalar element for node blocks for points 0 and 2 (at the top-left and bottom-right corner of the matrix) must be exactly the same because the point are equivalent. The random matrix does not satisfy this condition so it is impossible to fit.
The model is limited by the input node features, which only contain one scalar and one vector. The combination possibilities are very small. If you increase the node feats irreps to
0e + 2x1o
(i.e. add one extra vector) and modify theget_environment_representation
to still satisfy symmetries you should see some elements that have no symmetry problems (e.g. the 4 scalar elements at the top-left corner of node block for point 1) get very close to the target matrix.
We could work very hard to make our fake environment and true matrix computing functions equivariant to see the model fit perfectly, but you will see this in other real-life examples in the tutorials. Also it is nice to see how a random matrix can’t be fitted by an equivariant model to understand the power of equivariant design!
[ ]: