Skip to content

Dataloader

novae.data.AnnDataTorch

Source code in novae/data/convert.py
class AnnDataTorch:
    tensors: list[Tensor] | None
    genes_indices_list: list[Tensor]

    def __init__(self, adatas: list[AnnData], cell_embedder: CellEmbedder):
        """Converting AnnData objects to PyTorch tensors.

        Args:
            adatas: A list of `AnnData` objects.
            cell_embedder: A [novae.module.CellEmbedder][] object.
        """
        super().__init__()
        self.adatas = adatas
        self.cell_embedder = cell_embedder

        self.genes_indices_list = [self._adata_to_genes_indices(adata) for adata in self.adatas]
        self.tensors = None

        self.means, self.stds, self.label_encoder = self._compute_means_stds()

        # Tensors are loaded in memory for low numbers of cells
        if sum(adata.n_obs for adata in self.adatas) < Nums.N_OBS_THRESHOLD:
            self.tensors = [self.to_tensor(adata) for adata in self.adatas]

    def _adata_to_genes_indices(self, adata: AnnData) -> Tensor:
        return self.cell_embedder.genes_to_indices(adata.var_names[self._keep_var(adata)])[None, :]

    def _keep_var(self, adata: AnnData) -> AnnData:
        return adata.var[Keys.USE_GENE]

    def _compute_means_stds(self) -> tuple[Tensor, Tensor, LabelEncoder]:
        means, stds = {}, {}

        for adata in self.adatas:
            slide_ids = adata.obs[Keys.SLIDE_ID]
            for slide_id in slide_ids.cat.categories:
                adata_slide = adata[adata.obs[Keys.SLIDE_ID] == slide_id, self._keep_var(adata)]

                mean = adata_slide.X.mean(0)
                mean = mean.A1 if isinstance(mean, np.matrix) else mean
                means[slide_id] = mean.astype(np.float32)

                std = adata_slide.X.std(0) if isinstance(adata_slide.X, np.ndarray) else sparse_std(adata_slide.X, 0).A1
                stds[slide_id] = std.astype(np.float32)

        label_encoder = LabelEncoder()
        label_encoder.fit(list(means.keys()))

        means = [torch.tensor(means[slide_id]) for slide_id in label_encoder.classes_]
        stds = [torch.tensor(stds[slide_id]) for slide_id in label_encoder.classes_]

        return means, stds, label_encoder

    def to_tensor(self, adata: AnnData) -> Tensor:
        """Get the normalized gene expressions of the cells in the dataset.
        Only the genes of interest are kept (known genes and highly variable).

        Args:
            adata: An `AnnData` object.

        Returns:
            A `Tensor` containing the normalized gene expresions.
        """
        adata = adata[:, self._keep_var(adata)]

        if len(np.unique(adata.obs[Keys.SLIDE_ID])) == 1:
            slide_id_index = self.label_encoder.transform([adata.obs.iloc[0][Keys.SLIDE_ID]])[0]
            mean, std = self.means[slide_id_index], self.stds[slide_id_index]
        else:
            slide_id_indices = self.label_encoder.transform(adata.obs[Keys.SLIDE_ID])
            mean = torch.stack([self.means[i] for i in slide_id_indices])  # TODO: avoid stack (only if not fast enough)
            std = torch.stack([self.stds[i] for i in slide_id_indices])

        X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()
        X = torch.tensor(X, dtype=torch.float32)
        X = (X - mean) / (std + Nums.EPS)

        return X

    def __getitem__(self, item: tuple[int, slice]) -> tuple[Tensor, Tensor]:
        """Get the expression values for a subset of cells (corresponding to a subgraph).

        Args:
            item: A `tuple` containing the index of the `AnnData` object and the indices of the cells in the neighborhoods.

        Returns:
            A `Tensor` of normalized gene expressions and a `Tensor` of gene indices.
        """
        adata_index, obs_indices = item

        if self.tensors is not None:
            return self.tensors[adata_index][obs_indices], self.genes_indices_list[adata_index]

        adata = self.adatas[adata_index]
        adata_view = adata[obs_indices]

        return self.to_tensor(adata_view), self.genes_indices_list[adata_index]

__getitem__(item)

Get the expression values for a subset of cells (corresponding to a subgraph).

Parameters:

Name Type Description Default
item tuple[int, slice]

A tuple containing the index of the AnnData object and the indices of the cells in the neighborhoods.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A Tensor of normalized gene expressions and a Tensor of gene indices.

Source code in novae/data/convert.py
def __getitem__(self, item: tuple[int, slice]) -> tuple[Tensor, Tensor]:
    """Get the expression values for a subset of cells (corresponding to a subgraph).

    Args:
        item: A `tuple` containing the index of the `AnnData` object and the indices of the cells in the neighborhoods.

    Returns:
        A `Tensor` of normalized gene expressions and a `Tensor` of gene indices.
    """
    adata_index, obs_indices = item

    if self.tensors is not None:
        return self.tensors[adata_index][obs_indices], self.genes_indices_list[adata_index]

    adata = self.adatas[adata_index]
    adata_view = adata[obs_indices]

    return self.to_tensor(adata_view), self.genes_indices_list[adata_index]

__init__(adatas, cell_embedder)

Converting AnnData objects to PyTorch tensors.

Parameters:

Name Type Description Default
adatas list[AnnData]

A list of AnnData objects.

required
cell_embedder CellEmbedder required
Source code in novae/data/convert.py
def __init__(self, adatas: list[AnnData], cell_embedder: CellEmbedder):
    """Converting AnnData objects to PyTorch tensors.

    Args:
        adatas: A list of `AnnData` objects.
        cell_embedder: A [novae.module.CellEmbedder][] object.
    """
    super().__init__()
    self.adatas = adatas
    self.cell_embedder = cell_embedder

    self.genes_indices_list = [self._adata_to_genes_indices(adata) for adata in self.adatas]
    self.tensors = None

    self.means, self.stds, self.label_encoder = self._compute_means_stds()

    # Tensors are loaded in memory for low numbers of cells
    if sum(adata.n_obs for adata in self.adatas) < Nums.N_OBS_THRESHOLD:
        self.tensors = [self.to_tensor(adata) for adata in self.adatas]

to_tensor(adata)

Get the normalized gene expressions of the cells in the dataset. Only the genes of interest are kept (known genes and highly variable).

Parameters:

Name Type Description Default
adata AnnData

An AnnData object.

required

Returns:

Type Description
Tensor

A Tensor containing the normalized gene expresions.

Source code in novae/data/convert.py
def to_tensor(self, adata: AnnData) -> Tensor:
    """Get the normalized gene expressions of the cells in the dataset.
    Only the genes of interest are kept (known genes and highly variable).

    Args:
        adata: An `AnnData` object.

    Returns:
        A `Tensor` containing the normalized gene expresions.
    """
    adata = adata[:, self._keep_var(adata)]

    if len(np.unique(adata.obs[Keys.SLIDE_ID])) == 1:
        slide_id_index = self.label_encoder.transform([adata.obs.iloc[0][Keys.SLIDE_ID]])[0]
        mean, std = self.means[slide_id_index], self.stds[slide_id_index]
    else:
        slide_id_indices = self.label_encoder.transform(adata.obs[Keys.SLIDE_ID])
        mean = torch.stack([self.means[i] for i in slide_id_indices])  # TODO: avoid stack (only if not fast enough)
        std = torch.stack([self.stds[i] for i in slide_id_indices])

    X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()
    X = torch.tensor(X, dtype=torch.float32)
    X = (X - mean) / (std + Nums.EPS)

    return X

novae.data.NovaeDataset

Bases: Dataset

Dataset used for training and inference.

It extracts the the neighborhood of a cell, and convert it to PyTorch Geometric Data.

Attributes:

Name Type Description
valid_indices list[ndarray]

List containing, for each adata, an array that denotes the indices of the cells whose neighborhood is valid.

obs_ilocs ndarray

An array of shape (total_valid_indices, 2). The first column corresponds to the adata index, and the second column is the cell index for the corresponding adata.

shuffled_obs_ilocs ndarray

same as obs_ilocs, but shuffled. Each batch will contain cells from the same slide.

Source code in novae/data/dataset.py
class NovaeDataset(Dataset):
    """
    Dataset used for training and inference.

    It extracts the the neighborhood of a cell, and convert it to PyTorch Geometric Data.

    Attributes:
        valid_indices (list[np.ndarray]): List containing, for each `adata`, an array that denotes the indices of the cells whose neighborhood is valid.
        obs_ilocs (np.ndarray): An array of shape `(total_valid_indices, 2)`. The first column corresponds to the adata index, and the second column is the cell index for the corresponding adata.
        shuffled_obs_ilocs (np.ndarray): same as obs_ilocs, but shuffled. Each batch will contain cells from the same slide.
    """

    valid_indices: list[np.ndarray]
    obs_ilocs: np.ndarray
    shuffled_obs_ilocs: np.ndarray

    def __init__(
        self,
        adatas: list[AnnData],
        cell_embedder: CellEmbedder,
        batch_size: int,
        n_hops_local: int,
        n_hops_view: int,
        sample_cells: int | None = None,
    ) -> None:
        """NovaeDataset constructor.

        Args:
            adatas: A list of `AnnData` objects.
            cell_embedder: A [novae.module.CellEmbedder][] object.
            batch_size: The model batch size.
            n_hops_local: Number of hops between a cell and its neighborhood cells.
            n_hops_view: Number of hops between a cell and the origin of a second graph (or 'view').
            sample_cells: If not None, the dataset if used to sample the subgraphs from precisely `sample_cells` cells.
        """
        super().__init__()
        self.adatas = adatas
        self.cell_embedder = cell_embedder
        self.anndata_torch = AnnDataTorch(self.adatas, self.cell_embedder)

        self.training = False

        self.batch_size = batch_size
        self.n_hops_local = n_hops_local
        self.n_hops_view = n_hops_view
        self.sample_cells = sample_cells

        self.single_adata = len(self.adatas) == 1
        self.single_slide_mode = self.single_adata and len(np.unique(self.adatas[0].obs[Keys.SLIDE_ID])) == 1

        self._init_dataset()

    def __repr__(self) -> str:
        multi_slide_mode, multi_adata = not self.single_slide_mode, not self.single_adata
        n_samples = sum(len(indices) for indices in self.valid_indices)
        return f"{self.__class__.__name__} with {n_samples} samples ({multi_slide_mode=}, {multi_adata=})"

    def _init_dataset(self):
        for adata in self.adatas:
            adjacency: csr_matrix = adata.obsp[Keys.ADJ]

            if Keys.ADJ_LOCAL not in adata.obsp:
                adata.obsp[Keys.ADJ_LOCAL] = _to_adjacency_local(adjacency, self.n_hops_local)
            if Keys.ADJ_PAIR not in adata.obsp:
                adata.obsp[Keys.ADJ_PAIR] = _to_adjacency_view(adjacency, self.n_hops_view)
            if Keys.IS_VALID_OBS not in adata.obs:
                adata.obs[Keys.IS_VALID_OBS] = adata.obsp[Keys.ADJ_PAIR].sum(1).A1 > 0

        ratio_valid_obs = pd.concat([adata.obs[Keys.IS_VALID_OBS] for adata in self.adatas]).mean()
        if ratio_valid_obs < Nums.RATIO_VALID_CELLS_TH:
            log.warning(
                f"Only {ratio_valid_obs:.2%} of the cells have a valid neighborhood.\n"
                "Consider running `novae.utils.spatial_neighbors` with a larger `radius`."
            )

        self.valid_indices = [utils.valid_indices(adata) for adata in self.adatas]

        self.obs_ilocs = None
        if self.single_adata:
            self.obs_ilocs = np.array([(0, obs_index) for obs_index in self.valid_indices[0]])

        self.slides_metadata: pd.DataFrame = pd.concat(
            [
                self._adata_slides_metadata(adata_index, obs_indices)
                for adata_index, obs_indices in enumerate(self.valid_indices)
            ],
            axis=0,
        )

        self.shuffle_obs_ilocs()

    def __len__(self) -> int:
        if self.sample_cells is not None:
            return min(self.sample_cells, len(self.shuffled_obs_ilocs))

        if self.training:
            n_obs = len(self.shuffled_obs_ilocs)
            return min(n_obs, max(Nums.MIN_DATASET_LENGTH, int(n_obs * Nums.MAX_DATASET_LENGTH_RATIO)))

        assert self.single_adata, "Multi-adata mode not supported for inference"

        return len(self.obs_ilocs)

    def __getitem__(self, index: int) -> dict[str, Data]:
        """Gets a sample from the dataset, with one "main" graph and its corresponding "view" graph (only during training).

        Args:
            index: Index of the sample to retrieve.

        Returns:
            A dictionnary whose keys are names, and values are PyTorch Geometric `Data` objects. The `"view"` graph is only provided during training.
        """
        if self.training or self.sample_cells is not None:
            adata_index, obs_index = self.shuffled_obs_ilocs[index]
        else:
            adata_index, obs_index = self.obs_ilocs[index]

        data = self.to_pyg_data(adata_index, obs_index)

        if not self.training:
            return {"main": data}

        adjacency_pair: csr_matrix = self.adatas[adata_index].obsp[Keys.ADJ_PAIR]
        cell_view_index = np.random.choice(list(adjacency_pair[obs_index].indices), size=1)[0]

        return {"main": data, "view": self.to_pyg_data(adata_index, cell_view_index)}

    def to_pyg_data(self, adata_index: int, obs_index: int) -> Data:
        """Create a PyTorch Geometric Data object for the input cell

        Args:
            adata_index: The index of the `AnnData` object
            obs_index: The index of the input cell for the corresponding `AnnData` object

        Returns:
            A Data object
        """
        adata = self.adatas[adata_index]
        adjacency_local: csr_matrix = adata.obsp[Keys.ADJ_LOCAL]
        obs_indices = adjacency_local[obs_index].indices

        adjacency: csr_matrix = adata.obsp[Keys.ADJ]
        edge_index, edge_weight = from_scipy_sparse_matrix(adjacency[obs_indices][:, obs_indices])
        edge_attr = edge_weight[:, None].to(torch.float32) / Nums.CELLS_CHARACTERISTIC_DISTANCE

        x, genes_indices = self.anndata_torch[adata_index, obs_indices]

        return Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            genes_indices=genes_indices,
            slide_id=adata.obs[Keys.SLIDE_ID].iloc[0],
        )

    def shuffle_obs_ilocs(self):
        """Shuffle the indices of the cells to be used in the dataset (for training only)."""
        if self.single_slide_mode:
            self.shuffled_obs_ilocs = self.obs_ilocs[np.random.permutation(len(self.obs_ilocs))]
            return

        adata_indices = np.empty((0, self.batch_size), dtype=int)
        batched_obs_indices = np.empty((0, self.batch_size), dtype=int)

        for uid in self.slides_metadata.index:
            adata_index = self.slides_metadata.loc[uid, Keys.ADATA_INDEX]
            adata = self.adatas[adata_index]
            _obs_indices = np.where((adata.obs[Keys.SLIDE_ID] == uid) & adata.obs[Keys.IS_VALID_OBS])[0]
            _obs_indices = np.random.permutation(_obs_indices)

            n_elements = self.slides_metadata.loc[uid, Keys.N_BATCHES] * self.batch_size
            if len(_obs_indices) >= n_elements:
                _obs_indices = _obs_indices[:n_elements]
            else:
                _obs_indices = np.random.choice(_obs_indices, size=n_elements)

            _obs_indices = _obs_indices.reshape((-1, self.batch_size))

            adata_indices = np.concatenate([adata_indices, np.full_like(_obs_indices, adata_index)], axis=0)
            batched_obs_indices = np.concatenate([batched_obs_indices, _obs_indices], axis=0)

        permutation = np.random.permutation(len(batched_obs_indices))
        adata_indices = adata_indices[permutation].flatten()
        obs_indices = batched_obs_indices[permutation].flatten()

        self.shuffled_obs_ilocs = np.stack([adata_indices, obs_indices], axis=1)

    def _adata_slides_metadata(self, adata_index: int, obs_indices: list[int]) -> pd.DataFrame:
        obs_counts: pd.Series = self.adatas[adata_index].obs.iloc[obs_indices][Keys.SLIDE_ID].value_counts()
        slides_metadata = obs_counts.to_frame()
        slides_metadata[Keys.ADATA_INDEX] = adata_index
        slides_metadata[Keys.N_BATCHES] = (slides_metadata["count"] // self.batch_size).clip(1)
        return slides_metadata

__getitem__(index)

Gets a sample from the dataset, with one "main" graph and its corresponding "view" graph (only during training).

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Type Description
dict[str, Data]

A dictionnary whose keys are names, and values are PyTorch Geometric Data objects. The "view" graph is only provided during training.

Source code in novae/data/dataset.py
def __getitem__(self, index: int) -> dict[str, Data]:
    """Gets a sample from the dataset, with one "main" graph and its corresponding "view" graph (only during training).

    Args:
        index: Index of the sample to retrieve.

    Returns:
        A dictionnary whose keys are names, and values are PyTorch Geometric `Data` objects. The `"view"` graph is only provided during training.
    """
    if self.training or self.sample_cells is not None:
        adata_index, obs_index = self.shuffled_obs_ilocs[index]
    else:
        adata_index, obs_index = self.obs_ilocs[index]

    data = self.to_pyg_data(adata_index, obs_index)

    if not self.training:
        return {"main": data}

    adjacency_pair: csr_matrix = self.adatas[adata_index].obsp[Keys.ADJ_PAIR]
    cell_view_index = np.random.choice(list(adjacency_pair[obs_index].indices), size=1)[0]

    return {"main": data, "view": self.to_pyg_data(adata_index, cell_view_index)}

__init__(adatas, cell_embedder, batch_size, n_hops_local, n_hops_view, sample_cells=None)

NovaeDataset constructor.

Parameters:

Name Type Description Default
adatas list[AnnData]

A list of AnnData objects.

required
cell_embedder CellEmbedder required
batch_size int

The model batch size.

required
n_hops_local int

Number of hops between a cell and its neighborhood cells.

required
n_hops_view int

Number of hops between a cell and the origin of a second graph (or 'view').

required
sample_cells int | None

If not None, the dataset if used to sample the subgraphs from precisely sample_cells cells.

None
Source code in novae/data/dataset.py
def __init__(
    self,
    adatas: list[AnnData],
    cell_embedder: CellEmbedder,
    batch_size: int,
    n_hops_local: int,
    n_hops_view: int,
    sample_cells: int | None = None,
) -> None:
    """NovaeDataset constructor.

    Args:
        adatas: A list of `AnnData` objects.
        cell_embedder: A [novae.module.CellEmbedder][] object.
        batch_size: The model batch size.
        n_hops_local: Number of hops between a cell and its neighborhood cells.
        n_hops_view: Number of hops between a cell and the origin of a second graph (or 'view').
        sample_cells: If not None, the dataset if used to sample the subgraphs from precisely `sample_cells` cells.
    """
    super().__init__()
    self.adatas = adatas
    self.cell_embedder = cell_embedder
    self.anndata_torch = AnnDataTorch(self.adatas, self.cell_embedder)

    self.training = False

    self.batch_size = batch_size
    self.n_hops_local = n_hops_local
    self.n_hops_view = n_hops_view
    self.sample_cells = sample_cells

    self.single_adata = len(self.adatas) == 1
    self.single_slide_mode = self.single_adata and len(np.unique(self.adatas[0].obs[Keys.SLIDE_ID])) == 1

    self._init_dataset()

shuffle_obs_ilocs()

Shuffle the indices of the cells to be used in the dataset (for training only).

Source code in novae/data/dataset.py
def shuffle_obs_ilocs(self):
    """Shuffle the indices of the cells to be used in the dataset (for training only)."""
    if self.single_slide_mode:
        self.shuffled_obs_ilocs = self.obs_ilocs[np.random.permutation(len(self.obs_ilocs))]
        return

    adata_indices = np.empty((0, self.batch_size), dtype=int)
    batched_obs_indices = np.empty((0, self.batch_size), dtype=int)

    for uid in self.slides_metadata.index:
        adata_index = self.slides_metadata.loc[uid, Keys.ADATA_INDEX]
        adata = self.adatas[adata_index]
        _obs_indices = np.where((adata.obs[Keys.SLIDE_ID] == uid) & adata.obs[Keys.IS_VALID_OBS])[0]
        _obs_indices = np.random.permutation(_obs_indices)

        n_elements = self.slides_metadata.loc[uid, Keys.N_BATCHES] * self.batch_size
        if len(_obs_indices) >= n_elements:
            _obs_indices = _obs_indices[:n_elements]
        else:
            _obs_indices = np.random.choice(_obs_indices, size=n_elements)

        _obs_indices = _obs_indices.reshape((-1, self.batch_size))

        adata_indices = np.concatenate([adata_indices, np.full_like(_obs_indices, adata_index)], axis=0)
        batched_obs_indices = np.concatenate([batched_obs_indices, _obs_indices], axis=0)

    permutation = np.random.permutation(len(batched_obs_indices))
    adata_indices = adata_indices[permutation].flatten()
    obs_indices = batched_obs_indices[permutation].flatten()

    self.shuffled_obs_ilocs = np.stack([adata_indices, obs_indices], axis=1)

to_pyg_data(adata_index, obs_index)

Create a PyTorch Geometric Data object for the input cell

Parameters:

Name Type Description Default
adata_index int

The index of the AnnData object

required
obs_index int

The index of the input cell for the corresponding AnnData object

required

Returns:

Type Description
Data

A Data object

Source code in novae/data/dataset.py
def to_pyg_data(self, adata_index: int, obs_index: int) -> Data:
    """Create a PyTorch Geometric Data object for the input cell

    Args:
        adata_index: The index of the `AnnData` object
        obs_index: The index of the input cell for the corresponding `AnnData` object

    Returns:
        A Data object
    """
    adata = self.adatas[adata_index]
    adjacency_local: csr_matrix = adata.obsp[Keys.ADJ_LOCAL]
    obs_indices = adjacency_local[obs_index].indices

    adjacency: csr_matrix = adata.obsp[Keys.ADJ]
    edge_index, edge_weight = from_scipy_sparse_matrix(adjacency[obs_indices][:, obs_indices])
    edge_attr = edge_weight[:, None].to(torch.float32) / Nums.CELLS_CHARACTERISTIC_DISTANCE

    x, genes_indices = self.anndata_torch[adata_index, obs_indices]

    return Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        genes_indices=genes_indices,
        slide_id=adata.obs[Keys.SLIDE_ID].iloc[0],
    )

novae.data.NovaeDatamodule

Bases: LightningDataModule

Datamodule used for training and inference. Small wrapper around the novae.data.NovaeDataset

Source code in novae/data/datamodule.py
class NovaeDatamodule(L.LightningDataModule):
    """
    Datamodule used for training and inference. Small wrapper around the [novae.data.NovaeDataset][]
    """

    def __init__(
        self,
        adatas: list[AnnData],
        cell_embedder: CellEmbedder,
        batch_size: int,
        n_hops_local: int,
        n_hops_view: int,
        num_workers: int = 0,
        sample_cells: int | None = None,
    ) -> None:
        super().__init__()
        self.dataset = NovaeDataset(
            adatas,
            cell_embedder=cell_embedder,
            batch_size=batch_size,
            n_hops_local=n_hops_local,
            n_hops_view=n_hops_view,
            sample_cells=sample_cells,
        )
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self) -> DataLoader:
        """Get a Pytorch dataloader for prediction.

        Returns:
            The training dataloader.
        """
        self.dataset.training = True
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=self.num_workers,
        )

    def predict_dataloader(self) -> DataLoader:
        """Get a Pytorch dataloader for prediction or inference.

        Returns:
            The prediction dataloader.
        """
        self.dataset.training = False
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_workers,
        )

predict_dataloader()

Get a Pytorch dataloader for prediction or inference.

Returns:

Type Description
DataLoader

The prediction dataloader.

Source code in novae/data/datamodule.py
def predict_dataloader(self) -> DataLoader:
    """Get a Pytorch dataloader for prediction or inference.

    Returns:
        The prediction dataloader.
    """
    self.dataset.training = False
    return DataLoader(
        self.dataset,
        batch_size=self.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=self.num_workers,
    )

train_dataloader()

Get a Pytorch dataloader for prediction.

Returns:

Type Description
DataLoader

The training dataloader.

Source code in novae/data/datamodule.py
def train_dataloader(self) -> DataLoader:
    """Get a Pytorch dataloader for prediction.

    Returns:
        The training dataloader.
    """
    self.dataset.training = True
    return DataLoader(
        self.dataset,
        batch_size=self.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=self.num_workers,
    )