Skip to content

novae.module.AttentionAggregation

novae.module.AttentionAggregation

Bases: LightningModule

Aggregate the node embeddings using attention.

Source code in novae/module/aggregate.py
class AttentionAggregation(L.LightningModule):
    """Aggregate the node embeddings using attention."""

    @utils.format_docs
    def __init__(self, output_size: int):
        """

        Args:
            {output_size}
        """
        super().__init__()
        self.gate_nn = nn.Linear(output_size, 1)
        self.nn = nn.Linear(output_size, output_size)

        self.attention_aggregation = AttentionalAggregation(gate_nn=self.gate_nn, nn=self.nn)

    def forward(self, x: Tensor, index: Tensor) -> Tensor:
        """Performs attention aggragation.

        Args:
            x: The nodes embeddings representing `B` total graphs.
            index: The Pytorch Geometric index used to know to which graph each node belongs.

        Returns:
            A tensor of shape `(B, O)` of graph embeddings.
        """
        return self.attention_aggregation(x, index=index)

__init__(output_size)

Parameters:

Name Type Description Default
output_size int

Size of the representations, i.e. the encoder outputs (O in the article).

required
Source code in novae/module/aggregate.py
@utils.format_docs
def __init__(self, output_size: int):
    """

    Args:
        {output_size}
    """
    super().__init__()
    self.gate_nn = nn.Linear(output_size, 1)
    self.nn = nn.Linear(output_size, output_size)

    self.attention_aggregation = AttentionalAggregation(gate_nn=self.gate_nn, nn=self.nn)

forward(x, index)

Performs attention aggragation.

Parameters:

Name Type Description Default
x Tensor

The nodes embeddings representing B total graphs.

required
index Tensor

The Pytorch Geometric index used to know to which graph each node belongs.

required

Returns:

Type Description
Tensor

A tensor of shape (B, O) of graph embeddings.

Source code in novae/module/aggregate.py
def forward(self, x: Tensor, index: Tensor) -> Tensor:
    """Performs attention aggragation.

    Args:
        x: The nodes embeddings representing `B` total graphs.
        index: The Pytorch Geometric index used to know to which graph each node belongs.

    Returns:
        A tensor of shape `(B, O)` of graph embeddings.
    """
    return self.attention_aggregation(x, index=index)