Skip to content

Plots

scyan.plot.umap(adata, color=None, vmax='p95', vmin='p05', show=True, **scanpy_kwargs)

Plot a UMAP using scanpy.

Note

If you trained your UMAP with scyan.tools.umap on a subset of cells, it will only display the desired subset of cells.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required
color Union[str, List[str]]

Marker(s) or obs name(s) to color. It can be either just one string, or a list (it will plot one UMAP per element in the list).

None
vmax Union[str, float]

scanpy.pl.umap vmax argument.

'p95'
vmin Union[str, float]

scanpy.pl.umap vmin argument.

'p05'
show bool

Whether or not to display the figure.

True
**scanpy_kwargs int

Optional kwargs provided to scanpy.pl.umap.

{}
Source code in scyan/plot/dot.py
@plot_decorator(adata=True)
def umap(
    adata: AnnData,
    color: Union[str, List[str]] = None,
    vmax: Union[str, float] = "p95",
    vmin: Union[str, float] = "p05",
    show: bool = True,
    **scanpy_kwargs: int,
):
    """Plot a UMAP using scanpy.

    !!! note
        If you trained your UMAP with [scyan.tools.umap][] on a subset of cells, it will only display the desired subset of cells.

    Args:
        adata: An `AnnData` object.
        color: Marker(s) or `obs` name(s) to color. It can be either just one string, or a list (it will plot one UMAP per element in the list).
        vmax: `scanpy.pl.umap` vmax argument.
        vmin: `scanpy.pl.umap` vmin argument.
        show: Whether or not to display the figure.
        **scanpy_kwargs: Optional kwargs provided to `scanpy.pl.umap`.
    """
    assert isinstance(
        adata, AnnData
    ), f"umap first argument has to be an AnnData object. Received type {type(adata)}."

    has_umap = _has_umap(adata)
    if not has_umap.all():
        adata = adata[has_umap]

    if color is None:
        return scanpy_pl_umap(adata, **scanpy_kwargs)

    return scanpy_pl_umap(adata, color=color, vmax=vmax, vmin=vmin, **scanpy_kwargs)

scyan.plot.scatter(adata, population, markers=None, n_markers=3, key='scyan_pop', max_obs=2000, s=1.0, show=True)

Display marker expressions on 2D scatter plots with colors per population. One scatter plot is displayed for each pair of markers.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required
population Union[str, List[str], None]

One population, or a list of population to be colored, or None. If not None, the population name(s) has to be in adata.obs[key].

required
markers Optional[List[str]]

List of markers to plot. If None, the list is chosen automatically.

None
n_markers Optional[int]

Number of markers to choose automatically if markers is None.

3
key str

Key to look for populations in adata.obs. By default, uses the model predictions.

'scyan_pop'
max_obs int

Maximum number of cells per population to be displayed. If population is None, then this number is multiplied by 10.

2000
s float

Dot marker size.

1.0
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/dot.py
@plot_decorator(adata=True)
@check_population(return_list=True)
def scatter(
    adata: AnnData,
    population: Union[str, List[str], None],
    markers: Optional[List[str]] = None,
    n_markers: Optional[int] = 3,
    key: str = "scyan_pop",
    max_obs: int = 2000,
    s: float = 1.0,
    show: bool = True,
) -> None:
    """Display marker expressions on 2D scatter plots with colors per population.
    One scatter plot is displayed for each pair of markers.

    Args:
        adata: An `AnnData` object.
        population: One population, or a list of population to be colored, or `None`. If not `None`, the population name(s) has to be in `adata.obs[key]`.
        markers: List of markers to plot. If `None`, the list is chosen automatically.
        n_markers: Number of markers to choose automatically if `markers is None`.
        key: Key to look for populations in `adata.obs`. By default, uses the model predictions.
        max_obs: Maximum number of cells per population to be displayed. If population is None, then this number is multiplied by 10.
        s: Dot marker size.
        show: Whether or not to display the figure.
    """
    markers = select_markers(adata, markers, n_markers, key, population)

    if population is None:
        indices = _get_subset_indices(adata.n_obs, max_obs * 10)
        data = adata[indices, markers].to_df()
        g = sns.PairGrid(data, corner=True)
        g.map_offdiag(sns.scatterplot, s=s)
        g.map_diag(sns.histplot)
        g.add_legend()
        return

    data = adata[:, markers].to_df()
    keys = adata.obs[key].astype(str)
    data["Population"] = np.where(~np.isin(keys, population), "Others", keys)

    pops = list(population) + ["Others"]
    if max_obs is not None:
        groups = data.groupby("Population").groups
        data = data.loc[[i for pop in pops[::-1] for i in _subset(groups[pop], max_obs)]]

    palette = get_palette_others(data, "Population")

    g = sns.PairGrid(data, hue="Population", corner=True, palette=palette, hue_order=pops)
    g.map_offdiag(sns.scatterplot, s=s)
    g.map_diag(sns.histplot)
    g.add_legend()

scyan.plot.probs_per_marker(model, population, key='scyan_pop', prob_name='Prob', vmin_threshold=-100, figsize=(10, 6), show=True)

Interpretability tool: get a group of cells and plot a heatmap of marker probabilities for each population.

Parameters:

Name Type Description Default
model Scyan

Scyan model.

required
population str

Name of one population to interpret. To be valid, the population name has to be in adata.obs[key].

required
key str

Key to look for population in adata.obs. By default, uses the model predictions.

'scyan_pop'
prob_name str

Name to display on the plot.

'Prob'
vmin_threshold int

Minimum threshold for the heatmap colorbar.

-100
figsize Tuple[float]

Pair (width, height) indicating the size of the figure.

(10, 6)
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/heatmap.py
@torch.no_grad()
@plot_decorator()
@check_population(one=True)
def probs_per_marker(
    model: Scyan,
    population: str,
    key: str = "scyan_pop",
    prob_name: str = "Prob",
    vmin_threshold: int = -100,
    figsize: Tuple[float] = (10, 6),
    show: bool = True,
):
    """Interpretability tool: get a group of cells and plot a heatmap of marker probabilities for each population.

    Args:
        model: Scyan model.
        population: Name of one population to interpret. To be valid, the population name has to be in `adata.obs[key]`.
        key: Key to look for population in `adata.obs`. By default, uses the model predictions.
        prob_name: Name to display on the plot.
        vmin_threshold: Minimum threshold for the heatmap colorbar.
        figsize: Pair `(width, height)` indicating the size of the figure.
        show: Whether or not to display the figure.
    """
    u = model(model.adata.obs[key] == population)

    log_probs = model.module.prior.log_prob_per_marker(u)
    mean_log_probs = log_probs.mean(dim=0).numpy(force=True)

    df_probs = pd.DataFrame(
        mean_log_probs,
        columns=model.var_names,
        index=model.pop_names,
    )
    df_probs = df_probs.reindex(
        df_probs.mean().sort_values(ascending=False).index, axis=1
    )
    means = df_probs.mean(axis=1)
    means = means / means.min() * df_probs.values.min()
    df_probs.insert(0, prob_name, means)
    df_probs.insert(1, " ", np.nan)
    df_probs.sort_values(by=prob_name, inplace=True, ascending=False)

    plt.figure(figsize=figsize)
    sns.heatmap(df_probs, cmap="magma", vmin=max(vmin_threshold, mean_log_probs.min()))
    plt.title("Log probabilities per marker for each population")

scyan.plot.pop_percentage(adata, groupby=None, key='scyan_pop', figsize=None, dendogram=False, show=True)

Show populations percentages. Depending on groupby, this is either done globally, or as a stacked bar plot (one bar for each group).

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required
groupby Union[str, List[str], None]

Key(s) of adata.obs used to create groups (e.g. the patient ID).

None
key str

Key of adata.obs containing the population names (or the values) for which percentage will be displayed.

'scyan_pop'
figsize tuple[float, float]

matplotlib figure size.

None
dendogram bool

If True, the groups are sorted based on a dendogram clustering.

False
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/ratios.py
@plot_decorator(adata=True)
def pop_percentage(
    adata: AnnData,
    groupby: Union[str, List[str], None] = None,
    key: str = "scyan_pop",
    figsize: tuple[float, float] = None,
    dendogram: bool = False,
    show: bool = True,
):
    """Show populations percentages. Depending on `groupby`, this is either done globally, or as a stacked bar plot (one bar for each group).

    Args:
        adata: An `AnnData` object.
        groupby: Key(s) of `adata.obs` used to create groups (e.g. the patient ID).
        key: Key of `adata.obs` containing the population names (or the values) for which percentage will be displayed.
        figsize: matplotlib figure size.
        dendogram: If True, the groups are sorted based on a dendogram clustering.
        show: Whether or not to display the figure.
    """
    if groupby is None:
        adata.obs[key].value_counts(normalize=True).mul(100).plot.bar(figsize=figsize)
    else:
        df = adata.obs.groupby(groupby)[key].value_counts(normalize=True)
        df = df.mul(100).unstack()

        if dendogram:
            Z = hierarchy.linkage(df.values, method="ward")
            dendrogram = hierarchy.dendrogram(Z, no_plot=True)
            df = df.iloc[dendrogram["leaves"]]

        df.plot.bar(stacked=True, figsize=figsize)
        plt.legend(
            bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0, frameon=False
        )

    plt.ylabel(f"{key} percentage")
    sns.despine(offset=10, trim=True)
    plt.xticks(rotation=90)

scyan.plot.pop_dynamics(adata, time_key, groupby=None, key='scyan_pop', among=None, n_cols=4, size_mul=None, figsize=None, show=True)

Show populations percentages dynamics for different timepoints. Depending on groupby, this is either done globally, or for each group.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required
time_key str

Key of adata.obs containing the timepoints. We recommend to use a categorical series (to use the right timepoint order).

required
groupby Union[str, List[str], None]

Key(s) of adata.obs used to create groups (e.g. the patient ID).

None
key str

Key of adata.obs containing the population names (or the values) for which dynamics will be displayed.

'scyan_pop'
among str

Key of adata.obs containing the parent population name. See scyan.tools.cell_type_ratios.

None
n_cols int

Number of figures per row.

4
size_mul Optional[float]

Dot size multiplication factor. By default, it is computed using the population counts.

None
figsize tuple[float, float]

matplotlib figure size.

None
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/ratios.py
@plot_decorator(adata=True)
def pop_dynamics(
    adata: AnnData,
    time_key: str,
    groupby: Union[str, List[str], None] = None,
    key: str = "scyan_pop",
    among: str = None,
    n_cols: int = 4,
    size_mul: Optional[float] = None,
    figsize: tuple[float, float] = None,
    show: bool = True,
):
    """Show populations percentages dynamics for different timepoints. Depending on `groupby`, this is either done globally, or for each group.

    Args:
        adata: An `AnnData` object.
        time_key: Key of `adata.obs` containing the timepoints. We recommend to use a categorical series (to use the right timepoint order).
        groupby: Key(s) of `adata.obs` used to create groups (e.g. the patient ID).
        key: Key of `adata.obs` containing the population names (or the values) for which dynamics will be displayed.
        among: Key of `adata.obs` containing the parent population name. See [scyan.tools.cell_type_ratios][].
        n_cols: Number of figures per row.
        size_mul: Dot size multiplication factor. By default, it is computed using the population counts.
        figsize: matplotlib figure size.
        show: Whether or not to display the figure.
    """
    if not adata.obs[time_key].dtype.name == "category":
        log.info(f"Converting adata.obs['{time_key}'] to categorical")
        adata.obs[time_key] = adata.obs[time_key].astype("category")

    if groupby is None:
        groupby = [time_key]
    else:
        groupby = ([groupby] if isinstance(groupby, str) else groupby) + [time_key]

    df = cell_type_ratios(adata, groupby=groupby, key=key, normalize="%", among=among)
    df.index = df.index.set_levels(df.index.levels[-1].codes, level=-1)
    df_log_count = np.log(
        1 + cell_type_ratios(adata, groupby=groupby, key=key, normalize=False)
    )

    n_pops = df.shape[1]
    n_rows = ceil(n_pops / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize or (12, 3 * n_rows))

    if size_mul is None:
        size_mul = 40 / df_log_count.mean().mean()

    axes = axes.flatten()

    if len(groupby) == 1:
        for i, pop in enumerate(df.columns):
            axes[i].plot(df.index, df.iloc[:, i])
            axes[i].scatter(df.index, df.iloc[:, i], s=df_log_count.iloc[:, i] * size_mul)
    else:
        drop_levels = list(range(len(groupby) - 1))

        for group, group_df in df.groupby(level=drop_levels):
            label = " ".join(map(str, group)) if isinstance(group, tuple) else group
            group_df = group_df.droplevel(drop_levels)
            group_df_log_count = df_log_count.loc[group]

            for i, pop in enumerate(group_df.columns):
                axes[i].plot(group_df.index, group_df.iloc[:, i], label=label)
                axes[i].scatter(
                    group_df.index,
                    group_df.iloc[:, i],
                    s=(group_df_log_count.iloc[:, i] * size_mul).clip(0, 100),
                )

        fig.legend(
            *axes[0].get_legend_handles_labels(),
            bbox_to_anchor=(1.04, 0.55),
            loc="lower left",
            borderaxespad=0,
            frameon=False,
        )

    timepoints = adata.obs[time_key].cat.categories
    for i, pop in enumerate(df.columns):
        axes[i].set_ylabel(pop)
        axes[i].set_xlabel(time_key)
        axes[i].set_xticks(range(len(timepoints)), timepoints)

    sizes = [1, 10, 25, 40, 60]
    legend_markers = [
        Line2D([0], [0], linewidth=0, marker="o", markersize=np.sqrt(s)) for s in sizes
    ]
    legend2 = fig.legend(
        legend_markers,
        [f" {ceil(np.exp(s / size_mul) - 1):,} cells" for s in sizes],
        bbox_to_anchor=(1.04, 0.45),
        loc="upper left",
        borderaxespad=0,
        frameon=False,
    )
    fig.add_artist(legend2)

    for ax in axes[n_pops:]:
        ax.set_axis_off()

    sns.despine(offset=10, trim=True)
    plt.tight_layout()

scyan.plot.pop_expressions(model, population, key='scyan_pop', max_value=1.5, num_pieces=100, radius=0.05, figsize=(2, 6), show=True)

Plot latent cell expressions for one population. Contrary to scyan.plot.pops_expressions, in displays expressions on a vertical bar, from Neg to Pos.

Parameters:

Name Type Description Default
model Scyan

Scyan model.

required
population str

Name of one population to interpret. To be valid, the population name has to be in adata.obs[key].

required
key str

Key to look for populations in adata.obs. By default, uses the model predictions.

'scyan_pop'
max_value float

Maximum absolute latent value.

1.5
num_pieces int

Number of pieces to display the colorbar.

100
radius float

Radius used to chunk the colorbar. Increase this value if multiple names overlap.

0.05
figsize Tuple[float]

Pair (width, height) indicating the size of the figure.

(2, 6)
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/expressions.py
@torch.no_grad()
@plot_decorator()
@check_population(one=True)
def pop_expressions(
    model: Scyan,
    population: str,
    key: str = "scyan_pop",
    max_value: float = 1.5,
    num_pieces: int = 100,
    radius: float = 0.05,
    figsize: Tuple[float] = (2, 6),
    show: bool = True,
):
    """Plot latent cell expressions for one population. Contrary to `scyan.plot.pops_expressions`, in displays expressions on a vertical bar, from `Neg` to `Pos`.

    Args:
        model: Scyan model.
        population: Name of one population to interpret. To be valid, the population name has to be in `adata.obs[key]`.
        key: Key to look for populations in `adata.obs`. By default, uses the model predictions.
        max_value: Maximum absolute latent value.
        num_pieces: Number of pieces to display the colorbar.
        radius: Radius used to chunk the colorbar. Increase this value if multiple names overlap.
        figsize: Pair `(width, height)` indicating the size of the figure.
        show: Whether or not to display the figure.
    """
    condition = model.adata.obs[key] == population
    u_mean = model(condition).mean(dim=0)
    values = u_mean.numpy(force=True).clip(-max_value, max_value)

    y = np.linspace(-max_value, max_value, num_pieces + 1)
    cmap = plt.get_cmap("RdBu")
    y_cmap = norm.pdf(np.abs(y) - 1, scale=model.hparams.prior_std)
    y_cmap = y_cmap - y_cmap.min()
    y_cmap = 0.5 - np.sign(y) * (y_cmap / y_cmap.max() / 2)
    colors = cmap(y_cmap).clip(0, 0.8)

    plt.figure(figsize=figsize, dpi=100)
    plt.vlines(np.zeros(num_pieces), y[:-1], y[1:], colors=colors, linewidth=5)
    plt.annotate("Pos", (-0.7, 1), fontsize=15)
    plt.annotate("Neg", (-0.7, -1), fontsize=15)

    for v in np.arange(-max_value, max_value, 2 * radius):
        labels = [
            label
            for value, label in zip(values, model.var_names)
            if abs(v - value) < radius
        ]
        if labels:
            plt.plot([0, 0.1], [v, v], "k")
            plt.annotate(", ".join(labels), (0.2, v - 0.03))

    plt.xlim([-1, 1])
    plt.axis("off")

scyan.plot.pops_expressions(model, latent=True, key='scyan_pop', n_cells=200000, vmax=1.2, vmin=-1.2, cmap=None, figsize=(10, 6), show=True)

Heatmap that shows (latent or standardized) cell expressions for all populations.

Note

If using the latent space, it will only show the marker you provided to Scyan. Else, it shows every marker of the panel.

Parameters:

Name Type Description Default
model Scyan

Scyan model.

required
latent bool

If True, displays Scyan's latent expressions, else just the standardized expressions.

True
key str

Key to look for populations in adata.obs. By default, uses the model predictions.

'scyan_pop'
n_cells Optional[int]

Number of cells to be considered for the heatmap (to accelerate it when \(N\) is very high). If None, consider all cells.

200000
vmax float

Maximum value on the heatmap.

1.2
vmax float

Minimum value on the heatmap.

1.2
cmap Optional[str]

Colormap name. By default, uses "coolwarm" if latent, else "viridis".

None
figsize Tuple[float]

Pair (width, height) indicating the size of the figure.

(10, 6)
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/expressions.py
@torch.no_grad()
@plot_decorator()
def pops_expressions(
    model: Scyan,
    latent: bool = True,
    key: str = "scyan_pop",
    n_cells: Optional[int] = 200_000,
    vmax: float = 1.2,
    vmin: float = -1.2,
    cmap: Optional[str] = None,
    figsize: Tuple[float] = (10, 6),
    show: bool = True,
):
    """Heatmap that shows (latent or standardized) cell expressions for all populations.

    !!! note
        If using the latent space, it will only show the marker you provided to Scyan. Else, it shows every marker of the panel.

    Args:
        model: Scyan model.
        latent: If `True`, displays Scyan's latent expressions, else just the standardized expressions.
        key: Key to look for populations in `adata.obs`. By default, uses the model predictions.
        n_cells: Number of cells to be considered for the heatmap (to accelerate it when $N$ is very high). If `None`, consider all cells.
        vmax: Maximum value on the heatmap.
        vmax: Minimum value on the heatmap.
        cmap: Colormap name. By default, uses `"coolwarm"` if `latent`, else `"viridis"`.
        figsize: Pair `(width, height)` indicating the size of the figure.
        show: Whether or not to display the figure.
    """
    not_na = ~model.adata.obs[key].isna()
    indices = _get_subset_indices(not_na.sum(), n_cells)
    indices = np.where(not_na)[0][indices]

    x = model(indices).numpy(force=True) if latent else model.adata[indices].X
    columns = model.var_names if latent else model.adata.var_names

    df = pd.DataFrame(x, columns=columns, index=model.adata.obs.index[indices])
    df["Population"] = model.adata[indices].obs[key]

    if cmap is None:
        cmap = "coolwarm" if latent else "viridis"

    plt.figure(figsize=figsize)
    sns.heatmap(df.groupby("Population").mean(), vmax=vmax, vmin=vmin, cmap=cmap)
    plt.title(f"{'Latent' if latent else 'Standardized'} expressions grouped by {key}")

scyan.plot.kde(adata, population, markers=None, key='scyan_pop', n_markers=3, n_cells=100000, ncols=2, var_name='Marker', value_name='Expression', show=True)

Plot Kernel-Density-Estimation for each provided population and for multiple markers.

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required
population Union[str, List[str], None]

One population, or a list of population to be analyzed, or None. If not None, the population name(s) has to be in adata.obs[key].

required
markers Optional[List[str]]

List of markers to plot. If None, the list is chosen automatically.

None
key str

Key to look for populations in adata.obs. By default, uses the model predictions.

'scyan_pop'
n_markers Optional[int]

Number of markers to choose automatically if markers is None.

3
n_cells Optional[int]

Number of cells to be considered for the heatmap (to accelerate it when \(N\) is very high). If None, consider all cells.

100000
ncols int

Number of figures per row.

2
var_name str

Name displayed on the graphs.

'Marker'
value_name str

Name displayed on the graphs.

'Expression'
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/density.py
@plot_decorator(adata=True)
@check_population(return_list=True)
def kde(
    adata: AnnData,
    population: Union[str, List[str], None],
    markers: Optional[List[str]] = None,
    key: str = "scyan_pop",
    n_markers: Optional[int] = 3,
    n_cells: Optional[int] = 100_000,
    ncols: int = 2,
    var_name: str = "Marker",
    value_name: str = "Expression",
    show: bool = True,
):
    """Plot Kernel-Density-Estimation for each provided population and for multiple markers.

    Args:
        adata: An `AnnData` object.
        population: One population, or a list of population to be analyzed, or `None`. If not `None`, the population name(s) has to be in `adata.obs[key]`.
        markers: List of markers to plot. If `None`, the list is chosen automatically.
        key: Key to look for populations in `adata.obs`. By default, uses the model predictions.
        n_markers: Number of markers to choose automatically if `markers is None`.
        n_cells: Number of cells to be considered for the heatmap (to accelerate it when $N$ is very high). If `None`, consider all cells.
        ncols: Number of figures per row.
        var_name: Name displayed on the graphs.
        value_name: Name displayed on the graphs.
        show: Whether or not to display the figure.
    """
    indices = _get_subset_indices(adata.n_obs, n_cells)
    adata = adata[indices]

    markers = select_markers(adata, markers, n_markers, key, population, 1)

    df = adata.to_df()

    if population is None:
        df = pd.melt(
            df,
            value_vars=markers,
            var_name=var_name,
            value_name=value_name,
        )

        sns.displot(
            df,
            x=value_name,
            col=var_name,
            col_wrap=ncols,
            kind="kde",
            common_norm=False,
            facet_kws=dict(sharey=False),
        )
        return

    keys = adata.obs[key]
    df[key] = np.where(~np.isin(keys, population), "Others", keys)

    df = pd.melt(
        df,
        id_vars=[key],
        value_vars=markers,
        var_name=var_name,
        value_name=value_name,
    )

    sns.displot(
        df,
        x=value_name,
        col=var_name,
        hue=key,
        col_wrap=ncols,
        kind="kde",
        common_norm=False,
        facet_kws=dict(sharey=False),
        palette=get_palette_others(df, key),
        hue_order=sorted(df[key].unique(), key="Others".__eq__),
    )

scyan.plot.log_prob_threshold(adata, show=True)

Plot the number of cells annotated depending on the log probability threshold (below which cells are left non-classified). It can be helpful to determine the best threshold value, i.e. before a significative decrease in term of number of cells annotated.

Note

To use this function, you first need to fit a scyan.Scyan model and use the model.predict() method.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object used during the model training.

required
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/density.py
@plot_decorator(adata=True)
def log_prob_threshold(adata: AnnData, show: bool = True):
    """Plot the number of cells annotated depending on the log probability threshold (below which cells are left non-classified). It can be helpful to determine the best threshold value, i.e. before a significative decrease in term of number of cells annotated.

    !!! note
        To use this function, you first need to fit a `scyan.Scyan` model and use the `model.predict()` method.

    Args:
        adata: The `AnnData` object used during the model training.
        show: Whether or not to display the figure.
    """
    assert (
        "scyan_log_probs" in adata.obs
    ), f"Cannot find 'scyan_log_probs' in adata.obs. Have you run model.predict()?"

    x = np.sort(adata.obs["scyan_log_probs"])
    y = 1 - np.arange(len(x)) / float(len(x))

    plt.plot(x, y)
    plt.xlim(-100, x.max())
    sns.despine(offset=10, trim=True)
    plt.ylabel("Ratio of predicted cells")
    plt.xlabel("Log density threshold")

scyan.plot.pop_level(model, group_name, level_name='level', key='scyan_pop', **scanpy_kwargs)

Plot all subpopulations of a group at a certain level on a UMAP (according to the populations levels provided in the knowledge table).

Parameters:

Name Type Description Default
model Scyan

Scyan model.

required
group_name str

The group to look at among the populations of the selected level.

required
level_name str

Name of the column of the knowledge table containing the names of the grouped populations.

'level'
key str

Key of adata.obs to access the model predictions.

'scyan_pop'
Source code in scyan/plot/dot.py
def pop_level(
    model: Scyan,
    group_name: str,
    level_name: str = "level",
    key: str = "scyan_pop",
    **scanpy_kwargs: int,
) -> None:
    """Plot all subpopulations of a group at a certain level on a UMAP (according to the populations levels provided in the knowledge table).

    Args:
        model: Scyan model.
        group_name: The group to look at among the populations of the selected level.
        level_name: Name of the column of the knowledge table containing the names of the grouped populations.
        key: Key of `adata.obs` to access the model predictions.
    """
    adata = model.adata
    table = model.table

    assert isinstance(
        table.index, pd.MultiIndex
    ), "To use this function, you need a MultiIndex DataFrame, see: https://mics-lab.github.io/scyan/tutorials/usage/#working-with-hierarchical-populations"

    level_names = table.index.names[1:]
    assert (
        level_name in level_names
    ), f"Level '{level_name}' unknown. Choose one of: {level_names}"

    base_pops = table.index.get_level_values(0)
    group_pops = table.index.get_level_values(level_name)
    assert (
        group_name in group_pops
    ), f"Invalid group name '{group_name}'. It has to be one of: {', '.join(group_pops)}."

    valid_populations = [
        pop for pop, group in zip(base_pops, group_pops) if group == group_name
    ]
    key_name = f"{key}_one_level"
    adata.obs[key_name] = pd.Categorical(
        [pop if pop in valid_populations else np.nan for pop in adata.obs[key]]
    )
    umap(
        adata,
        color=key_name,
        title=f"Among {group_name}",
        na_in_legend=False,
        **scanpy_kwargs,
    )

scyan.plot.pops_hierarchy(model, figsize=(18, 5), show=True)

Plot populations as a tree, where each level corresponds to more detailed populations. To run this function, your knowledge table need to contain at least one population 'level' (see this tutorial), and you need to install graphviz.

Parameters:

Name Type Description Default
model Scyan

Scyan model.

required
figsize tuple

Matplotlib figure size.

(18, 5)
show bool

Whether or not to display the figure.

True
Source code in scyan/plot/graph.py
@plot_decorator()
def pops_hierarchy(model: Scyan, figsize: tuple = (18, 5), show: bool = True) -> None:
    """Plot populations as a tree, where each level corresponds to more detailed populations. To run this function, your knowledge table need to contain at least one population 'level' (see [this tutorial](../../tutorials/usage/#working-with-hierarchical-populations)), and you need to install `graphviz`.

    Args:
        model: Scyan model.
        figsize: Matplotlib figure size.
        show: Whether or not to display the figure.
    """
    import networkx as nx
    from networkx.drawing.nx_agraph import graphviz_layout

    table = model.table

    assert isinstance(
        table.index, pd.MultiIndex
    ), "To plot population hierarchy, you need a MultiIndex DataFrame. See the documentation for more details."

    root = "All populations"

    G = nx.DiGraph()
    G.add_node(root)

    def add_nodes(table, indices, level, parent=root):
        if level == -1:
            return

        index = table.index.get_level_values(level)
        dict_indices = defaultdict(list)
        for i in indices:
            dict_indices[index[i]].append(i)

        for name, indices in dict_indices.items():
            if not name == parent:
                G.add_node(name)
                G.add_edge(parent, name)
            add_nodes(table, indices, level - 1, name)

    add_nodes(
        table,
        range(len(table)),
        table.index.nlevels - 1,
    )

    plt.figure(figsize=figsize)
    pos = graphviz_layout(G, prog="dot")
    nx.draw_networkx(G, pos, with_labels=False, arrows=False, node_size=0)
    for node, (x, y) in pos.items():
        plt.text(
            x,
            y,
            node,
            ha="center",
            va="center",
            rotation=90,
            bbox=dict(facecolor="wheat", edgecolor="black", boxstyle="round,pad=0.5"),
        )

    plt.grid(False)
    plt.box(False)