Computing an equivariant matrix
In this notebook we guide you through your first steps on using graph2mat
to compute an equivariant matrix.
Our goal will be to compute a matrix from the coordinates of some points in space.
In particular we will use a version of Graph2Mat
that is designed to deal with e3nn
’s conventions: E3nnGraph2Mat
.
We will have to follow the next steps:
Create a function to compute the matrix.
Get the coordinates of the system .
Preprocess the system’s data to make it usable by the function.
Generate some input for the function.
Call the function.
Postprocess the output to get the matrix.
[1]:
import numpy as np
# So that we can plot sisl geometries
import sisl.viz
from e3nn import o3
from graph2mat import (
PointBasis,
BasisTableWithEdges,
BasisConfiguration,
MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixData
from graph2mat.bindings.e3nn import E3nnGraph2Mat
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
_Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
Create a function to generate matrices
In this section, we focus on the things you need to create a function to compute equivariant matrices.
There are three things that you need to know about your problem:
The basis functions. Each point will have a set of basis functions \(\phi_\mu\) that look something like \(\phi_\mu = R(r)Y_{\ell}^m(\theta, \varphi)\), where \(Y_\ell^m\) are the spherical harmonics. Most likely, you will have points of different types, and each type will have a given set of basis functions. E.g. different order (\(\ell\)) of spherical harmonics, or different number of sets for a given \(\ell\). In any case, you must know beforehand all the unique basis sets that you will use in your problem.
The shape of the inputs. What are the inputs from which you will compute the matrix? Are they scalars, are they vectors, higher order spherical harmonics…? How many of them will you have? This information is all condensed into an irreps specification that you will pass to the function creation.
The symmetries of your output matrix. Is it symmetric? Is each point-point block symmetric?
Define your basis
The first thing to do is to understand which basis functions will you face in your problem.
Let’s say that we know that all the systems that we will deal with have two different types of points:
A, which has only an \(\ell=0\) basis function with a range of
2
.B, which has two \(\ell=0\) basis function and a set of \(\ell=1\) basis functions with a range of
5
.
We need to create a PointBasis
for each of the types:
[2]:
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical") # "0e"
point_2 = PointBasis(
"B", R=5, basis="2x0e + 1o", basis_convention="spherical"
) # "2x0e + 1o"
basis = [point_1, point_2]
For the basis specification, we have decided to follow e3nn
’s string specification for irreps, where in practical terms:
0e
means spherical harmonics for \(\ell=0\).1o
means spherical harmonics for \(\ell=1\)2x
means 2 sets of the given spherical harmonics.+
just merges the multiple spherical harmonics together.
PointBasis
’s basis specification can also accept a list like [2, 1]
, meaning 2 \(\ell=0\) spherical harmonics and 1 set of \(\ell=1\) spherical harmonics.
Note
The basis definition is not specific to using e3nn’s bindings through E3nnGraph2Mat
, we would have defined the basis like this even if we were using the raw Graph2Mat
.
Define the shape of the inputs
The function expects a point-wise input. That is, one input for each point.
For e3nn
operations, you need to know what is the shape of this input and what each number means. You must ask yourself two questions:
What kind of inputs will you receive? Scalars, vectors, higher order spherical harmonics…?
How many of them will you receive?
In this example, we are going to keep it simple and say that for each node we will pass one scalar and one vector. We just need to define an e3nn
Irreps
object with the appropiate irreps.
[3]:
# The irreps of the node features that we will input into the model
# One scalar (0e) and one vector (1o)
node_feats_irreps = o3.Irreps("0e + 1o")
Initialize the module
Armed with all the information about our inputs and outputs, we can finally create our matrix generating function.
It is now that we will be able to initialize an E3nnGraph2Mat
function. For the simplest usage, we just need to pass:
unique_basis
: The list ofPointBasis
that the function should be able to deal with.irreps
: A dictionary containing the irreps for all relevant features that the model will deal with. In this case we will just use node features, so we just need to passnode_feats_irreps
.symmetric
: Whether our target matrices are symmetric.
[4]:
model = E3nnGraph2Mat(
unique_basis=basis,
irreps=dict(node_feats_irreps=node_feats_irreps),
symmetric=True,
)
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn(
We now have our first matrix model!
We can explore it. Let’s use its summary
property:
[5]:
print(model.summary)
Preprocessing nodes: None
Preprocessing edges: None
Node operations:
(A) E3nnSimpleNodeBlock: (1x0e) x (1x0e) -> 1x0e
(B) E3nnSimpleNodeBlock: (2x0e+1x1o) x (2x0e+1x1o) -> 4x0e+2x1o+1x2e
Edge operations:
(A, A) [XY = YX.T] E3nnSimpleEdgeBlock: (1x0e) x (1x0e) -> 1x0e.
(A, B) E3nnSimpleEdgeBlock: (1x0e) x (2x0e+1x1o) -> 2x0e+1x1o.
(B, B) [XY = YX.T] E3nnSimpleEdgeBlock: (2x0e+1x1o) x (2x0e+1x1o) -> 5x0e+4x1o+1x1e+1x2e.
You can see that the module created 5 different operations:
Two node operations: They will compute the blocks corresponding to interactions within the same point.
Three edge operations: They will compute the blocks corresponding to interactions between different points.
Note that the summary also prints the irreps of each point basis involved and the output needed to generate the corresponding block.
It also indicates with [XY = YX.T]
if the operation returns the transpose block when you commute factors.
However, this short summary doesn’t tell us exactly what operations are performed. Since E3nnGraph2Mat
is a torch
module, its representation will show us what is exactly the anatomy of the operation:
[6]:
model
[6]:
E3nnGraph2Mat(
(self_interactions): ModuleList(
(0): E3nnIrrepsMatrixBlock(
(operation): E3nnSimpleNodeBlock(
(tsq): TensorSquare(1x0e+1x1o -> 1x0e | 2 paths | 2 weights)
)
)
(1): E3nnIrrepsMatrixBlock(
(operation): E3nnSimpleNodeBlock(
(tsq): TensorSquare(1x0e+1x1o -> 4x0e+2x1o+1x2e | 11 paths | 11 weights)
)
)
)
(interactions): ModuleDict(
((0, 0, 0)): E3nnIrrepsMatrixBlock(
(operation): E3nnSimpleEdgeBlock(
(tensor_products): ModuleList(
(0): FullyConnectedTensorProduct(1x0e+1x1o x 1x0e+1x1o -> 1x0e | 2 paths | 2 weights)
)
)
)
((0, 1, 1)): E3nnIrrepsMatrixBlock(
(operation): E3nnSimpleEdgeBlock(
(tensor_products): ModuleList(
(0): FullyConnectedTensorProduct(1x0e+1x1o x 1x0e+1x1o -> 2x0e+1x1o | 6 paths | 6 weights)
)
)
)
((1, 1, 2)): E3nnIrrepsMatrixBlock(
(operation): E3nnSimpleEdgeBlock(
(tensor_products): ModuleList(
(0): FullyConnectedTensorProduct(1x0e+1x1o x 1x0e+1x1o -> 5x0e+4x1o+1x1e+1x2e | 20 paths | 20 weights)
)
)
)
)
)
Try to relate this representation with the summary and identify the role of each input in it. For example:
Where is
node_feats_irreps
in this representation?Why are the output irreps different for each type of block?
We encourage you to play with the three arguments and see if they have the influence that you expected on the summary and the architecture of the function.
We have our model, now we are only missing the data!
Coordinates of a system
Let’s say we have to predict a matrix for three interacting points in space: two A points at [0,0,0]
and [6, 0, 0]
and a B point at [11, 0, 0]
.
Something like: (A)—(B)–(A).
First, we create the positions array:
[7]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12.0, 0, 0]])
And from it, we will create a BasisConfiguration
, which apart from positions contains information about the basis, the cell and the boundaries.
[8]:
config = BasisConfiguration(
point_types=["A", "B", "A"],
positions=positions,
basis=basis,
cell=np.eye(3) * 100,
pbc=(False, False, False),
)
Note
The configuration could also store an associated matrix (e.g. the target matrix), however we are not going to use it for now.
Let’s see what this configuration looks like. We can convert it to a sisl
geometry and plot it (or you could also plot the points yourself):
[9]:
geometry = config.to_sisl_geometry()
geometry.plot(show_cell=False, atoms_style={"size": geometry.maxR(all=True)})
In gray you can see B atoms and in blue you can see the A atom . Their sizes are set according to their ranges, so you can see which points overlap with which. This will become important when we interpret the matrix!
Preprocessing the data
Now, we need to preprocess the data to make it digestible by our matrix-generating function.
For that, we initialize a MatrixDataProcessor
that will take care of all the processing. This object contains all the information to correctly process the data, and it exists to make sure that all the processing is consistent (you don’t need to store all the different parameters separately, which avoids mistakes when using data processing routines). It needs:
A basis table (
BasisTableWithEdges
), which determines all the node and edge types that are possible to find given our basis. It also knows the size of the blocks, and other type dependent variables.Some information about the matrix, which will be used to appropiately pre and postprocess matrices.
First let’s create the basis table and check that it contains all the information about the basis:
[10]:
# Create the basis table.
table = BasisTableWithEdges(basis)
table
[10]:
Index | Type | Irreps | Max R |
---|---|---|---|
0 | A | 1x0e | 2 |
1 | B | 2x0e + 1x1o | 5 |
Then we can create the processor:
[11]:
# Initialize the processor.
processor = MatrixDataProcessor(
basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)
Armed with a processor and the data we need to process, we can already initialize a TorchBasisMatrixData
object, which will parse and store all the data already in the shape that the torch module expects it.
[12]:
data = TorchBasisMatrixData.from_config(config, processor)
data
[12]:
TorchBasisMatrixData(
edge_index=[2, 4],
num_nodes=3,
neigh_isc=[4],
n_edges=4,
positions=[3, 3],
shifts=[4, 3],
cell=[3, 3],
nsc=[1, 3],
node_attrs=[3, 2],
point_types=[3],
edge_types=[4],
edge_type_nlabels=[1, 3],
metadata={ data_processor=MatrixDataProcessor(basis_table=BasisTableWithEdges(spherical, basis=[PointBasis(type='A', R=2, basis=((1, 0, 1),), basis_convention='spherical'), PointBasis(type='B', R=5, basis=((2, 0, 1), (1, 1, -1)), basis_convention='spherical')]), symmetric_matrix=True, sub_point_matrix=False, out_matrix=None, node_attr_getters=[]) }
)
This TorchBasisMatrixData
is just an extension of torch_geometric
’s Data
.
Note
We can batch several configurations, but in this notebook our objective is simply to compute a matrix for one configuration.
Executing the module
The information of the system is now prepared to be passed to the function!
We are only missing a very important thing, the input!
Remember that we specified the input of our function to be of shape o3.Irreps(0e + 1o)
. Therefore, we need an input that is one scalar and one vector for each node.
This could be anything really. To keep it simple, we will create a “fake” function that computes some environment represenation and use it.
Note
In practice, you would use a function that computes a true environment representation. If that representation is equivariant, the symmetry constraints will be automatically satisfied.
[13]:
def get_environment_representation(data, irreps):
"""Function that mocks a true calculation of an environment representation.
Computes a random array and then ensures that the numbers obey our particular
system's symmetries.
"""
node_features = irreps.randn(data.num_nodes, -1)
# The point in the middle sees the same in -X and +X directions
# therefore its representation must be 0.
# In principle the +/- YZ are also equivalent, but let's say that there
# is something breaking the symmetry to make the numbers more interesting.
# Note that the spherical harmonics convention is YZX.
node_features[1, 3] = 0
# We make both A points have equivalent features except in the X direction,
# where the features are opposite
node_features[-1, :3] = node_features[0, :3]
node_features[-1, 3] = -node_features[0, 3]
return node_features
# Get the environment representation.
node_inputs = get_environment_representation(data, node_feats_irreps)
node_inputs
[13]:
tensor([[ 0.6541, -0.7908, -0.3080, 0.9873],
[-1.3978, 0.6041, -0.7304, 0.0000],
[ 0.6541, -0.7908, -0.3080, -0.9873]])
And now we can call the function to get a matrix!
The function needs two things:
The structural description of the graph. This is stored already in our
TorchBasisMatrixData
object, which we have under thedata
variable.The computed node features. We have that under the
node_inputs
variable.
[14]:
node_labels, edge_labels = model(data, node_feats=node_inputs)
Let’s see what we received as output:
[15]:
print("NODE LABELS: ", node_labels)
print("EDGE LABELS:", edge_labels)
NODE LABELS: tensor([ 0.4701, 0.5614, -0.0310, -0.0931, 0.1125, 0.0000, -0.0310, -0.2055,
-0.0047, 0.0057, 0.0000, -0.0931, -0.0047, 0.3242, -0.5045, 0.0000,
0.1125, 0.0057, -0.5045, 0.5170, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0930, 0.4701], grad_fn=<CatBackward0>)
EDGE LABELS: tensor([-0.3763, -0.4739, 1.3026, -0.0658, -1.1786, -0.3763, -0.4739, 1.3026,
-0.0658, 1.1786], grad_fn=<CopyBackwards>)
We expected our model to produce a matrix, and instead we get two flat arrays!
Don’t worry, it’s just a different representation of the matrix that is much more convenient for the function to compute. When training a model using this function, it is possible that you don’t actually need to convert this to a “real” matrix. That’s why the arrays are returned like this.
In our particular case, however, we are on a mission to get the matrix, so we need to do some simple post processing.
Note
These tensors have been computed with pytorch operations, so they keep track of the operations performed. Therefore, gradients can be computed either from these tensors or from further tensors that you compute with them.
Post processing
This step is simple. Remember we created a MatrixDataProcessor
? It’s time to put it to use!
The processor has a matrix_from_data
method that given:
The information of the configuration, in the form of the preprocessed
TorchBasisMatrixData
object.The output of the function to the
predictions
argument.
will return the actual sparse matrix:
[16]:
matrix = processor.matrix_from_data(
data,
predictions={"node_labels": node_labels, "edge_labels": edge_labels},
)
matrix
[16]:
<Compressed Sparse Row sparse array of dtype 'float32'
with 47 stored elements and shape (7, 7)>
This is a scipy
sparse matrix. If you are not familiar with sparse matrices, they are just an efficient way of storing matrices with many zeros.
You can also specify the out_format
argument to for any other supported output format (see the graph2mat.Formats
documentation for all the supported formats). For example, you can ask for a torch tensor:
[17]:
processor.matrix_from_data(
data,
predictions={"node_labels": node_labels, "edge_labels": edge_labels},
out_format="torch",
)
[17]:
tensor([[ 0.4701, -0.3763, -0.4739, 1.3026, -0.0658, -1.1786, 0.0000],
[-0.3763, 0.5614, -0.0310, -0.0931, 0.1125, 0.0000, -0.3763],
[-0.4739, -0.0310, -0.2055, -0.0047, 0.0057, 0.0000, -0.4739],
[ 1.3026, -0.0931, -0.0047, 0.3242, -0.5045, 0.0000, 1.3026],
[-0.0658, 0.1125, 0.0057, -0.5045, 0.5170, 0.0000, -0.0658],
[-1.1786, 0.0000, 0.0000, 0.0000, 0.0000, -0.0930, 1.1786],
[ 0.0000, -0.3763, -0.4739, 1.3026, -0.0658, 1.1786, 0.4701]],
grad_fn=<ToDenseBackward0>)
And we also provide plot_basis_matrix
, a nice tool to quickly visualize the matrix and understand what you got.
[18]:
from graph2mat.tools.viz import plot_basis_matrix
plot_basis_matrix(
matrix,
config,
point_lines={"color": "black"},
basis_lines={"color": "blue"},
colorscale="temps",
text=".3f",
basis_labels=True,
)
The black lines delimit blocks of the matrix that correspond to the same point-point interaction, and the blue dashed lines delimit the blocks of interaction between sets of basis functions.
The rows and columns are labeled as \(P: (l, m)\) where \(P\) is the index of the point and \(l\), \(m\) are the indices of the spherical harmonics.
There are some important things to note:
There are two white squares. These correspond to values of the matrix that were not set. Which points are interacting for those elements? Does it make sense that we have a blank space there then?
Look at the interactions between points 0 and 1 and compare them to those between points 1 and 2. How similar are they? Does it make sense?
From the previous point, you will conclude that the reason there’s something special is because the structure is symmetric. Try to move then the third point (change its position) to see if something changes. You will see that nothing changes. This is because we have used
E3nnGraph2Mat
with its simplest settings, its defaults (block operationsE3nnSimpleNodeBlock
andE3nnSimpleEdgeBlock
). It simply trusts that the inputs contain all the important information and combines them to generate the matrix. More complex block operations that use edge distances, directions, etc… can be used withinE3nnGraph2Mat
, see its documentation to understand how.As a conclusion from the previous point, we can extract that if we change the inputs, that should induce changes in the matrix. Try to change our fake
get_environment_representation
function, keeping in mind that the first number is a scalar and the other three a vector. Maybe try to rotate the vector and see what happens. You will discover thatE3nnGraph2Mat
is an equivariant function.
Summary and next steps
In this notebook we learned the whole process to go from the coordinates of some points in space to an equivariant matrix.
The next steps could be:
Understanding how to compute multiple matrices with the same function call (batching). See this notebook.
Understanding how to train the function to produce the target matrix. See this notebook
Combining this function with other modules for your particular application.
[ ]: