Skip to content

GMMConv forward pass does not support optional edge_attr #10143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lilian-83 opened this issue Mar 26, 2025 · 0 comments · May be fixed by #10182
Open

GMMConv forward pass does not support optional edge_attr #10143

lilian-83 opened this issue Mar 26, 2025 · 0 comments · May be fixed by #10182
Labels

Comments

@lilian-83
Copy link

🐛 Describe the bug

edge_attr.size() is accessed even when edge_attr is not specified as argument in the forward pass (see code below) and according to docs the forward pass has "edge features $(|\mathcal{E}|, D)$ (optional)" and edge_attr: OptTensor = None.

from torch_geometric.utils import to_torch_coo_tensor
x1 = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_coo_tensor(edge_index, size=(4, 4))

conv1 = pyg_nn.GMMConv(8, 12, dim=4, kernel_size=4)

print(conv1(x1, adj1))

which gives:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[612], [line 8](vscode-notebook-cell:?execution_count=612&line=8)
      [4](vscode-notebook-cell:?execution_count=612&line=4) adj1 = to_torch_coo_tensor(edge_index, size=(4, 4))
      [6](vscode-notebook-cell:?execution_count=612&line=6) conv1 = pyg_nn.GMMConv(8, 12, dim=4, kernel_size=4)
----> [8](vscode-notebook-cell:?execution_count=612&line=8) print(conv1(x1, adj1))

File ~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1737](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1737)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1738](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1738) else:
-> [1739](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1739)     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   [1745](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1745) # If we don't have any hooks, we want to skip the rest of the logic in
   [1746](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1746) # this function, and just call forward.
   [1747](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1747) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1748](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1748)         or _global_backward_pre_hooks or _global_backward_hooks
   [1749](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1749)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1750](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1750)     return forward_call(*args, **kwargs)
   [1752](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1752) result = None
   [1753](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch/nn/modules/module.py:1753) called_always_called_hooks = set()

File ~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:139, in GMMConv.forward(self, x, edge_index, edge_attr, size)
    [137](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:137) if not self.separate_gaussians:
    [138](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:138)     out: OptPairTensor = (torch.matmul(x[0], self.g), x[1])
--> [139](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:139)     out = self.propagate(edge_index, x=out, edge_attr=edge_attr,
    [140](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:140)                          size=size)
    [141](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:141) else:
    [142](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:142)     out = self.propagate(edge_index, x=x, edge_attr=edge_attr,
    [143](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:143)                          size=size)

File /var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:183, in propagate(self, edge_index, x, edge_attr, size)
    [174](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:174)             kwargs = CollectArgs(
    [175](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:175)                 x_j=hook_kwargs['x_j'],
    [176](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:176)                 edge_attr=hook_kwargs['edge_attr'],
   (...)
    [179](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:179)                 dim_size=kwargs.dim_size,
    [180](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:180)             )
    [181](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:181) # End Message Forward Pre Hook #########################################
--> [183](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:183) out = self.message(
    [184](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:184)     x_j=kwargs.x_j,
    [185](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:185)     edge_attr=kwargs.edge_attr,
    [186](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:186) )
    [188](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:188) # Begin Message Forward Hook ###########################################
    [189](https://file+.vscode-resource.vscode-cdn.net/var/folders/wz/xhstnmw54gbfx1t5ggvjy1vw0000gn/T/torch_geometric.nn.conv.gmm_conv_GMMConv_propagate_o724hgti.py:189) if not torch.jit.is_scripting() and not is_compiling():

File ~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:157, in GMMConv.message(self, x_j, edge_attr)
    [155](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:155) EPS = 1e-15
    [156](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:156) F, M = self.rel_in_channels, self.out_channels
--> [157](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:157) (E, D), K = edge_attr.size(), self.kernel_size
    [159](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:159) if not self.separate_gaussians:
    [160](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:160)     gaussian = -0.5 * (edge_attr.view(E, 1, D) -
    [161](https://file+.vscode-resource.vscode-cdn.net/Users/lbonnet/Desktop/ETHZ/thesis/repos/implementation/~/miniconda3/envs/thesis/lib/python3.12/site-packages/torch_geometric/nn/conv/gmm_conv.py:161)                        self.mu.view(1, K, D)).pow(2)

AttributeError: 'NoneType' object has no attribute 'size'

Versions

PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.3.2 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 12:55:12) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.3.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] numpy==2.1.0
[pip3] pytorch-lightning==2.5.0.post0
[pip3] torch==2.6.0
[pip3] torch-geometric==2.6.1
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torchaudio==2.6.0
[pip3] torchmetrics==1.6.3
[pip3] torchvision==0.21.0
[conda] numpy 2.1.0 pypi_0 pypi
[conda] pytorch-lightning 2.5.0.post0 pypi_0 pypi
[conda] torch 2.6.0 pypi_0 pypi
[conda] torch-geometric 2.6.1 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torchaudio 2.6.0 pypi_0 pypi
[conda] torchmetrics 1.6.3 pypi_0 pypi
[conda] torchvision 0.21.0 pypi_0 pypi

@lilian-83 lilian-83 added the bug label Mar 26, 2025
@xnuohz xnuohz linked a pull request Apr 12, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant