Tutorial: Using the MILDataset and PyTorch DataLoader in cellink#
This tutorial shows how to use the MILDataset together with PyTorch’s DataLoader to prepare multimodal input for model training in the cellink framework. This pipeline is designed to support donor-level learning using single-cell transcriptomic measurements.
We will use the OneK1K dataset and demonstrate how to:
Filter for CD8 Naive T cells,
Package the data using
DonorData,Wrap it into a
MILDataset,Load it with a PyTorch
DataLoader,Train a
DonorMILModel.
Setup and Configuration#
We start by importing relevant modules and creating local directories to store input/output files. This ensures that any annotation tools have a consistent file structure to work with.
from cellink.resources import get_dummy_onek1k
from cellink.ml.dataset import MILDataset
from torch.utils.data import DataLoader
from cellink.ml.dataset import mil_collate_fn
from cellink.ml.model import DonorMILModel
import pytorch_lightning as pl
import numpy as np
Load Genotype Data (gdata)#
We load the example dataset using get_dummy_onek1k() (This is a subset of the full OneK1K dataset, which can be downloaded, and prepared using get_onek1k()). The DonorData object contains a .G attribute (gdata) that stores genotype information at the variant level. These variants will be the target of our annotations. We filter for chromosome 22 and CD8 Naive for fast execution of the notebook.
dd = get_dummy_onek1k(config_path="../../src/cellink/resources/config/dummy_onek1k.yaml")
dd
[2025-12-29 01:31:21,754] INFO:root: /Users/larnoldt/cellink_data/dummy_onek1k/dummy_onek1k.dd.h5 already exists
[2025-12-29 01:31:21,755] INFO:root: Veryifying checksum
[2025-12-29 01:31:23,226] INFO:root: Loaded dummy OneK1K dataset: (100, 146939, 125366, 34073)
╔═ DonorData(n_donors=100, n_cells_per_donor=[613-2,731], donor_id='donor_id') ═══════════════════════════════╗ ║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║ ║ ┃ G (donors) ┃ C (cells) ┃ ║ ║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║ ║ │ AnnData object with n_obs × n_vars = 100 × 146,939 │ View of AnnData object with n_obs × n_vars = │ ║ ║ │ │ 125,366 × 34,073 │ ║ ║ │ var: 'chrom', 'pos', 'a0', 'a1', 'AC', │ obs: 'orig.ident', 'nCount_RNA', │ ║ ║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2', │ 'nFeature_RNA', 'percent.mt', 'donor_id', │ ║ ║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2', │ 'pool_number', 'predicted.celltype.l2', │ ║ ║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length', │ 'predicted.celltype.l2.score', 'age', │ ║ ║ │ 'quality', 'pos_hg19', 'id_hg19' │ 'organism_ontology_term_id', │ ║ ║ │ │ 'tissue_ontology_term_id', │ ║ ║ │ │ 'assay_ontology_term_id', │ ║ ║ │ │ 'disease_ontology_term_id', │ ║ ║ │ │ 'cell_type_ontology_term_id', │ ║ ║ │ │ 'self_reported_ethnicity_ontology_term_id', │ ║ ║ │ │ 'development_stage_ontology_term_id', │ ║ ║ │ │ 'sex_ontology_term_id', 'is_primary_data', │ ║ ║ │ │ 'suspension_type', 'tissue_type', 'cell_type', │ ║ ║ │ │ 'assay', 'disease', 'organism', 'sex', 'tissue', │ ║ ║ │ │ 'self_reported_ethnicity', 'development_stage', │ ║ ║ │ │ 'observation_joinid' │ ║ ║ │ uns: 'kinship' │ var: 'vst.mean', 'vst.variance', │ ║ ║ │ │ 'vst.variance.expected', │ ║ ║ │ │ 'vst.variance.standardized', 'vst.variable', │ ║ ║ │ │ 'feature_is_filtered', 'feature_name', │ ║ ║ │ │ 'feature_reference', 'feature_biotype', │ ║ ║ │ │ 'feature_length', 'feature_type', 'start', 'end', │ ║ ║ │ │ 'chrom' │ ║ ║ │ obsm: 'gPCs' │ uns: 'cell_type_ontology_term_id_colors', │ ║ ║ │ │ 'citation', 'default_embedding', │ ║ ║ │ │ 'schema_reference', 'schema_version', 'title' │ ║ ║ │ varm: 'filter' │ obsm: 'X_azimuth_spca', 'X_azimuth_umap', │ ║ ║ │ │ 'X_harmony', 'X_pca', 'X_umap' │ ║ ║ │ │ varm: 'PCs' │ ║ ║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║ ╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
chrom = 22
dd = dd.sel(G_var=dd.G.var.chrom == str(chrom), C_var=dd.C.var.chrom == str(chrom)).copy()
dd
╔═ DonorData(n_donors=100, n_cells_per_donor=[613-2,731], donor_id='donor_id') ═══════════════════════════════╗ ║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║ ║ ┃ G (donors) ┃ C (cells) ┃ ║ ║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║ ║ │ AnnData object with n_obs × n_vars = 100 × 136,776 │ AnnData object with n_obs × n_vars = 125,366 × 871 │ ║ ║ │ var: 'chrom', 'pos', 'a0', 'a1', 'AC', │ obs: 'orig.ident', 'nCount_RNA', │ ║ ║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2', │ 'nFeature_RNA', 'percent.mt', 'donor_id', │ ║ ║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2', │ 'pool_number', 'predicted.celltype.l2', │ ║ ║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length', │ 'predicted.celltype.l2.score', 'age', │ ║ ║ │ 'quality', 'pos_hg19', 'id_hg19' │ 'organism_ontology_term_id', │ ║ ║ │ │ 'tissue_ontology_term_id', │ ║ ║ │ │ 'assay_ontology_term_id', │ ║ ║ │ │ 'disease_ontology_term_id', │ ║ ║ │ │ 'cell_type_ontology_term_id', │ ║ ║ │ │ 'self_reported_ethnicity_ontology_term_id', │ ║ ║ │ │ 'development_stage_ontology_term_id', │ ║ ║ │ │ 'sex_ontology_term_id', 'is_primary_data', │ ║ ║ │ │ 'suspension_type', 'tissue_type', 'cell_type', │ ║ ║ │ │ 'assay', 'disease', 'organism', 'sex', 'tissue', │ ║ ║ │ │ 'self_reported_ethnicity', 'development_stage', │ ║ ║ │ │ 'observation_joinid' │ ║ ║ │ uns: 'kinship' │ var: 'vst.mean', 'vst.variance', │ ║ ║ │ │ 'vst.variance.expected', │ ║ ║ │ │ 'vst.variance.standardized', 'vst.variable', │ ║ ║ │ │ 'feature_is_filtered', 'feature_name', │ ║ ║ │ │ 'feature_reference', 'feature_biotype', │ ║ ║ │ │ 'feature_length', 'feature_type', 'start', 'end', │ ║ ║ │ │ 'chrom' │ ║ ║ │ obsm: 'gPCs' │ uns: 'cell_type_ontology_term_id_colors', │ ║ ║ │ │ 'citation', 'default_embedding', │ ║ ║ │ │ 'schema_reference', 'schema_version', 'title' │ ║ ║ │ varm: 'filter' │ obsm: 'X_azimuth_spca', 'X_azimuth_umap', │ ║ ║ │ │ 'X_harmony', 'X_pca', 'X_umap' │ ║ ║ │ │ varm: 'PCs' │ ║ ║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║ ╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
cell_type = "CD8 Naive"
celltype_key = "predicted.celltype.l2"
dd = dd[..., dd.C.obs[celltype_key] == cell_type, :].copy()
dd
╔═ DonorData(n_donors=100, n_cells_per_donor=[1-302], donor_id='donor_id') ═══════════════════════════════════╗ ║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║ ║ ┃ G (donors) ┃ C (cells) ┃ ║ ║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║ ║ │ AnnData object with n_obs × n_vars = 100 × 136,776 │ AnnData object with n_obs × n_vars = 4,756 × 871 │ ║ ║ │ var: 'chrom', 'pos', 'a0', 'a1', 'AC', │ obs: 'orig.ident', 'nCount_RNA', │ ║ ║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2', │ 'nFeature_RNA', 'percent.mt', 'donor_id', │ ║ ║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2', │ 'pool_number', 'predicted.celltype.l2', │ ║ ║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length', │ 'predicted.celltype.l2.score', 'age', │ ║ ║ │ 'quality', 'pos_hg19', 'id_hg19' │ 'organism_ontology_term_id', │ ║ ║ │ │ 'tissue_ontology_term_id', │ ║ ║ │ │ 'assay_ontology_term_id', │ ║ ║ │ │ 'disease_ontology_term_id', │ ║ ║ │ │ 'cell_type_ontology_term_id', │ ║ ║ │ │ 'self_reported_ethnicity_ontology_term_id', │ ║ ║ │ │ 'development_stage_ontology_term_id', │ ║ ║ │ │ 'sex_ontology_term_id', 'is_primary_data', │ ║ ║ │ │ 'suspension_type', 'tissue_type', 'cell_type', │ ║ ║ │ │ 'assay', 'disease', 'organism', 'sex', 'tissue', │ ║ ║ │ │ 'self_reported_ethnicity', 'development_stage', │ ║ ║ │ │ 'observation_joinid' │ ║ ║ │ uns: 'kinship' │ var: 'vst.mean', 'vst.variance', │ ║ ║ │ │ 'vst.variance.expected', │ ║ ║ │ │ 'vst.variance.standardized', 'vst.variable', │ ║ ║ │ │ 'feature_is_filtered', 'feature_name', │ ║ ║ │ │ 'feature_reference', 'feature_biotype', │ ║ ║ │ │ 'feature_length', 'feature_type', 'start', 'end', │ ║ ║ │ │ 'chrom' │ ║ ║ │ obsm: 'gPCs' │ uns: 'cell_type_ontology_term_id_colors', │ ║ ║ │ │ 'citation', 'default_embedding', │ ║ ║ │ │ 'schema_reference', 'schema_version', 'title' │ ║ ║ │ varm: 'filter' │ obsm: 'X_azimuth_spca', 'X_azimuth_umap', │ ║ ║ │ │ 'X_harmony', 'X_pca', 'X_umap' │ ║ ║ │ │ varm: 'PCs' │ ║ ║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║ ╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
Wrap the Data with MILDataset#
We use the MILDataset to create a wrapper that returns bags of cells for each donor. This enables multiple instance learning (MIL) with donor labels. The MILDataset automatically packages all labels, categorical and continuous covariates and data matrices, when available. You may adjust the keys. Please note, that for demosntration purposes we are randomly generating labels now.
dd.G.obs["donor_id"] = dd.G.obs.index
dd.G.obs["donor_labels"] = np.random.randint(2, size=len(dd.G.obs))
dd.C.obs["pool_number"] = dd.C.obs["pool_number"].astype("float")
dataset = MILDataset(
dd,
donor_labels_key="donor_labels",
cell_batch_key="pool_number",
# split_donors=["OneK1K_1", "OneK1K_10", "OneK1K_1000"],
split_indices=list(range(10)),
)
Create a PyTorch DataLoader#
We now build a PyTorch DataLoader that will feed batches of donors (each with a variable number of cells) into the model. We use the mil_collate_fn to handle the collation of variable-length inputs.
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=mil_collate_fn)
Initialize and Train a Model#
We now create a DonorMILModel, specifying the dimensionality of the donor and cell input features. We then use PyTorch Lightning to train for a single epoch.
model = DonorMILModel(n_input_donor=dd.G.n_vars, n_input_cell=dd.C.n_vars)
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, dataloader)
[2025-12-29 01:31:42,127] INFO:pytorch_lightning.utilities.rank_zero: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
[2025-12-29 01:31:42,191] INFO:pytorch_lightning.utilities.rank_zero: GPU available: True (mps), used: True
[2025-12-29 01:31:42,191] INFO:pytorch_lightning.utilities.rank_zero: TPU available: False, using: 0 TPU cores
[2025-12-29 01:31:42,192] INFO:pytorch_lightning.utilities.rank_zero: HPU available: False, using: 0 HPUs
[2025-12-29 01:32:10,856] INFO:pytorch_lightning.callbacks.model_summary:
| Name | Type | Params | Mode
-----------------------------------------------------
0 | donor_encoder | Linear | 17.5 M | train
1 | cell_encoder | Sequential | 111 K | train
2 | attention | Sequential | 8.3 K | train
3 | classifier | Sequential | 257 | train
-----------------------------------------------------
17.6 M Trainable params
0 Non-trainable params
17.6 M Total params
70.511 Total estimated model params size (MB)
11 Modules in train mode
0 Modules in eval mode
[2025-12-29 01:32:14,654] INFO:pytorch_lightning.utilities.rank_zero: `Trainer.fit` stopped: `max_epochs=2` reached.