Skip to content

scyan.module.ScyanModule

Bases: LightningModule

Core logic contained inside the main class Scyan. Do not use this class directly.

Attributes:

Name Type Description
real_nvp RealNVP

The Normalizing Flow (a RealNVP object)

prior PriorDistribution

The prior \(U\) (a PriorDistribution object)

pi_logit Tensor

Logits used to learn the population weights

Source code in scyan/module/scyan_module.py
class ScyanModule(pl.LightningModule):
    """Core logic contained inside the main class [Scyan][scyan.Scyan]. Do not use this class directly.

    Attributes:
        real_nvp (RealNVP): The Normalizing Flow (a [RealNVP][scyan.module.RealNVP] object)
        prior (PriorDistribution): The prior $U$ (a [PriorDistribution][scyan.module.PriorDistribution] object)
        pi_logit (Tensor): Logits used to learn the population weights
    """

    pi_logit_ratio: float = 100  # To learn pi logits faster

    def __init__(
        self,
        rho: Tensor,
        n_covariates: int,
        is_continuum_marker: Tensor,
        hidden_size: int,
        n_hidden_layers: int,
        n_layers: int,
        prior_std: float,
        temperature: float,
    ):
        """
        Args:
            rho: Tensor $\rho$ representing the knowledge table.
            n_covariates: Number of covariates $M_c$ considered.
            hidden_size: MLP (`s` and `t`) hidden size.
            n_hidden_layers: Number of hidden layers for the MLP (`s` and `t`).
            n_layers: Number of coupling layers.
            prior_std: Standard deviation $\sigma$ of the cell-specific random variable $H$.
            temperature: Temperature to favour small populations.
        """
        super().__init__()
        self.save_hyperparameters(ignore=["rho", "n_covariates", "is_continuum_marker"])

        self.n_pops, self.n_markers = rho.shape

        self.pi_logit = nn.Parameter(torch.zeros(self.n_pops))

        self.real_nvp = RealNVP(
            self.n_markers + n_covariates,
            self.hparams.hidden_size,
            self.n_markers,
            self.hparams.n_hidden_layers,
            self.hparams.n_layers,
        )

        self.prior = PriorDistribution(
            rho, is_continuum_marker, self.hparams.prior_std, self.n_markers
        )

    def forward(self, x: Tensor, covariates: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Forward implementation, going through the complete flow $f_{\phi}$.

        Args:
            x: Inputs of size $(B, M)$.
            covariates: Covariates of size $(B, M_c)$

        Returns:
            Tuple of (outputs, covariates, lod_det_jacobian sum)
        """
        return self.real_nvp(x, covariates)

    @torch.no_grad()
    def inverse(self, u: Tensor, covariates: Tensor) -> Tensor:
        """Go through the flow in reverse direction, i.e. $f_{\phi}^{-1}$.

        Args:
            u: Latent expressions of size $(B, M)$.
            covariates: Covariates of size $(B, M_c)$

        Returns:
            Outputs of size $(B, M)$.
        """
        return self.real_nvp.inverse(u, covariates)

    @property
    def prior_z(self) -> distributions.Distribution:
        """Population prior, i.e. $Categorical(\pi)$.

        Returns:
            Distribution of the population index.
        """
        return distributions.Categorical(self.pi)

    @property
    def log_pi(self) -> Tensor:
        """Log population weights $log \; \pi$."""
        return torch.log_softmax(self.pi_logit_ratio * self.pi_logit, dim=0)

    @property
    def pi(self) -> Tensor:
        """Population weights $\pi$"""
        return torch.exp(self.log_pi)

    def log_pi_temperature(self, T: float) -> Tensor:
        """Compute the log weights with temperature $log \; \pi^{(-T)}$

        Args:
            T: Temperature.

        Returns:
            Log weights with temperature.
        """
        return torch.log_softmax(self.pi_logit_ratio * self.pi_logit / T, dim=0).detach()

    @torch.no_grad()
    def sample(
        self,
        n_samples: int,
        covariates: Tensor,
        z: Union[int, Tensor, None] = None,
        return_z: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """Sampling cell-marker expressions.

        Args:
            n_samples: Number of cells to sample.
            covariates: Tensor of covariates.
            z: Either one population index or a Tensor of population indices. If None, sampling from all populations.
            return_z: Whether to return the population Tensor.

        Returns:
            Sampled cells expressions and, if `return_z`, the populations associated to these cells.
        """
        if z is None:
            z = self.prior_z.sample((n_samples,))
        elif isinstance(z, int):
            z = torch.full((n_samples,), z)
        elif isinstance(z, torch.Tensor):
            z = z.to(int)
        else:
            raise ValueError(
                f"z has to be 'None', an 'int' or a 'torch.Tensor'. Found type {type(z)}."
            )

        u = self.prior.sample(z)
        x = self.inverse(u, covariates)

        return (x, z) if return_z else x

    def compute_probabilities(
        self, x: Tensor, covariates: Tensor, use_temp: bool = False
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Compute probabilities used in the loss function.

        Args:
            x: Inputs of size $(B, M)$.
            covariates: Covariates of size $(B, M_c)$.

        Returns:
            Log probabilities of size $(B, P)$, the log det jacobian and the latent expressions of size $(B, M)$.
        """
        u, _, ldj_sum = self(x, covariates)

        log_pi = (
            self.log_pi_temperature(-self.hparams.temperature)
            if use_temp
            else self.log_pi
        )

        log_probs = self.prior.log_prob(u) + log_pi  # size N x P

        return log_probs, ldj_sum, u

    def kl(
        self,
        x: Tensor,
        covariates: Tensor,
        use_temp: bool,
    ) -> Tuple[Tensor, Tensor]:
        """Compute the module loss for one mini-batch.

        Args:
            x: Inputs of size $(B, M)$.
            covariates: Covariates of size $(B, M_c)$.
            use_temp: Whether to consider temperature is the KL term.

        Returns:
            The KL loss term.
        """
        log_probs, ldj_sum, _ = self.compute_probabilities(x, covariates, use_temp)

        return -(torch.logsumexp(log_probs, dim=1) + ldj_sum).mean()

log_pi: Tensor property

Log population weights \(log \; \pi\).

pi: Tensor property

Population weights \(\pi\)

prior_z: distributions.Distribution property

Population prior, i.e. \(Categorical(\pi)\).

Returns:

Type Description
Distribution

Distribution of the population index.

__init__(rho, n_covariates, is_continuum_marker, hidden_size, n_hidden_layers, n_layers, prior_std, temperature)

Parameters:

Name Type Description Default
rho Tensor

Tensor $ ho$ representing the knowledge table.

required
n_covariates int

Number of covariates \(M_c\) considered.

required
hidden_size int

MLP (s and t) hidden size.

required
n_hidden_layers int

Number of hidden layers for the MLP (s and t).

required
n_layers int

Number of coupling layers.

required
prior_std float

Standard deviation \(\sigma\) of the cell-specific random variable \(H\).

required
temperature float

Temperature to favour small populations.

required
Source code in scyan/module/scyan_module.py
def __init__(
    self,
    rho: Tensor,
    n_covariates: int,
    is_continuum_marker: Tensor,
    hidden_size: int,
    n_hidden_layers: int,
    n_layers: int,
    prior_std: float,
    temperature: float,
):
    """
    Args:
        rho: Tensor $\rho$ representing the knowledge table.
        n_covariates: Number of covariates $M_c$ considered.
        hidden_size: MLP (`s` and `t`) hidden size.
        n_hidden_layers: Number of hidden layers for the MLP (`s` and `t`).
        n_layers: Number of coupling layers.
        prior_std: Standard deviation $\sigma$ of the cell-specific random variable $H$.
        temperature: Temperature to favour small populations.
    """
    super().__init__()
    self.save_hyperparameters(ignore=["rho", "n_covariates", "is_continuum_marker"])

    self.n_pops, self.n_markers = rho.shape

    self.pi_logit = nn.Parameter(torch.zeros(self.n_pops))

    self.real_nvp = RealNVP(
        self.n_markers + n_covariates,
        self.hparams.hidden_size,
        self.n_markers,
        self.hparams.n_hidden_layers,
        self.hparams.n_layers,
    )

    self.prior = PriorDistribution(
        rho, is_continuum_marker, self.hparams.prior_std, self.n_markers
    )

compute_probabilities(x, covariates, use_temp=False)

Compute probabilities used in the loss function.

Parameters:

Name Type Description Default
x Tensor

Inputs of size \((B, M)\).

required
covariates Tensor

Covariates of size \((B, M_c)\).

required

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Log probabilities of size \((B, P)\), the log det jacobian and the latent expressions of size \((B, M)\).

Source code in scyan/module/scyan_module.py
def compute_probabilities(
    self, x: Tensor, covariates: Tensor, use_temp: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
    """Compute probabilities used in the loss function.

    Args:
        x: Inputs of size $(B, M)$.
        covariates: Covariates of size $(B, M_c)$.

    Returns:
        Log probabilities of size $(B, P)$, the log det jacobian and the latent expressions of size $(B, M)$.
    """
    u, _, ldj_sum = self(x, covariates)

    log_pi = (
        self.log_pi_temperature(-self.hparams.temperature)
        if use_temp
        else self.log_pi
    )

    log_probs = self.prior.log_prob(u) + log_pi  # size N x P

    return log_probs, ldj_sum, u

forward(x, covariates)

Forward implementation, going through the complete flow \(f_{\phi}\).

Parameters:

Name Type Description Default
x Tensor

Inputs of size \((B, M)\).

required
covariates Tensor

Covariates of size \((B, M_c)\)

required

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple of (outputs, covariates, lod_det_jacobian sum)

Source code in scyan/module/scyan_module.py
def forward(self, x: Tensor, covariates: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Forward implementation, going through the complete flow $f_{\phi}$.

    Args:
        x: Inputs of size $(B, M)$.
        covariates: Covariates of size $(B, M_c)$

    Returns:
        Tuple of (outputs, covariates, lod_det_jacobian sum)
    """
    return self.real_nvp(x, covariates)

inverse(u, covariates)

Go through the flow in reverse direction, i.e. \(f_{\phi}^{-1}\).

Parameters:

Name Type Description Default
u Tensor

Latent expressions of size \((B, M)\).

required
covariates Tensor

Covariates of size \((B, M_c)\)

required

Returns:

Type Description
Tensor

Outputs of size \((B, M)\).

Source code in scyan/module/scyan_module.py
@torch.no_grad()
def inverse(self, u: Tensor, covariates: Tensor) -> Tensor:
    """Go through the flow in reverse direction, i.e. $f_{\phi}^{-1}$.

    Args:
        u: Latent expressions of size $(B, M)$.
        covariates: Covariates of size $(B, M_c)$

    Returns:
        Outputs of size $(B, M)$.
    """
    return self.real_nvp.inverse(u, covariates)

kl(x, covariates, use_temp)

Compute the module loss for one mini-batch.

Parameters:

Name Type Description Default
x Tensor

Inputs of size \((B, M)\).

required
covariates Tensor

Covariates of size \((B, M_c)\).

required
use_temp bool

Whether to consider temperature is the KL term.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

The KL loss term.

Source code in scyan/module/scyan_module.py
def kl(
    self,
    x: Tensor,
    covariates: Tensor,
    use_temp: bool,
) -> Tuple[Tensor, Tensor]:
    """Compute the module loss for one mini-batch.

    Args:
        x: Inputs of size $(B, M)$.
        covariates: Covariates of size $(B, M_c)$.
        use_temp: Whether to consider temperature is the KL term.

    Returns:
        The KL loss term.
    """
    log_probs, ldj_sum, _ = self.compute_probabilities(x, covariates, use_temp)

    return -(torch.logsumexp(log_probs, dim=1) + ldj_sum).mean()

log_pi_temperature(T)

Compute the log weights with temperature \(log \; \pi^{(-T)}\)

Parameters:

Name Type Description Default
T float

Temperature.

required

Returns:

Type Description
Tensor

Log weights with temperature.

Source code in scyan/module/scyan_module.py
def log_pi_temperature(self, T: float) -> Tensor:
    """Compute the log weights with temperature $log \; \pi^{(-T)}$

    Args:
        T: Temperature.

    Returns:
        Log weights with temperature.
    """
    return torch.log_softmax(self.pi_logit_ratio * self.pi_logit / T, dim=0).detach()

sample(n_samples, covariates, z=None, return_z=False)

Sampling cell-marker expressions.

Parameters:

Name Type Description Default
n_samples int

Number of cells to sample.

required
covariates Tensor

Tensor of covariates.

required
z Union[int, Tensor, None]

Either one population index or a Tensor of population indices. If None, sampling from all populations.

None
return_z bool

Whether to return the population Tensor.

False

Returns:

Type Description
Tuple[Tensor, Tensor]

Sampled cells expressions and, if return_z, the populations associated to these cells.

Source code in scyan/module/scyan_module.py
@torch.no_grad()
def sample(
    self,
    n_samples: int,
    covariates: Tensor,
    z: Union[int, Tensor, None] = None,
    return_z: bool = False,
) -> Tuple[Tensor, Tensor]:
    """Sampling cell-marker expressions.

    Args:
        n_samples: Number of cells to sample.
        covariates: Tensor of covariates.
        z: Either one population index or a Tensor of population indices. If None, sampling from all populations.
        return_z: Whether to return the population Tensor.

    Returns:
        Sampled cells expressions and, if `return_z`, the populations associated to these cells.
    """
    if z is None:
        z = self.prior_z.sample((n_samples,))
    elif isinstance(z, int):
        z = torch.full((n_samples,), z)
    elif isinstance(z, torch.Tensor):
        z = z.to(int)
    else:
        raise ValueError(
            f"z has to be 'None', an 'int' or a 'torch.Tensor'. Found type {type(z)}."
        )

    u = self.prior.sample(z)
    x = self.inverse(u, covariates)

    return (x, z) if return_z else x