Batching equivariant matrices
This notebook introduces you to one aspect of generating matrices that you will inevitably face when training a model: batching.
Prerequisites
Before reading this notebook, make sure you have read the notebook on computing a matrix, which introduces all of the most basic concepts of graph2mat
that we are going to assume are already known. Also we will use exactly the same setup, with the only difference that we will compute two matrices at the same time instead of just one.
[1]:
import numpy as np
# So that we can plot sisl geometries
import sisl.viz
from e3nn import o3
from graph2mat import (
PointBasis,
BasisTableWithEdges,
BasisConfiguration,
MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixData, TorchBasisMatrixDataset
from graph2mat.bindings.e3nn import E3nnGraph2Mat
from graph2mat.tools.viz import plot_basis_matrix
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
The matrix-computing function
As we have already seen in the notebook on computing a matrix, we need to define a basis, a basis table, a data processor and the shape of the node features. With all this, we can initialize the matrix-computing function. We define everything exactly as in the other notebook:
[2]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")
basis = [point_1, point_2]
# The basis table.
table = BasisTableWithEdges(basis)
# The data processor.
processor = MatrixDataProcessor(
basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)
# The shape of the node features.
node_feats_irreps = o3.Irreps("0e + 1o")
# The fake environment representation function that we will use
# to compute node features.
def get_environment_representation(data, irreps):
"""Function that mocks a true calculation of an environment representation.
Computes a random array and then ensures that the numbers obey our particular
system's symmetries.
"""
node_features = irreps.randn(data.num_nodes, -1)
# The point in the middle sees the same in -X and +X directions
# therefore its representation must be 0.
# In principle the +/- YZ are also equivalent, but let's say that there
# is something breaking the symmetry to make the numbers more interesting.
# Note that the spherical harmonics convention is YZX.
node_features[1, 3] = 0
# We make both A points have equivalent features except in the X direction,
# where the features are opposite
node_features[2::3, :3] = node_features[0::3, :3]
node_features[2::3, 3] = -node_features[0::3, 3]
return node_features
# The matrix readout function
model = E3nnGraph2Mat(
unique_basis=basis,
irreps=dict(node_feats_irreps=node_feats_irreps),
symmetric=True,
)
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
Creating two configurations
Now, we will create two configurations instead of one. Both will have the same coordinates, the only difference will be that we will swap the point types. However, you could give different coordinates to each of them as well, or a different number of atoms.
We’ll store both configurations in a configs
list.
[3]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])
config1 = BasisConfiguration(
point_types=["A", "B", "A"],
positions=positions,
basis=basis,
cell=np.eye(3) * 100,
pbc=(False, False, False),
)
config2 = BasisConfiguration(
point_types=["B", "A", "B"],
positions=positions,
basis=basis,
cell=np.eye(3) * 100,
pbc=(False, False, False),
)
configs = [config1, config2]
As we did in the other notebook, we plot the configurations to see how they look like, and visualize the overlaps:
[4]:
geom1 = config1.to_sisl_geometry()
geom1.plot(show_cell=False, atoms_style={"size": geom1.maxR(all=True)}).update_layout(
title="Config 1"
).show()
geom2 = config2.to_sisl_geometry()
geom2.plot(show_cell=False, atoms_style={"size": geom2.maxR(all=True)}).update_layout(
title="Config 2"
).show()
Build a dataset
With all our configurations, we can create a dataset. The specific class that does this is the TorchBasisMatrixDataset
, which apart from the configurations needs the data processor as usual.
[5]:
dataset = TorchBasisMatrixDataset(configs, data_processor=processor)
This dataset contains all the configurations. We now just need some tool to create batches from it.
Batching with a DataLoader
TorchBasisMatrixDataset
is just an extension of torch.utils.data.Dataset
. Therefore, you don’t need a graph2mat
specific tool to create batches. In fact, we recommend that you use torch_geometric
’s DataLoader
:
[6]:
from torch_geometric.loader import DataLoader
Everything that you need to do is: pass the dataset and specify some batch size.
[7]:
loader = DataLoader(dataset, batch_size=2)
In this case we use a batch size of 2
, which is our total number of configurations. Therefore, we will only have one batch.
Let’s loop through the batches (only 1) and print them:
[8]:
for data in loader:
print(data)
TorchBasisMatrixDataBatch(
metadata={ data_processor=[2] },
edge_index=[2, 8],
num_nodes=6,
neigh_isc=[8],
n_edges=[2],
positions=[6, 3],
shifts=[8, 3],
cell=[6, 3],
nsc=[2, 3],
node_attrs=[6, 2],
point_types=[6],
edge_types=[8],
edge_type_nlabels=[2, 3],
batch=[6],
ptr=[3]
)
Calling the function
We now have our batch object, data
. It is a Batch
object. In the previous notebook, we called the function from a BasisMatrixTorchData
object. One might think that having batched data might make it more complicated to call the function.
However, it is exactly the same code that you have to use to compute matrices in a batch. First, of course, we need to get our inputs, which we generate artificially here (in the batch we have 6 nodes, each of them needs a scalar and a vector):
[9]:
node_inputs = get_environment_representation(data, node_feats_irreps)
node_inputs
[9]:
tensor([[-1.5750, -0.8732, 0.7904, 0.4111],
[-1.6839, 0.9178, 0.8768, 0.0000],
[-1.5750, -0.8732, 0.7904, -0.4111],
[ 0.2914, 0.8705, 0.7386, 0.5319],
[-0.9043, -0.5457, -0.1159, -0.2057],
[ 0.2914, 0.8705, 0.7386, -0.5319]])
And from them, we compute the matrices. We use the inputs as well as the preprocessed data in the batch, with exactly the same code that we have already seen:
[10]:
node_labels, edge_labels = model(data, node_feats=node_inputs)
Disentangling the batch
As simple as it is to run a batched calculation, disentangling everything back into individual cases is harder. It is even harder in our case, in which we have batched sparse matrices.
Not only you have to handle the indices of the sparsity pattern, but also the additional aggregation of the batches. This is the reason why in the BasisMatrixData
objects you can see so many pointer arrays. They are needed to keep track of the organization.
Making use of those indices, the data processor can disentangle the batch and give you the individual cases. You’ll be happy to see that you can call the matrix_from_data
method of the processor, just as you did with the single matrix case, and it will return a tuple
of matrices instead of just one:
[11]:
matrices = processor.matrix_from_data(
data,
predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)
matrices
[11]:
(<Compressed Sparse Row sparse array of dtype 'float32'
with 47 stored elements and shape (7, 7)>,
<Compressed Sparse Row sparse array of dtype 'float32'
with 71 stored elements and shape (11, 11)>)
Note
matrix_from_data
has automatically detected that the data passed was a torch_geometric
’s Batch
object. There’s also the is_batch
argument to explicitly indicate if it is a batch or not. Also, the processor has the yield_from_batch
method, which is more explicit and will return a generator instead of a tuple, which is better for very big matrices if you want to process them individually.
As we already did in the previous notebook, we can plot the matrices:
[12]:
for config, matrix in zip(configs, matrices):
plot_basis_matrix(
matrix,
config,
point_lines={"color": "black"},
basis_lines={"color": "blue"},
colorscale="temps",
text=".2f",
basis_labels=True,
).show()
Try to relate the matrices to the systems we created and see if their shape makes sense :)
Summary and next steps
In this notebook we learned how to batch systems and then use the data processor to unbatch them.
The next steps could be:
Understanding how to train the function to produce the target matrix. See this notebook.
Combining this function with other modules for your particular application.
[ ]: