Skip to content

API

thunder.benchmark.benchmark(model, dataset, task, loading_mode='online_loading', lora=False, ckpt_save_all=False, online_wandb=False, recomp_embs=False, retrain_model=False, **kwargs)

Runs a benchmark for a pretrained model on a dataset with a task-specific approach.

where options are
  • dataset: bach, bracs, break_his, ccrcc, crc, esca, mhist, ocelot, pannuke, patch_camelyon, segpath_epithelial, segpath_lymphocytes, tcga_crc_msi, tcga_tils, tcga_uniform, wilds
  • model: hiboub, hiboul, hoptimus0, hoptimus1, midnight, phikon, phikon2, uni, uni2h, virchow, virchow2, conch, titan, keep, musk, plip, quiltnetb32, dinov2base, dinov2large, vitbasepatch16224in21k, vitlargepatch16224in21k, clipvitbasepatch32, clipvitlargepatch14
  • task: adversarial_attack, alignment_scoring, image_retrieval, knn, linear_probing, pre_computing_embeddings, segmentation, simple_shot, transformation_invariance
  • loading_mode: online_loading, image_pre_loading, embedding_pre_loading

Parameters:

Name Type Description Default
model str

The name of the pretrained model to use.

required
dataset str

The name of the dataset to use.

required
task str

The name of the task to perform.

required
loading_mode str

The type of data loading to use.

'online_loading'
lora bool

Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.

False
ckpt_save_all bool

Whether to save all checkpoints during training. Default is False which means that only the best is saved.

False
online_wandb bool

Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.

False
recomp_embs bool

Whether to recompute embeddings if already saved.

False
retrain_model bool

Whether to retrain model if already trained and saved ckpts.

False
Source code in src/thunder/benchmark.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def benchmark(
    model: str | Callable,
    dataset: str,
    task: str,
    loading_mode: str = "online_loading",
    lora: bool = False,
    ckpt_save_all: bool = False,
    online_wandb: bool = False,
    recomp_embs: bool = False,
    retrain_model: bool = False,
    **kwargs,
):
    """
    Runs a benchmark for a pretrained model on a dataset with a task-specific approach.

    where options are:
        - dataset: *bach*, *bracs*, *break_his*, *ccrcc*, *crc*, *esca*, *mhist*, *ocelot*, *pannuke*, *patch_camelyon*, *segpath_epithelial*, *segpath_lymphocytes*, *tcga_crc_msi*, *tcga_tils*, *tcga_uniform*, *wilds*
        - model: *hiboub*, *hiboul*, *hoptimus0*, *hoptimus1*, *midnight*, *phikon*, *phikon2*, *uni*, *uni2h*, *virchow*, *virchow2*, *conch*, *titan*, *keep*, *musk*, *plip*, *quiltnetb32*, *dinov2base*, *dinov2large*, *vitbasepatch16224in21k*, *vitlargepatch16224in21k*, *clipvitbasepatch32*, *clipvitlargepatch14*
        - task: *adversarial_attack*, *alignment_scoring*, *image_retrieval*, *knn*, *linear_probing*, *pre_computing_embeddings*, *segmentation*, *simple_shot*, *transformation_invariance*
        - loading_mode: *online_loading*, *image_pre_loading*, *embedding_pre_loading*

    Args:
        model (str): The name of the pretrained model to use.
        dataset (str): The name of the dataset to use.
        task (str): The name of the task to perform.
        loading_mode (str): The type of data loading to use.
        lora (bool): Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.
        ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
        online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
        recomp_embs (bool): Whether to recompute embeddings if already saved.
        retrain_model (bool): Whether to retrain model if already trained and saved ckpts.
    """
    from hydra import compose, initialize
    from omegaconf import OmegaConf

    from .utils.config import get_config

    wandb_mode = "online" if online_wandb else "offline"
    adaptation_type = "lora" if lora else "frozen"
    ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
    embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
    model_retraining = "retrain_model" if retrain_model else "no_retrain_model"
    model_name = model if isinstance(model, str) else None
    custom_name = None

    if model_name and model_name.startswith("custom:"):
        model = load_custom_model_from_file(model_name.split(":")[1])
        model_name = None
        custom_name = model.name

    # Get Config
    cfg = get_config(
        task,
        ckpt_saving,
        dataset,
        model_name,
        adaptation_type,
        loading_mode,
        wandb_mode,
        embedding_recomputing,
        model_retraining,
        **kwargs,
    )

    print_task_hyperparams(cfg, custom_name=custom_name)

    if not is_dataset_available(dataset):
        from . import download_datasets

        download_datasets(dataset, make_splits=True)

    if model_name and not is_model_available(model_name):
        from . import download_models

        download_models(model)

    if isinstance(model, str):
        # If model is a string, cfg is already populated with the model details
        run_benchmark(cfg)
    else:
        # If model is a callable, pass it directly to the benchmark function
        run_benchmark(cfg, model)

thunder.download_datasets(datasets, make_splits=False)

Downloads the benchmark datasets specified in the list of dataset names.

This function requires the $THUNDER_BASE_DATA_FOLDER environment variable to be set, which indicates the base directory where the datasets will be downloaded.

The list of all available datasets
  • bach
  • bracs
  • break_his
  • ccrcc
  • crc
  • esca
  • mhist
  • ocelot
  • pannuke
  • patch_camelyon
  • segpath_epithelial
  • segpath_lymphocytes
  • tcga_crc_msi
  • tcga_tils
  • tcga_uniform
  • wilds

Parameters:

Name Type Description Default
datasets List[str] or str

A dataset name string or a List of dataset names to download or one of the following aliases: all, classification, segmentation.

required
make_splits bool

Whether to generate data splits for the datasets. Defaults to False.

False
Source code in src/thunder/datasets/download.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def download_datasets(datasets: Union[List[str], str], make_splits: bool = False):
    """Downloads the benchmark datasets specified in the list of dataset names.

    This function requires the `$THUNDER_BASE_DATA_FOLDER` environment variable to be set,
    which indicates the base directory where the datasets will be downloaded.

    The list of all available datasets:
        * bach
        * bracs
        * break_his
        * ccrcc
        * crc
        * esca
        * mhist
        * ocelot
        * pannuke
        * patch_camelyon
        * segpath_epithelial
        * segpath_lymphocytes
        * tcga_crc_msi
        * tcga_tils
        * tcga_uniform
        * wilds

    Args:
        datasets (List[str] or str): A dataset name string or a List of dataset names to download or one of the following aliases: `all`, `classification`, `segmentation`.
        make_splits (bool): Whether to generate data splits for the datasets. Defaults to False.
    """
    if "THUNDER_BASE_DATA_FOLDER" not in os.environ:
        raise EnvironmentError(
            "Please set base data directory of thunder using `export THUNDER_BASE_DATA_FOLDER=/base/data/directory`"
        )

    if isinstance(datasets, str):
        datasets = [datasets]

    if len(datasets) == 1:
        if datasets[0] == "all":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
                "mhist",
            ]
        elif datasets[0] == "classification":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "mhist",
            ]
        elif datasets[0] == "segmentation":
            datasets = [
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
            ]

    for dataset in datasets:
        download_dataset(dataset)
        if make_splits:
            generate_splits([dataset])

thunder.download_models(models)

Download model checkpoints from Hugging Face.

The list of all available models
  • uni
  • uni2h
  • virchow
  • virchow2
  • hoptimus0
  • hoptimus1
  • conch
  • titan
  • phikon
  • phikon2
  • hiboub
  • hiboul
  • midnight
  • keep
  • quiltb32
  • plip
  • musk
  • dinov2base
  • dinov2large
  • vitbasepatch16224in21k
  • vitlargepatch16224in21k
  • clipvitbasepatch32
  • clipvitlargepatch14

Parameters:

Name Type Description Default
models List[str] or str

a list of model names or single a model name str.

required
Source code in src/thunder/models/download.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def download_models(models: Union[List[str], str]) -> None:
    """Download model checkpoints from Hugging Face.

    The list of all available models:
        * uni
        * uni2h
        * virchow
        * virchow2
        * hoptimus0
        * hoptimus1
        * conch
        * titan
        * phikon
        * phikon2
        * hiboub
        * hiboul
        * midnight
        * keep
        * quiltb32
        * plip
        * musk
        * dinov2base
        * dinov2large
        * vitbasepatch16224in21k
        * vitlargepatch16224in21k
        * clipvitbasepatch32
        * clipvitlargepatch14

    Args:
        models (List[str] or str): a list of model names or single a model name str.
    """
    if isinstance(models, str):
        models = [models]

    for model in models:
        if model not in TAGS_FILENAMES:
            raise ValueError(f"Model {model} is not available.")
        download_model(model)

thunder.generate_splits(datasets)

Generates the data splits for all datasets in input list.

This function requires the $THUNDER_BASE_DATA_FOLDER environment variable to be set, which indicates the base directory where the datasets will be downloaded.

Parameters:

Name Type Description Default
datasets List[str]

List of dataset names to generate splits for or one of the following aliases: all, classification, segmentation.

required
Source code in src/thunder/datasets/data_splits.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def generate_splits(datasets: Union[List[str], str]) -> None:
    """Generates the data splits for all datasets in input list.

    This function requires the `$THUNDER_BASE_DATA_FOLDER` environment variable to be set,
    which indicates the base directory where the datasets will be downloaded.

    Args:
        datasets (List[str]): List of dataset names to generate splits for or one of the following aliases: `all`, `classification`, `segmentation`.
    """

    if isinstance(datasets, str):
        datasets = [datasets]

    if len(datasets) == 1:
        if datasets[0] == "all":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
                "mhist",
            ]
        elif datasets[0] == "classification":
            datasets = [
                "bach",
                "bracs",
                "break_his",
                "ccrcc",
                "crc",
                "esca",
                "patch_camelyon",
                "tcga_crc_msi",
                "tcga_tils",
                "tcga_uniform",
                "wilds",
                "mhist",
            ]
        elif datasets[0] == "segmentation":
            datasets = [
                "ocelot",
                "pannuke",
                "segpath_epithelial",
                "segpath_lymphocytes",
            ]

    base_folder = os.path.join(os.environ["THUNDER_BASE_DATA_FOLDER"], "datasets")
    data_splits_folder = os.path.join(base_folder, "data_splits")
    os.makedirs(data_splits_folder, exist_ok=True)

    # Generating data splits
    for dataset_name in datasets:
        generate_splits_for_dataset(dataset_name)

thunder.models.PretrainedModel

Bases: Module, ABC

Abstract class to be inherited by custom pretrained models.

Source code in src/thunder/models/pretrained_models.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class PretrainedModel(torch.nn.Module, ABC):
    """Abstract class to be inherited by custom pretrained models."""

    @abstractmethod
    def get_transform(self):
        """Returns the transform function to be applied to the input images."""
        pass

    @abstractmethod
    def get_linear_probing_embeddings(self, x):
        """Returns the embeddings for linear probing."""
        pass

    @abstractmethod
    def get_segmentation_embeddings(self, x):
        """Returns the pixel dense embeddings for segmentation."""
        pass

    def get_embeddings(self, x, model, task_type):
        if task_type == "linear_probing":
            return self.get_linear_probing_embeddings(x)
        elif task_type == "segmentation":
            return self.get_segmentation_embeddings(x)
        else:
            raise ValueError(f"Invalid task type {task_type}")

get_linear_probing_embeddings(x) abstractmethod

Returns the embeddings for linear probing.

Source code in src/thunder/models/pretrained_models.py
21
22
23
24
@abstractmethod
def get_linear_probing_embeddings(self, x):
    """Returns the embeddings for linear probing."""
    pass

get_segmentation_embeddings(x) abstractmethod

Returns the pixel dense embeddings for segmentation.

Source code in src/thunder/models/pretrained_models.py
26
27
28
29
@abstractmethod
def get_segmentation_embeddings(self, x):
    """Returns the pixel dense embeddings for segmentation."""
    pass

get_transform() abstractmethod

Returns the transform function to be applied to the input images.

Source code in src/thunder/models/pretrained_models.py
16
17
18
19
@abstractmethod
def get_transform(self):
    """Returns the transform function to be applied to the input images."""
    pass

thunder.models.get_model_from_name(model_name, device)

Loading pretrained model from input name.

The list of all available models
  • uni
  • uni2h
  • virchow
  • virchow2
  • hoptimus0
  • hoptimus1
  • conch
  • titan
  • phikon
  • phikon2
  • hiboub
  • hiboul
  • midnight
  • keep
  • quiltb32
  • plip
  • musk
  • dinov2base
  • dinov2large
  • vitbasepatch16224in21k
  • vitlargepatch16224in21k
  • clipvitbasepatch32
  • clipvitlargepatch14

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
device str

Device to use (cpu, cuda).

required

Returns:

Name Type Description
model Module

Pytorch model instance.

transform Compose

Transform to apply to input image.

get_embeddings Callable

Function to extract embeddings.

output function get_embeddings signature.
  • src (torch.Tensor): Batch of transformed images with shape (B, 3, H, W).
  • pretrained_model (torch.nn.Module): Model to extract embeddings with.
  • pooled_emb (bool): Whether to output pooled (True) or spatial (False) embeddings.
Source code in src/thunder/models/pretrained_models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def get_model_from_name(model_name: str, device: str):
    """Loading pretrained model from input name.

    The list of all available models:
        * uni
        * uni2h
        * virchow
        * virchow2
        * hoptimus0
        * hoptimus1
        * conch
        * titan
        * phikon
        * phikon2
        * hiboub
        * hiboul
        * midnight
        * keep
        * quiltb32
        * plip
        * musk
        * dinov2base
        * dinov2large
        * vitbasepatch16224in21k
        * vitlargepatch16224in21k
        * clipvitbasepatch32
        * clipvitlargepatch14

    Args:
        model_name (str): The name of the model to use.
        device (str): Device to use (cpu, cuda).

    Returns:
        model (torch.nn.Module): Pytorch model instance.
        transform (torchvision.transforms.transforms.Compose): Transform to apply to input image.
        get_embeddings (Callable): Function to extract embeddings.

    Tip: output function `get_embeddings` signature.
        * src (torch.Tensor): Batch of transformed images with shape (B, 3, H, W).
        * pretrained_model (torch.nn.Module): Model to extract embeddings with.
        * pooled_emb (bool):  Whether to output pooled (True) or spatial (False) embeddings.
    """

    # Loading model config
    yaml_file = (
        f"{Path(__file__).parent.parent}/config/pretrained_model/{model_name}.yaml"
    )
    model_cfg = OmegaConf.load(yaml_file)

    # Getting model, transform, embedding extraction function
    model, transform, extract_embedding = get_model(model_cfg, device)

    # Defining wrapper function to get embeddings
    def get_embeddings(src, pretrained_model, pooled_emb=True):
        return extract_embedding(
            src,
            pretrained_model,
            task_type="linear_probing" if pooled_emb else "segmentation",
        )

    return model, transform, get_embeddings