Source code for graph2mat.core.modules.matrixblock

import numpy as np
from typing import Type, Tuple, TypeVar
from types import ModuleType

from ..data.basis import PointBasis

__all__ = ["MatrixBlock"]

# This type will be
ArrayType = TypeVar("ArrayType")


[docs] class MatrixBlock: """Computes a fixed size matrix coming from the product of spherical harmonics. There are two things to note: - It computes a dense matrix. - It computes a fixed size matrix. It takes care of: - Determining what are the irreps needed to reproduce a certain block. - Converting from those irreps to the actual values of the block using the appropiate change of basis. This module doesn't implement any computation, so you need to pass one as stated in the ``operation`` parameter. Parameters ----------- i_irreps: o3.Irreps The irreps of the matrix rows. j_irreps: o3.Irreps The irreps of the matrix columns. symmetry: str Symmetries that this matrix is expected to have. This should be indicated as documented in `e3nn.o3.ReducedTensorProducts`. As an example, for a symmetric matrix you would pass "ij=ji" here. operation_cls: Type[torch.nn.Module] Torch module used to actually do the computation. On initialization, it will receive the `irreps_out` argument from this module, specifying the shape of the output that it should produce. On forward, this module will just be a wrapper around the operation, so you should pass whatever arguments that the operation expects. **operation_kwargs: dict Any arguments needed for the initialization of the `operation_cls`. Returns ----------- matrix: ArrayType A 2D tensor of shape (i_irreps.dim, j_irreps.dm) containing the output matrix. """ block_shape: Tuple[int, int] block_size: int symm_transpose: bool numpy: ModuleType = np
[docs] def __init__( self, i_basis: PointBasis, j_basis: PointBasis, operation_cls: Type, symm_transpose: bool = False, preprocessor=None, **operation_kwargs, ): super().__init__() self.symm_transpose = symm_transpose self.operation = operation_cls( i_basis=i_basis, j_basis=j_basis, **operation_kwargs )
def _compute_block(self, *args, **kwargs): return self.operation(*args, **kwargs)
[docs] def forward(self, *args, **kwargs): if self.symm_transpose == False: return self._compute_block(*args, **kwargs) else: forward = self._compute_block(*args, **kwargs) back_args = [ (arg[1], arg[0]) if isinstance(arg, tuple) and len(arg) == 2 else arg for arg in args ] back_kwargs = { key: (value[1], value[0]) if isinstance(value, tuple) and len(value) == 2 else value for key, value in kwargs.items() } backward = self._compute_block(*back_args, **back_kwargs) return (forward + backward.transpose(-1, -2)) / 2
[docs] def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)