Skip to content

Modules

novae.module.AttentionAggregation

Bases: Aggregation, LightningModule

Aggregate the node embeddings using attention.

Source code in novae/module/aggregate.py
class AttentionAggregation(Aggregation, L.LightningModule):
    """Aggregate the node embeddings using attention."""

    def __init__(self, output_size: int):
        """

        Args:
            output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
        """
        super().__init__()
        self.attention_aggregation = ProjectionLayers(output_size)  # for backward compatibility when loading models
        self._entropies = torch.tensor([], dtype=torch.float32)

    def forward(
        self,
        x: Tensor,
        index: Tensor | None = None,
        ptr: None = None,
        dim_size: None = None,
        dim: int = -2,
    ) -> Tensor:
        """Performs attention aggragation.

        Args:
            x: The nodes embeddings representing `B` total graphs.
            index: The Pytorch Geometric index used to know to which graph each node belongs.

        Returns:
            A tensor of shape `(B, O)` of graph embeddings.
        """
        gate = self.attention_aggregation.gate_nn(x)
        x = self.attention_aggregation.nn(x)

        gate = softmax(gate, index, dim=dim)

        if settings.store_attention_entropy:
            att = softmax(gate / 0.01, index, dim=dim)
            attention_entropy = scatter(-att * (att + Nums.EPS).log2(), index=index)[:, 0]
            self._entropies = torch.cat([self._entropies, attention_entropy])

        return self.reduce(gate * x, index, dim=dim)

    def reset_parameters(self):
        reset(self.attention_aggregation.gate_nn)
        reset(self.attention_aggregation.nn)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(gate_nn={self.attention_aggregation.gate_nn}, nn={self.attention_aggregation.nn})"

__init__(output_size)

Parameters:

Name Type Description Default
output_size int

Size of the representations, i.e. the encoder outputs (O in the article).

required
Source code in novae/module/aggregate.py
def __init__(self, output_size: int):
    """

    Args:
        output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
    """
    super().__init__()
    self.attention_aggregation = ProjectionLayers(output_size)  # for backward compatibility when loading models
    self._entropies = torch.tensor([], dtype=torch.float32)

forward(x, index=None, ptr=None, dim_size=None, dim=-2)

Performs attention aggragation.

Parameters:

Name Type Description Default
x Tensor

The nodes embeddings representing B total graphs.

required
index Tensor | None

The Pytorch Geometric index used to know to which graph each node belongs.

None

Returns:

Type Description
Tensor

A tensor of shape (B, O) of graph embeddings.

Source code in novae/module/aggregate.py
def forward(
    self,
    x: Tensor,
    index: Tensor | None = None,
    ptr: None = None,
    dim_size: None = None,
    dim: int = -2,
) -> Tensor:
    """Performs attention aggragation.

    Args:
        x: The nodes embeddings representing `B` total graphs.
        index: The Pytorch Geometric index used to know to which graph each node belongs.

    Returns:
        A tensor of shape `(B, O)` of graph embeddings.
    """
    gate = self.attention_aggregation.gate_nn(x)
    x = self.attention_aggregation.nn(x)

    gate = softmax(gate, index, dim=dim)

    if settings.store_attention_entropy:
        att = softmax(gate / 0.01, index, dim=dim)
        attention_entropy = scatter(-att * (att + Nums.EPS).log2(), index=index)[:, 0]
        self._entropies = torch.cat([self._entropies, attention_entropy])

    return self.reduce(gate * x, index, dim=dim)

novae.module.CellEmbedder

Bases: LightningModule

Convert a cell into an embedding using a gene embedding matrix.

Source code in novae/module/embed.py
class CellEmbedder(L.LightningModule):
    """Convert a cell into an embedding using a gene embedding matrix."""

    def __init__(
        self,
        gene_names: list[str] | dict[str, int],
        embedding_size: int | None,
        embedding: torch.Tensor | None = None,
    ) -> None:
        """

        Args:
            gene_names: Name of the genes to be used in the embedding, or dictionnary of index to name.
            embedding_size: Size of the embeddings of the genes (`E` in the article). Optional if `embedding` is provided.
            embedding: Optional pre-trained embedding matrix. If provided, `embedding_size` shouldn't be provided.
        """
        super().__init__()
        assert (embedding_size is None) ^ (embedding is None), "Either embedding_size or embedding must be provided"

        if isinstance(gene_names, dict):
            self.gene_to_index = {gene.lower(): index for gene, index in gene_names.items()}
            self.gene_names = sorted(self.gene_to_index, key=self.gene_to_index.get)
            _check_gene_to_index(self.gene_to_index)
        else:
            self.gene_names = [gene.lower() for gene in gene_names]
            self.gene_to_index = {gene: i for i, gene in enumerate(self.gene_names)}

        self.voc_size = len(self.gene_names)

        if embedding is None:
            self.embedding_size = embedding_size
            self.embedding = nn.Embedding(self.voc_size, embedding_size)
        else:
            self.embedding_size = embedding.size(1)
            self.embedding = nn.Embedding.from_pretrained(embedding)

        self.linear = nn.Linear(self.embedding_size, self.embedding_size)
        self._init_linear()

    @torch.no_grad()
    def _init_linear(self):
        self.linear.weight.data.copy_(torch.eye(self.embedding_size))
        self.linear.bias.data.zero_()

    @classmethod
    def from_scgpt_embedding(cls, scgpt_model_dir: str) -> "CellEmbedder":
        """Initialize the CellEmbedder from a scGPT pretrained model directory.

        Args:
            scgpt_model_dir: Path to a directory containing a scGPT checkpoint, i.e. a `vocab.json` and a `best_model.pt` file.

        Returns:
            A CellEmbedder instance.
        """
        scgpt_model_dir = Path(scgpt_model_dir)

        vocab_file = scgpt_model_dir / "vocab.json"

        with open(vocab_file, "r") as file:
            gene_to_index: dict[str, int] = json.load(file)

        checkpoint = torch.load(scgpt_model_dir / "best_model.pt", map_location=torch.device("cpu"))
        embedding = checkpoint["encoder.embedding.weight"]

        return cls(gene_to_index, None, embedding=embedding)

    def genes_to_indices(self, gene_names: pd.Index | list[str], as_torch: bool = True) -> torch.Tensor | np.ndarray:
        """Convert gene names to their corresponding indices.

        Args:
            gene_names: Names of the gene names to convert.
            as_torch: Whether to return a `torch` tensor or a `numpy` array.

        Returns:
            A tensor or array of gene indices.
        """
        indices = [self.gene_to_index[gene] for gene in utils.lower_var_names(gene_names)]

        if as_torch:
            return torch.tensor(indices, dtype=torch.long)

        return np.array(indices, dtype=np.int16)

    def forward(self, data: Data) -> Data:
        """Embed the input data.

        Args:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. The number of node features is variable.

        Returns:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. Each node now has a size of `E`.
        """
        genes_embeddings = self.embedding(data.genes_indices[0])
        genes_embeddings = self.linear(genes_embeddings)
        genes_embeddings = F.normalize(genes_embeddings, dim=0, p=2)

        data.x = data.x @ genes_embeddings
        return data

    def pca_init(self, adatas: list[AnnData] | None):
        """Initialize the Noave embeddings with PCA components.

        Args:
            adatas: A list of `AnnData` objects to use for PCA initialization.
        """
        if adatas is None:
            return

        adatas = [adata[:, adata.var[Keys.USE_GENE]] for adata in adatas]

        adata = max(adatas, key=lambda adata: adata.n_vars)

        if adata.X.shape[1] <= self.embedding_size:
            log.warning(
                f"PCA with {self.embedding_size} components can not be run on shape {adata.X.shape}.\nTo use PCA initialization, set a lower `embedding_size` (<{adata.X.shape[1]}) in novae.Novae()."
            )
            return

        X = adata.X.toarray() if issparse(adata.X) else adata.X

        log.info("Running PCA embedding initialization")

        pca = PCA(n_components=self.embedding_size)
        pca.fit(X.astype(np.float32))

        indices = self.genes_to_indices(adata.var_names)
        self.embedding.weight.data[indices] = torch.tensor(pca.components_.T)

        known_var_names = utils.lower_var_names(adata.var_names)

        for other_adata in adatas:
            other_var_names = utils.lower_var_names(other_adata.var_names)
            where_in = np.isin(other_var_names, known_var_names)

            if where_in.all():
                continue

            X = other_adata[:, where_in].X.toarray().T
            Y = other_adata[:, ~where_in].X.toarray().T

            tree = KDTree(X)
            _, ind = tree.query(Y, k=1)
            neighbor_indices = self.genes_to_indices(other_adata[:, where_in].var_names[ind[:, 0]])

            indices = self.genes_to_indices(other_adata[:, ~where_in].var_names)
            self.embedding.weight.data[indices] = self.embedding.weight.data[neighbor_indices].clone()

__init__(gene_names, embedding_size, embedding=None)

Parameters:

Name Type Description Default
gene_names list[str] | dict[str, int]

Name of the genes to be used in the embedding, or dictionnary of index to name.

required
embedding_size int | None

Size of the embeddings of the genes (E in the article). Optional if embedding is provided.

required
embedding Tensor | None

Optional pre-trained embedding matrix. If provided, embedding_size shouldn't be provided.

None
Source code in novae/module/embed.py
def __init__(
    self,
    gene_names: list[str] | dict[str, int],
    embedding_size: int | None,
    embedding: torch.Tensor | None = None,
) -> None:
    """

    Args:
        gene_names: Name of the genes to be used in the embedding, or dictionnary of index to name.
        embedding_size: Size of the embeddings of the genes (`E` in the article). Optional if `embedding` is provided.
        embedding: Optional pre-trained embedding matrix. If provided, `embedding_size` shouldn't be provided.
    """
    super().__init__()
    assert (embedding_size is None) ^ (embedding is None), "Either embedding_size or embedding must be provided"

    if isinstance(gene_names, dict):
        self.gene_to_index = {gene.lower(): index for gene, index in gene_names.items()}
        self.gene_names = sorted(self.gene_to_index, key=self.gene_to_index.get)
        _check_gene_to_index(self.gene_to_index)
    else:
        self.gene_names = [gene.lower() for gene in gene_names]
        self.gene_to_index = {gene: i for i, gene in enumerate(self.gene_names)}

    self.voc_size = len(self.gene_names)

    if embedding is None:
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(self.voc_size, embedding_size)
    else:
        self.embedding_size = embedding.size(1)
        self.embedding = nn.Embedding.from_pretrained(embedding)

    self.linear = nn.Linear(self.embedding_size, self.embedding_size)
    self._init_linear()

forward(data)

Embed the input data.

Parameters:

Name Type Description Default
data Data

A Pytorch Geometric Data object representing a batch of B graphs. The number of node features is variable.

required

Returns:

Name Type Description
data Data

A Pytorch Geometric Data object representing a batch of B graphs. Each node now has a size of E.

Source code in novae/module/embed.py
def forward(self, data: Data) -> Data:
    """Embed the input data.

    Args:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. The number of node features is variable.

    Returns:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs. Each node now has a size of `E`.
    """
    genes_embeddings = self.embedding(data.genes_indices[0])
    genes_embeddings = self.linear(genes_embeddings)
    genes_embeddings = F.normalize(genes_embeddings, dim=0, p=2)

    data.x = data.x @ genes_embeddings
    return data

from_scgpt_embedding(scgpt_model_dir) classmethod

Initialize the CellEmbedder from a scGPT pretrained model directory.

Parameters:

Name Type Description Default
scgpt_model_dir str

Path to a directory containing a scGPT checkpoint, i.e. a vocab.json and a best_model.pt file.

required

Returns:

Type Description
CellEmbedder

A CellEmbedder instance.

Source code in novae/module/embed.py
@classmethod
def from_scgpt_embedding(cls, scgpt_model_dir: str) -> "CellEmbedder":
    """Initialize the CellEmbedder from a scGPT pretrained model directory.

    Args:
        scgpt_model_dir: Path to a directory containing a scGPT checkpoint, i.e. a `vocab.json` and a `best_model.pt` file.

    Returns:
        A CellEmbedder instance.
    """
    scgpt_model_dir = Path(scgpt_model_dir)

    vocab_file = scgpt_model_dir / "vocab.json"

    with open(vocab_file, "r") as file:
        gene_to_index: dict[str, int] = json.load(file)

    checkpoint = torch.load(scgpt_model_dir / "best_model.pt", map_location=torch.device("cpu"))
    embedding = checkpoint["encoder.embedding.weight"]

    return cls(gene_to_index, None, embedding=embedding)

genes_to_indices(gene_names, as_torch=True)

Convert gene names to their corresponding indices.

Parameters:

Name Type Description Default
gene_names Index | list[str]

Names of the gene names to convert.

required
as_torch bool

Whether to return a torch tensor or a numpy array.

True

Returns:

Type Description
Tensor | ndarray

A tensor or array of gene indices.

Source code in novae/module/embed.py
def genes_to_indices(self, gene_names: pd.Index | list[str], as_torch: bool = True) -> torch.Tensor | np.ndarray:
    """Convert gene names to their corresponding indices.

    Args:
        gene_names: Names of the gene names to convert.
        as_torch: Whether to return a `torch` tensor or a `numpy` array.

    Returns:
        A tensor or array of gene indices.
    """
    indices = [self.gene_to_index[gene] for gene in utils.lower_var_names(gene_names)]

    if as_torch:
        return torch.tensor(indices, dtype=torch.long)

    return np.array(indices, dtype=np.int16)

pca_init(adatas)

Initialize the Noave embeddings with PCA components.

Parameters:

Name Type Description Default
adatas list[AnnData] | None

A list of AnnData objects to use for PCA initialization.

required
Source code in novae/module/embed.py
def pca_init(self, adatas: list[AnnData] | None):
    """Initialize the Noave embeddings with PCA components.

    Args:
        adatas: A list of `AnnData` objects to use for PCA initialization.
    """
    if adatas is None:
        return

    adatas = [adata[:, adata.var[Keys.USE_GENE]] for adata in adatas]

    adata = max(adatas, key=lambda adata: adata.n_vars)

    if adata.X.shape[1] <= self.embedding_size:
        log.warning(
            f"PCA with {self.embedding_size} components can not be run on shape {adata.X.shape}.\nTo use PCA initialization, set a lower `embedding_size` (<{adata.X.shape[1]}) in novae.Novae()."
        )
        return

    X = adata.X.toarray() if issparse(adata.X) else adata.X

    log.info("Running PCA embedding initialization")

    pca = PCA(n_components=self.embedding_size)
    pca.fit(X.astype(np.float32))

    indices = self.genes_to_indices(adata.var_names)
    self.embedding.weight.data[indices] = torch.tensor(pca.components_.T)

    known_var_names = utils.lower_var_names(adata.var_names)

    for other_adata in adatas:
        other_var_names = utils.lower_var_names(other_adata.var_names)
        where_in = np.isin(other_var_names, known_var_names)

        if where_in.all():
            continue

        X = other_adata[:, where_in].X.toarray().T
        Y = other_adata[:, ~where_in].X.toarray().T

        tree = KDTree(X)
        _, ind = tree.query(Y, k=1)
        neighbor_indices = self.genes_to_indices(other_adata[:, where_in].var_names[ind[:, 0]])

        indices = self.genes_to_indices(other_adata[:, ~where_in].var_names)
        self.embedding.weight.data[indices] = self.embedding.weight.data[neighbor_indices].clone()

novae.module.GraphAugmentation

Bases: LightningModule

Perform graph augmentation for Novae. It adds noise to the data and keeps a subset of the genes.

Source code in novae/module/augment.py
class GraphAugmentation(L.LightningModule):
    """Perform graph augmentation for Novae. It adds noise to the data and keeps a subset of the genes."""

    def __init__(
        self,
        panel_subset_size: float,
        background_noise_lambda: float,
        sensitivity_noise_std: float,
    ):
        """

        Args:
            panel_subset_size: Ratio of genes kept from the panel during augmentation.
            background_noise_lambda: Parameter of the exponential distribution for the noise augmentation.
            sensitivity_noise_std: Standard deviation for the multiplicative for for the noise augmentation.
        """
        super().__init__()
        self.panel_subset_size = panel_subset_size
        self.background_noise_lambda = background_noise_lambda
        self.sensitivity_noise_std = sensitivity_noise_std

        self.background_noise_distribution = Exponential(torch.tensor(float(background_noise_lambda)))

    def noise(self, data: Batch):
        """Add noise (inplace) to the data as detailed in the article.

        Args:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.
        """
        sample_shape = (data.batch_size, data.x.shape[1])

        additions = self.background_noise_distribution.sample(sample_shape=sample_shape).to(self.device)
        gaussian_noise = torch.randn(sample_shape, device=self.device)
        factors = (1 + gaussian_noise * self.sensitivity_noise_std).clip(0, 2)

        for i in range(data.batch_size):
            start, stop = data.ptr[i], data.ptr[i + 1]
            data.x[start:stop] = data.x[start:stop] * factors[i] + additions[i]

    def panel_subset(self, data: Batch):
        """
        Keep a ratio of `panel_subset_size` of the input genes (inplace operation).

        Args:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.
        """
        n_total = len(data.genes_indices[0])
        n_subset = int(n_total * self.panel_subset_size)

        gene_subset_indices = torch.randperm(n_total)[:n_subset]

        data.x = data.x[:, gene_subset_indices]
        data.genes_indices = data.genes_indices[:, gene_subset_indices]

    def forward(self, data: Batch) -> Batch:
        """Perform data augmentation (`noise` and `panel_subset`).

        Args:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.

        Returns:
            The augmented `Data` object
        """
        self.panel_subset(data)
        self.noise(data)
        return data

__init__(panel_subset_size, background_noise_lambda, sensitivity_noise_std)

Parameters:

Name Type Description Default
panel_subset_size float

Ratio of genes kept from the panel during augmentation.

required
background_noise_lambda float

Parameter of the exponential distribution for the noise augmentation.

required
sensitivity_noise_std float

Standard deviation for the multiplicative for for the noise augmentation.

required
Source code in novae/module/augment.py
def __init__(
    self,
    panel_subset_size: float,
    background_noise_lambda: float,
    sensitivity_noise_std: float,
):
    """

    Args:
        panel_subset_size: Ratio of genes kept from the panel during augmentation.
        background_noise_lambda: Parameter of the exponential distribution for the noise augmentation.
        sensitivity_noise_std: Standard deviation for the multiplicative for for the noise augmentation.
    """
    super().__init__()
    self.panel_subset_size = panel_subset_size
    self.background_noise_lambda = background_noise_lambda
    self.sensitivity_noise_std = sensitivity_noise_std

    self.background_noise_distribution = Exponential(torch.tensor(float(background_noise_lambda)))

forward(data)

Perform data augmentation (noise and panel_subset).

Parameters:

Name Type Description Default
data Batch

A Pytorch Geometric Data object representing a batch of B graphs.

required

Returns:

Type Description
Batch

The augmented Data object

Source code in novae/module/augment.py
def forward(self, data: Batch) -> Batch:
    """Perform data augmentation (`noise` and `panel_subset`).

    Args:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.

    Returns:
        The augmented `Data` object
    """
    self.panel_subset(data)
    self.noise(data)
    return data

noise(data)

Add noise (inplace) to the data as detailed in the article.

Parameters:

Name Type Description Default
data Batch

A Pytorch Geometric Data object representing a batch of B graphs.

required
Source code in novae/module/augment.py
def noise(self, data: Batch):
    """Add noise (inplace) to the data as detailed in the article.

    Args:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.
    """
    sample_shape = (data.batch_size, data.x.shape[1])

    additions = self.background_noise_distribution.sample(sample_shape=sample_shape).to(self.device)
    gaussian_noise = torch.randn(sample_shape, device=self.device)
    factors = (1 + gaussian_noise * self.sensitivity_noise_std).clip(0, 2)

    for i in range(data.batch_size):
        start, stop = data.ptr[i], data.ptr[i + 1]
        data.x[start:stop] = data.x[start:stop] * factors[i] + additions[i]

panel_subset(data)

Keep a ratio of panel_subset_size of the input genes (inplace operation).

Parameters:

Name Type Description Default
data Batch

A Pytorch Geometric Data object representing a batch of B graphs.

required
Source code in novae/module/augment.py
def panel_subset(self, data: Batch):
    """
    Keep a ratio of `panel_subset_size` of the input genes (inplace operation).

    Args:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.
    """
    n_total = len(data.genes_indices[0])
    n_subset = int(n_total * self.panel_subset_size)

    gene_subset_indices = torch.randperm(n_total)[:n_subset]

    data.x = data.x[:, gene_subset_indices]
    data.genes_indices = data.genes_indices[:, gene_subset_indices]

novae.module.GraphEncoder

Bases: LightningModule

Graph encoder of Novae. It uses a graph attention network.

Source code in novae/module/encode.py
class GraphEncoder(L.LightningModule):
    """Graph encoder of Novae. It uses a graph attention network."""

    def __init__(
        self,
        embedding_size: int,
        hidden_size: int,
        num_layers: int,
        output_size: int,
        heads: int,
    ) -> None:
        """
        Args:
            embedding_size: Size of the embeddings of the genes (`E` in the article).
            hidden_size: The size of the hidden layers in the GAT.
            num_layers: The number of layers in the GAT.
            output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
            heads: The number of attention heads in the GAT.
        """
        super().__init__()
        self.gnn = GAT(
            embedding_size,
            hidden_channels=hidden_size,
            num_layers=num_layers,
            out_channels=output_size,
            edge_dim=1,
            v2=True,
            heads=heads,
            act="ELU",
        )

        self.node_aggregation = AttentionAggregation(output_size)

    def forward(self, data: Batch) -> Tensor:
        """Encode the input data.

        Args:
            data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.

        Returns:
            A tensor of shape `(B, O)` containing the encoded graphs.
        """
        out = self.gnn(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)
        return self.node_aggregation(out, index=data.batch)

__init__(embedding_size, hidden_size, num_layers, output_size, heads)

Parameters:

Name Type Description Default
embedding_size int

Size of the embeddings of the genes (E in the article).

required
hidden_size int

The size of the hidden layers in the GAT.

required
num_layers int

The number of layers in the GAT.

required
output_size int

Size of the representations, i.e. the encoder outputs (O in the article).

required
heads int

The number of attention heads in the GAT.

required
Source code in novae/module/encode.py
def __init__(
    self,
    embedding_size: int,
    hidden_size: int,
    num_layers: int,
    output_size: int,
    heads: int,
) -> None:
    """
    Args:
        embedding_size: Size of the embeddings of the genes (`E` in the article).
        hidden_size: The size of the hidden layers in the GAT.
        num_layers: The number of layers in the GAT.
        output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
        heads: The number of attention heads in the GAT.
    """
    super().__init__()
    self.gnn = GAT(
        embedding_size,
        hidden_channels=hidden_size,
        num_layers=num_layers,
        out_channels=output_size,
        edge_dim=1,
        v2=True,
        heads=heads,
        act="ELU",
    )

    self.node_aggregation = AttentionAggregation(output_size)

forward(data)

Encode the input data.

Parameters:

Name Type Description Default
data Batch

A Pytorch Geometric Data object representing a batch of B graphs.

required

Returns:

Type Description
Tensor

A tensor of shape (B, O) containing the encoded graphs.

Source code in novae/module/encode.py
def forward(self, data: Batch) -> Tensor:
    """Encode the input data.

    Args:
        data: A Pytorch Geometric `Data` object representing a batch of `B` graphs.

    Returns:
        A tensor of shape `(B, O)` containing the encoded graphs.
    """
    out = self.gnn(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)
    return self.node_aggregation(out, index=data.batch)

novae.module.SwavHead

Bases: LightningModule

Source code in novae/module/swav.py
class SwavHead(L.LightningModule):
    queue: None | Tensor  # (n_slides, num_prototypes)

    def __init__(
        self,
        mode: utils.Mode,
        output_size: int,
        num_prototypes: int,
        temperature: float,
    ):
        """SwavHead module, adapted from the paper ["Unsupervised Learning of Visual Features by Contrasting Cluster Assignments"](https://arxiv.org/abs/2006.09882).

        Args:
            output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
            num_prototypes: Number of prototypes (`K` in the article).
            temperature: Temperature used in the cross-entropy loss.
        """
        super().__init__()
        self.mode = mode
        self.output_size = output_size
        self.num_prototypes = num_prototypes
        self.temperature = temperature

        self._prototypes = nn.Parameter(torch.empty((self.num_prototypes, self.output_size)))
        self._prototypes = nn.init.kaiming_uniform_(self._prototypes, a=math.sqrt(5), mode="fan_out")
        self.normalize_prototypes()
        self.min_prototypes = 0

        self.queue = None

        self.reset_clustering()

    def set_min_prototypes(self, min_prototypes_ratio: float):
        self.min_prototypes = int(self.num_prototypes * min_prototypes_ratio)

    def init_queue(self, slide_ids: list[str]) -> None:
        """Initialize the slide-queue.

        Args:
            slide_ids: A list of slide ids.
        """
        del self.queue

        shape = (len(slide_ids), Nums.QUEUE_SIZE, self.num_prototypes)
        self.register_buffer("queue", torch.full(shape, 1 / self.num_prototypes))

        self.slide_label_encoder = {slide_id: i for i, slide_id in enumerate(slide_ids)}

    @torch.no_grad()
    def normalize_prototypes(self):
        self.prototypes.data = F.normalize(self.prototypes.data, dim=1, p=2)

    def forward(self, z1: Tensor, z2: Tensor, slide_id: str | None) -> tuple[Tensor, Tensor]:
        """Compute the SwAV loss for two batches of neighborhood graph views.

        Args:
            z1: Batch containing graphs representations `(B, output_size)`
            z2: Batch containing graphs representations `(B, output_size)`

        Returns:
            The SwAV loss, and the mean entropy normalized (for monitoring).
        """
        self.normalize_prototypes()

        projections1 = self.projection(z1)  # (B, K)
        projections2 = self.projection(z2)  # (B, K)

        ilocs = self.prototype_ilocs(projections1, slide_id)

        projections1, projections2 = projections1[:, ilocs], projections2[:, ilocs]

        q1 = self.sinkhorn(projections1)  # (B, K) or (B, len(ilocs))
        q2 = self.sinkhorn(projections2)  # (B, K) or (B, len(ilocs))

        loss = -0.5 * (self.cross_entropy_loss(q1, projections2) + self.cross_entropy_loss(q2, projections1))

        return loss, _mean_entropy_normalized(q1)

    def cross_entropy_loss(self, q: Tensor, p: Tensor) -> Tensor:
        return torch.mean(torch.sum(q * F.log_softmax(p / self.temperature, dim=1), dim=1))

    def projection(self, z: Tensor) -> Tensor:
        """Compute the projection of the (normalized) representations over the prototypes.

        Args:
            z: The representations of one batch, of size `(B, O)`.

        Returns:
            The projections of size `(B, K)`.
        """
        z_normalized = F.normalize(z, dim=1, p=2)
        return z_normalized @ self.prototypes.T

    @torch.no_grad()
    def prototype_ilocs(self, projections: Tensor, slide_id: str | None = None) -> Tensor:
        """Get the indices of the prototypes to use for the current slide.

        Args:
            projections: Projections of the (normalized) representations over the prototypes, of size `(B, K)`.
            slide_id: ID of the slide, or `None`.

        Returns:
            The indices of the prototypes to use, or an `Ellipsis` if all prototypes.
        """
        if (self.queue is None) or (slide_id is None) or self.mode.zero_shot:
            return ...

        slide_index = self.slide_label_encoder[slide_id]

        self.queue[slide_index, 1:] = self.queue[slide_index, :-1].clone()
        self.queue[slide_index, 0] = projections.topk(3, dim=0).values[-1]  # top3 more robust than max

        weights, thresholds = self.queue_weights()
        slide_weights = weights[slide_index]

        ilocs = torch.where(slide_weights >= thresholds)[0]
        return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(slide_weights, self.min_prototypes).indices

    def queue_weights(self) -> tuple[Tensor, Tensor]:
        """Convert the queue to a matrix of prototype weight per slide.

        Returns:
            A tensor of shape `(n_slides, K)`.
        """
        max_projections = self.queue.max(dim=1).values

        thresholds = max_projections.max(0).values * Nums.QUEUE_WEIGHT_THRESHOLD_RATIO

        return max_projections, thresholds

    @torch.no_grad()
    def sinkhorn(self, projections: Tensor) -> Tensor:
        """Apply the Sinkhorn-Knopp algorithm to the projections.

        Args:
            projections: Projections of the (normalized) representations over the prototypes, of size `(B, K)`.

        Returns:
            The soft codes from the Sinkhorn-Knopp algorithm, with shape `(B, K)`.
        """
        Q = torch.exp(projections / Nums.SWAV_EPSILON)  # (B, K)
        Q /= torch.sum(Q)

        B, K = Q.shape

        for _ in range(Nums.SINKHORN_ITERATIONS):
            Q /= torch.sum(Q, dim=0, keepdim=True)
            Q /= K
            Q /= torch.sum(Q, dim=1, keepdim=True)
            Q /= B

        return Q / Q.sum(dim=1, keepdim=True)  # ensure rows sum to 1 (for cross-entropy loss)

    def set_kmeans_prototypes(self, latent: np.ndarray):
        assert (
            len(latent) >= self.num_prototypes
        ), f"The number of valid cells ({len(latent)}) must be greater than the number of prototypes ({self.num_prototypes})."

        kmeans = KMeans(n_clusters=self.num_prototypes, random_state=0, n_init="auto")
        X = latent / (Nums.EPS + np.linalg.norm(latent, axis=1)[:, None])

        kmeans_prototypes = kmeans.fit(X).cluster_centers_
        kmeans_prototypes = kmeans_prototypes / (Nums.EPS + np.linalg.norm(kmeans_prototypes, axis=1)[:, None])

        self._kmeans_prototypes = torch.nn.Parameter(torch.tensor(kmeans_prototypes))

    @property
    def prototypes(self) -> nn.Parameter:
        return self._kmeans_prototypes if self.mode.zero_shot else self._prototypes

    @property
    def clustering(self) -> AgglomerativeClustering:
        clustering_attr = self.mode.clustering_attr

        if getattr(self, clustering_attr) is None:
            self.hierarchical_clustering()

        return getattr(self, clustering_attr)

    @property
    def clusters_levels(self) -> np.ndarray:
        clusters_levels_attr = self.mode.clusters_levels_attr

        if getattr(self, clusters_levels_attr) is None:
            self.hierarchical_clustering()

        return getattr(self, clusters_levels_attr)

    def reset_clustering(self, only_zero_shot: bool = False) -> None:
        attrs = self.mode.zero_shot_clustering_attrs if only_zero_shot else self.mode.all_clustering_attrs
        for attr in attrs:
            setattr(self, attr, None)

    def set_clustering(self, clustering: None, clusters_levels: None) -> None:
        setattr(self, self.mode.clustering_attr, clustering)
        setattr(self, self.mode.clusters_levels_attr, clusters_levels)

    def hierarchical_clustering(self) -> None:
        """
        Perform hierarchical clustering on the prototypes. Saves the full tree of clusters.
        """
        X = self.prototypes.data.numpy(force=True)  # (K, O)

        _clustering = AgglomerativeClustering(
            n_clusters=None,
            distance_threshold=0,
            compute_full_tree=True,
            metric="cosine",
            linkage="average",
        )
        _clustering.fit(X)

        _clusters_levels = np.zeros((len(X), len(X)), dtype=np.uint16)
        _clusters_levels[0] = np.arange(len(X))

        for i, (a, b) in enumerate(_clustering.children_):
            clusters = _clusters_levels[i]
            _clusters_levels[i + 1] = clusters
            _clusters_levels[i + 1, np.where((clusters == a) | (clusters == b))] = len(X) + i

        self.set_clustering(_clustering, _clusters_levels)

    def map_leaves_domains(self, series: pd.Series, level: int) -> pd.Series:
        """Map leaves to the parent domain from the corresponding level of the hierarchical tree.

        Args:
            series: Leaves classes
            level: Level of the hierarchical clustering tree (or, number of clusters)

        Returns:
            Series of classes.
        """
        return series.map(lambda x: f"D{self.clusters_levels[-level, int(x[1:])]}" if isinstance(x, str) else x)

    def find_level(self, leaves_indices: np.ndarray, n_domains: int):
        sub_clusters_levels = self.clusters_levels[:, leaves_indices]
        for level in range(1, self.num_prototypes):
            _n_domains = len(np.unique(sub_clusters_levels[-level]))
            if _n_domains == n_domains:
                return level
        raise ValueError(f"Could not find a level with {n_domains=}")

__init__(mode, output_size, num_prototypes, temperature)

SwavHead module, adapted from the paper "Unsupervised Learning of Visual Features by Contrasting Cluster Assignments".

Parameters:

Name Type Description Default
output_size int

Size of the representations, i.e. the encoder outputs (O in the article).

required
num_prototypes int

Number of prototypes (K in the article).

required
temperature float

Temperature used in the cross-entropy loss.

required
Source code in novae/module/swav.py
def __init__(
    self,
    mode: utils.Mode,
    output_size: int,
    num_prototypes: int,
    temperature: float,
):
    """SwavHead module, adapted from the paper ["Unsupervised Learning of Visual Features by Contrasting Cluster Assignments"](https://arxiv.org/abs/2006.09882).

    Args:
        output_size: Size of the representations, i.e. the encoder outputs (`O` in the article).
        num_prototypes: Number of prototypes (`K` in the article).
        temperature: Temperature used in the cross-entropy loss.
    """
    super().__init__()
    self.mode = mode
    self.output_size = output_size
    self.num_prototypes = num_prototypes
    self.temperature = temperature

    self._prototypes = nn.Parameter(torch.empty((self.num_prototypes, self.output_size)))
    self._prototypes = nn.init.kaiming_uniform_(self._prototypes, a=math.sqrt(5), mode="fan_out")
    self.normalize_prototypes()
    self.min_prototypes = 0

    self.queue = None

    self.reset_clustering()

forward(z1, z2, slide_id)

Compute the SwAV loss for two batches of neighborhood graph views.

Parameters:

Name Type Description Default
z1 Tensor

Batch containing graphs representations (B, output_size)

required
z2 Tensor

Batch containing graphs representations (B, output_size)

required

Returns:

Type Description
tuple[Tensor, Tensor]

The SwAV loss, and the mean entropy normalized (for monitoring).

Source code in novae/module/swav.py
def forward(self, z1: Tensor, z2: Tensor, slide_id: str | None) -> tuple[Tensor, Tensor]:
    """Compute the SwAV loss for two batches of neighborhood graph views.

    Args:
        z1: Batch containing graphs representations `(B, output_size)`
        z2: Batch containing graphs representations `(B, output_size)`

    Returns:
        The SwAV loss, and the mean entropy normalized (for monitoring).
    """
    self.normalize_prototypes()

    projections1 = self.projection(z1)  # (B, K)
    projections2 = self.projection(z2)  # (B, K)

    ilocs = self.prototype_ilocs(projections1, slide_id)

    projections1, projections2 = projections1[:, ilocs], projections2[:, ilocs]

    q1 = self.sinkhorn(projections1)  # (B, K) or (B, len(ilocs))
    q2 = self.sinkhorn(projections2)  # (B, K) or (B, len(ilocs))

    loss = -0.5 * (self.cross_entropy_loss(q1, projections2) + self.cross_entropy_loss(q2, projections1))

    return loss, _mean_entropy_normalized(q1)

hierarchical_clustering()

Perform hierarchical clustering on the prototypes. Saves the full tree of clusters.

Source code in novae/module/swav.py
def hierarchical_clustering(self) -> None:
    """
    Perform hierarchical clustering on the prototypes. Saves the full tree of clusters.
    """
    X = self.prototypes.data.numpy(force=True)  # (K, O)

    _clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=0,
        compute_full_tree=True,
        metric="cosine",
        linkage="average",
    )
    _clustering.fit(X)

    _clusters_levels = np.zeros((len(X), len(X)), dtype=np.uint16)
    _clusters_levels[0] = np.arange(len(X))

    for i, (a, b) in enumerate(_clustering.children_):
        clusters = _clusters_levels[i]
        _clusters_levels[i + 1] = clusters
        _clusters_levels[i + 1, np.where((clusters == a) | (clusters == b))] = len(X) + i

    self.set_clustering(_clustering, _clusters_levels)

init_queue(slide_ids)

Initialize the slide-queue.

Parameters:

Name Type Description Default
slide_ids list[str]

A list of slide ids.

required
Source code in novae/module/swav.py
def init_queue(self, slide_ids: list[str]) -> None:
    """Initialize the slide-queue.

    Args:
        slide_ids: A list of slide ids.
    """
    del self.queue

    shape = (len(slide_ids), Nums.QUEUE_SIZE, self.num_prototypes)
    self.register_buffer("queue", torch.full(shape, 1 / self.num_prototypes))

    self.slide_label_encoder = {slide_id: i for i, slide_id in enumerate(slide_ids)}

map_leaves_domains(series, level)

Map leaves to the parent domain from the corresponding level of the hierarchical tree.

Parameters:

Name Type Description Default
series Series

Leaves classes

required
level int

Level of the hierarchical clustering tree (or, number of clusters)

required

Returns:

Type Description
Series

Series of classes.

Source code in novae/module/swav.py
def map_leaves_domains(self, series: pd.Series, level: int) -> pd.Series:
    """Map leaves to the parent domain from the corresponding level of the hierarchical tree.

    Args:
        series: Leaves classes
        level: Level of the hierarchical clustering tree (or, number of clusters)

    Returns:
        Series of classes.
    """
    return series.map(lambda x: f"D{self.clusters_levels[-level, int(x[1:])]}" if isinstance(x, str) else x)

projection(z)

Compute the projection of the (normalized) representations over the prototypes.

Parameters:

Name Type Description Default
z Tensor

The representations of one batch, of size (B, O).

required

Returns:

Type Description
Tensor

The projections of size (B, K).

Source code in novae/module/swav.py
def projection(self, z: Tensor) -> Tensor:
    """Compute the projection of the (normalized) representations over the prototypes.

    Args:
        z: The representations of one batch, of size `(B, O)`.

    Returns:
        The projections of size `(B, K)`.
    """
    z_normalized = F.normalize(z, dim=1, p=2)
    return z_normalized @ self.prototypes.T

prototype_ilocs(projections, slide_id=None)

Get the indices of the prototypes to use for the current slide.

Parameters:

Name Type Description Default
projections Tensor

Projections of the (normalized) representations over the prototypes, of size (B, K).

required
slide_id str | None

ID of the slide, or None.

None

Returns:

Type Description
Tensor

The indices of the prototypes to use, or an Ellipsis if all prototypes.

Source code in novae/module/swav.py
@torch.no_grad()
def prototype_ilocs(self, projections: Tensor, slide_id: str | None = None) -> Tensor:
    """Get the indices of the prototypes to use for the current slide.

    Args:
        projections: Projections of the (normalized) representations over the prototypes, of size `(B, K)`.
        slide_id: ID of the slide, or `None`.

    Returns:
        The indices of the prototypes to use, or an `Ellipsis` if all prototypes.
    """
    if (self.queue is None) or (slide_id is None) or self.mode.zero_shot:
        return ...

    slide_index = self.slide_label_encoder[slide_id]

    self.queue[slide_index, 1:] = self.queue[slide_index, :-1].clone()
    self.queue[slide_index, 0] = projections.topk(3, dim=0).values[-1]  # top3 more robust than max

    weights, thresholds = self.queue_weights()
    slide_weights = weights[slide_index]

    ilocs = torch.where(slide_weights >= thresholds)[0]
    return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(slide_weights, self.min_prototypes).indices

queue_weights()

Convert the queue to a matrix of prototype weight per slide.

Returns:

Type Description
tuple[Tensor, Tensor]

A tensor of shape (n_slides, K).

Source code in novae/module/swav.py
def queue_weights(self) -> tuple[Tensor, Tensor]:
    """Convert the queue to a matrix of prototype weight per slide.

    Returns:
        A tensor of shape `(n_slides, K)`.
    """
    max_projections = self.queue.max(dim=1).values

    thresholds = max_projections.max(0).values * Nums.QUEUE_WEIGHT_THRESHOLD_RATIO

    return max_projections, thresholds

sinkhorn(projections)

Apply the Sinkhorn-Knopp algorithm to the projections.

Parameters:

Name Type Description Default
projections Tensor

Projections of the (normalized) representations over the prototypes, of size (B, K).

required

Returns:

Type Description
Tensor

The soft codes from the Sinkhorn-Knopp algorithm, with shape (B, K).

Source code in novae/module/swav.py
@torch.no_grad()
def sinkhorn(self, projections: Tensor) -> Tensor:
    """Apply the Sinkhorn-Knopp algorithm to the projections.

    Args:
        projections: Projections of the (normalized) representations over the prototypes, of size `(B, K)`.

    Returns:
        The soft codes from the Sinkhorn-Knopp algorithm, with shape `(B, K)`.
    """
    Q = torch.exp(projections / Nums.SWAV_EPSILON)  # (B, K)
    Q /= torch.sum(Q)

    B, K = Q.shape

    for _ in range(Nums.SINKHORN_ITERATIONS):
        Q /= torch.sum(Q, dim=0, keepdim=True)
        Q /= K
        Q /= torch.sum(Q, dim=1, keepdim=True)
        Q /= B

    return Q / Q.sum(dim=1, keepdim=True)  # ensure rows sum to 1 (for cross-entropy loss)