Skip to content

API

thunder.benchmark.benchmark(model, dataset, task, loading_mode='online_loading', lora=False, ckpt_save_all=False, online_wandb=False, **kwargs)

Runs a benchmark for a pretrained model on a dataset with a task-specific approach.

where options are
  • dataset: bach, bcss, bracs, break_his, ccrcc, crc, esca, mhist, ocelot, pannuke, patch_camelyon, segpath_epithelial, segpath_lymphocytes, tcga_crc_msi, tcga_tils, tcga_uniform, wilds
  • model: hiboub, hiboul, hoptimus0, hoptimus1, midnight, phikon, phikon2, uni, uni2h, virchow, virchow2, conch, titan, keep, musk, plip, quiltnetb32, dinov2base, dinov2large, vitbasepatch16224in21k, vitlargepatch16224in21k, clipvitbasepatch32, clipvitlargepatch14
  • task: adversarial_attack_linear, alignment_scoring, image_retrieval, knn, linear_probing, pre_computing_embeddings, segmentation, simple_shot, transformation_invariance
  • loading_mode: online_loading, image_pre_loading, embedding_pre_loading

Parameters:

Name Type Description Default
model str

The name of the pretrained model to use.

required
dataset str

The name of the dataset to use.

required
task str

The name of the task to perform.

required
loading_mode str

The type of data loading to use.

'online_loading'
lora bool

Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.

False
ckpt_save_all bool

Whether to save all checkpoints during training. Default is False which means that only the best is saved.

False
online_wandb bool

Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.

False
Source code in .venv/lib/python3.10/site-packages/thunder/benchmark.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def benchmark(
    model: str | Callable,
    dataset: str,
    task: str,
    loading_mode: str = "online_loading",
    lora: bool = False,
    ckpt_save_all: bool = False,
    online_wandb: bool = False,
    **kwargs,
):
    """
    Runs a benchmark for a pretrained model on a dataset with a task-specific approach.

    where options are:
        - dataset: *bach*, *bcss*, *bracs*, *break_his*, *ccrcc*, *crc*, *esca*, *mhist*, *ocelot*, *pannuke*, *patch_camelyon*, *segpath_epithelial*, *segpath_lymphocytes*, *tcga_crc_msi*, *tcga_tils*, *tcga_uniform*, *wilds*
        - model: *hiboub*, *hiboul*, *hoptimus0*, *hoptimus1*, *midnight*, *phikon*, *phikon2*, *uni*, *uni2h*, *virchow*, *virchow2*, *conch*, *titan*, *keep*, *musk*, *plip*, *quiltnetb32*, *dinov2base*, *dinov2large*, *vitbasepatch16224in21k*, *vitlargepatch16224in21k*, *clipvitbasepatch32*, *clipvitlargepatch14*
        - task: *adversarial_attack_linear*, *alignment_scoring*, *image_retrieval*, *knn*, *linear_probing*, *pre_computing_embeddings*, *segmentation*, *simple_shot*, *transformation_invariance*
        - loading_mode: *online_loading*, *image_pre_loading*, *embedding_pre_loading*

    Args:
        model (str): The name of the pretrained model to use.
        dataset (str): The name of the dataset to use.
        task (str): The name of the task to perform.
        loading_mode (str): The type of data loading to use.
        lora (bool): Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.
        ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
        online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
    """
    from hydra import compose, initialize
    from omegaconf import OmegaConf

    from .utils.config import get_config

    wandb_mode = "online" if online_wandb else "offline"
    adaptation_type = "lora" if lora else "frozen"
    ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
    model_name = model if isinstance(model, str) else None

    if model_name and model_name.startswith("custom:"):
        model = load_custom_model_from_file(model_name.split(":")[1])
        model_name = None

    # Get Config
    cfg = get_config(
        task,
        ckpt_saving,
        dataset,
        model_name,
        adaptation_type,
        loading_mode,
        wandb_mode,
        **kwargs,
    )

    if not is_dataset_available(dataset):
        from . import download_datasets

        download_datasets(dataset, make_splits=True)

    if model_name and not is_model_available(model_name):
        from . import download_models

        download_models(model)

    if isinstance(model, str):
        # If model is a string, cfg is already populated with the model details
        run_benchmark(cfg)
    else:
        # If model is a callable, pass it directly to the benchmark function
        run_benchmark(cfg, model)

thunder.download_datasets(datasets, make_splits=False)

Downloads the benchmark datasets specified in the list of dataset names.

This function requires the $THUNDER_BASE_DATA_FOLDER environment variable to be set, which indicates the base directory where the datasets will be downloaded.

The list of all available datasets
  • bach
  • bracs
  • break_his
  • ccrcc
  • crc
  • esca
  • mhist
  • ocelot
  • pannuke
  • patch_camelyon
  • segpath_epithelial
  • segpath_lymphocytes
  • tcga_crc_msi
  • tcga_tils
  • tcga_uniform
  • wilds

Parameters:

Name Type Description Default
datasets List[str] or str

A dataset name string or a List of dataset names to download or one of the following aliases: all, classification, segmentation.

required
make_splits bool

Whether to generate data splits for the datasets. Defaults to False.

False
Source code in .venv/lib/python3.10/site-packages/thunder/datasets/download.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def download_datasets(datasets: Union[List[str], str], make_splits: bool = False):
    """Downloads the benchmark datasets specified in the list of dataset names.

    This function requires the `$THUNDER_BASE_DATA_FOLDER` environment variable to be set,
    which indicates the base directory where the datasets will be downloaded.

    The list of all available datasets:
        * bach
        * bracs
        * break_his
        * ccrcc
        * crc
        * esca
        * mhist
        * ocelot
        * pannuke
        * patch_camelyon
        * segpath_epithelial
        * segpath_lymphocytes
        * tcga_crc_msi
        * tcga_tils
        * tcga_uniform
        * wilds

    Args:
        datasets (List[str] or str): A dataset name string or a List of dataset names to download or one of the following aliases: `all`, `classification`, `segmentation`.
        make_splits (bool): Whether to generate data splits for the datasets. Defaults to False.
    """
    if "THUNDER_BASE_DATA_FOLDER" not in os.environ:
        raise EnvironmentError(
            "Please set base data directory of thunder using `export THUNDER_BASE_DATA_FOLDER=/base/data/directory`"
        )

    if isinstance(datasets, str):
        datasets = [datasets]

    if len(datasets) == 1:
        if datasets[0] == "all":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
                "mhist",
            ]
        elif datasets[0] == "classification":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "mhist",
            ]
        elif datasets[0] == "segmentation":
            datasets = [
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
            ]

    for dataset in datasets:
        download_dataset(dataset)
        if make_splits:
            generate_splits([dataset])

thunder.download_models(models)

Download model checkpoints from Hugging Face.

The list of all available models
  • uni
  • uni2h
  • virchow
  • virchow2
  • hoptimus0
  • hoptimus1
  • conch
  • titan
  • phikon
  • phikon2
  • hiboub
  • hiboul
  • midnight
  • keep
  • quiltb32
  • plip
  • musk
  • dinov2base
  • dinov2large
  • vitbasepatch16224in21k
  • vitlargepatch16224in21k
  • clipvitbasepatch32
  • clipvitlargepatch14

Parameters:

Name Type Description Default
models List[str] or str

a list of model names or single a model name str.

required
Source code in .venv/lib/python3.10/site-packages/thunder/models/download.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def download_models(models: Union[List[str], str]) -> None:
    """Download model checkpoints from Hugging Face.

    The list of all available models:
        * uni
        * uni2h
        * virchow
        * virchow2
        * hoptimus0
        * hoptimus1
        * conch
        * titan
        * phikon
        * phikon2
        * hiboub
        * hiboul
        * midnight
        * keep
        * quiltb32
        * plip
        * musk
        * dinov2base
        * dinov2large
        * vitbasepatch16224in21k
        * vitlargepatch16224in21k
        * clipvitbasepatch32
        * clipvitlargepatch14

    Args:
        models (List[str] or str): a list of model names or single a model name str.
    """
    if isinstance(models, str):
        models = [models]

    for model in models:
        if model not in TAGS_FILENAMES:
            raise ValueError(f"Model {model} is not available.")
        download_model(model)

thunder.generate_splits(datasets)

Generates the data splits for all datasets in input list.

This function requires the $THUNDER_BASE_DATA_FOLDER environment variable to be set, which indicates the base directory where the datasets will be downloaded.

Parameters:

Name Type Description Default
datasets List[str]

List of dataset names to generate splits for or one of the following aliases: all, classification, segmentation.

required
Source code in .venv/lib/python3.10/site-packages/thunder/datasets/data_splits.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def generate_splits(datasets: Union[List[str], str]) -> None:
    """Generates the data splits for all datasets in input list.

    This function requires the `$THUNDER_BASE_DATA_FOLDER` environment variable to be set,
    which indicates the base directory where the datasets will be downloaded.

    Args:
        datasets (List[str]): List of dataset names to generate splits for or one of the following aliases: `all`, `classification`, `segmentation`.
    """

    if isinstance(datasets, str):
        datasets = [datasets]

    if len(datasets) == 1:
        if datasets[0] == "all":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
                "mhist",
            ]
        elif datasets[0] == "classification":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "mhist",
            ]
        elif datasets[0] == "segmentation":
            datasets = [
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
            ]

    base_folder = os.path.join(os.environ["THUNDER_BASE_DATA_FOLDER"], "datasets")
    data_splits_folder = os.path.join(base_folder, "data_splits")
    os.makedirs(data_splits_folder, exist_ok=True)

    # Generating data splits
    for dataset_name in datasets:
        generate_splits_for_dataset(dataset_name)

thunder.models.PretrainedModel

Bases: Module, ABC

Abstract class to be inherited by custom pretrained models.

Source code in .venv/lib/python3.10/site-packages/thunder/models/pretrained_models.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class PretrainedModel(torch.nn.Module, ABC):
    """Abstract class to be inherited by custom pretrained models."""

    @abstractmethod
    def get_transform(self):
        """Returns the transform function to be applied to the input images."""
        pass

    @abstractmethod
    def get_linear_probing_embeddings(self, x):
        """Returns the embeddings for linear probing."""
        pass

    @abstractmethod
    def get_segmentation_embeddings(self, x):
        """Returns the pixel dense embeddings for segmentation."""
        pass

    def get_embeddings(self, x, model, task_type):
        if task_type == "linear_probing":
            return self.get_linear_probing_embeddings(x)
        elif task_type == "segmentation":
            return self.get_segmentation_embeddings(x)
        else:
            raise ValueError(f"Invalid task type {task_type}")

get_linear_probing_embeddings(x) abstractmethod

Returns the embeddings for linear probing.

Source code in .venv/lib/python3.10/site-packages/thunder/models/pretrained_models.py
20
21
22
23
@abstractmethod
def get_linear_probing_embeddings(self, x):
    """Returns the embeddings for linear probing."""
    pass

get_segmentation_embeddings(x) abstractmethod

Returns the pixel dense embeddings for segmentation.

Source code in .venv/lib/python3.10/site-packages/thunder/models/pretrained_models.py
25
26
27
28
@abstractmethod
def get_segmentation_embeddings(self, x):
    """Returns the pixel dense embeddings for segmentation."""
    pass

get_transform() abstractmethod

Returns the transform function to be applied to the input images.

Source code in .venv/lib/python3.10/site-packages/thunder/models/pretrained_models.py
15
16
17
18
@abstractmethod
def get_transform(self):
    """Returns the transform function to be applied to the input images."""
    pass