Skip to main content
Ctrl+K

cellink

  • Tutorials
    • Tutorial: Pseudobulk eQTL Analysis with cellink
    • Tutorial: eQTL Analysis with JaxQTL and TensorQTL using cellink
    • Tutorial: Annotating Genetic Variants with cellink
    • Tutorial: Rare Variant Association Testing with cellink
    • Tutorial: LD Clumping and Identifying Independent Signals with cellink
    • Tutorial: Colocalization Analysis - Linking eQTLs to GWAS Signals with cellink
    • Tutorial: Integrating GWAS with Single-Cell Data using cellink
    • Tutorial: Spatially Resolved GWAS Mapping with gsMap
    • Tutorial: eQTL Analysis with SAIGE-QTL using cellink
    • Tutorial: Using EHR Data as Donor-Level Input in cellink
    • Tutorial: Using the MILDataset and PyTorch DataLoader in cellink
  • API
    • DonorData
      • cellink.DonorData
    • Preprocessing pp
      • cellink.pp.variant_qc
      • cellink.pp.cell_level_obs_filter
      • cellink.pp.donor_level_obs_filter
      • cellink.pp.donor_level_var_filter
      • cellink.pp.log_transform
      • cellink.pp.low_abundance_filter
      • cellink.pp.missing_values_filter
      • cellink.pp.normalize
    • Input-Output io
      • cellink.io.from_sgkit_dataset
      • cellink.io.read_plink
      • cellink.io.read_bgen
      • cellink.io.read_sgkit_zarr
      • cellink.io.read_pgen_zarr
      • cellink.io.stream_pgen_to_zarr
      • cellink.io.to_plink
      • cellink.io.write_variants_to_vcf
    • Tools tl
      • cellink.tl.get_snp_df
      • cellink.tl.run_favor
      • cellink.tl.run_snpeff
      • cellink.tl.run_vep
      • cellink.tl.add_vep_annos_to_gdata
      • cellink.tl.combine_annotations
      • cellink.tl.aggregate_annotations_for_varm
      • cellink.tl.run_burden_test
      • cellink.tl.run_skat_test
      • cellink.tl.beta_weighting
    • External tools tl.external
      • cellink.tl.external.calculate_ld
      • cellink.tl.external.run_jaxqtl
      • cellink.tl.external.read_jaxqtl_results
      • cellink.tl.external.run_mixmil
      • cellink.tl.external.calculate_pcs
      • cellink.tl.external.run_tensorqtl
      • cellink.tl.external.read_tensorqtl_results
      • cellink.tl.external.run_scdrs
      • cellink.tl.external.run_seismic
      • cellink.tl.external.run_magma_pipeline
      • cellink.tl.external.run_saigeqtl
      • cellink.tl.external.configure_saigeqtl_runner
      • cellink.tl.external.get_saigeqtl_runner
      • cellink.tl.external.make_group_file
      • cellink.tl.external.read_saigeqtl_results
      • cellink.tl.external.load_gsmap_results
      • cellink.tl.external.format_gsmap_sumstats
    • Plotting
      • cellink.pl.locus
      • cellink.pl.manhattan
      • cellink.pl.qq
      • cellink.pl.expression_by_genotype
      • cellink.pl.volcano
    • Machine Learning ml
      • cellink.ml.MILDataset
      • cellink.ml.mil_collate_fn
      • cellink.ml.DonorMILModel
    • Association Testing at
      • cellink.at.acat_test
      • cellink.at.compute_acat
      • cellink.at.GWAS
      • cellink.at.Skat
    • Utils
      • cellink.utils.column_normalize
      • cellink.utils.gaussianize
      • cellink.utils.one_hot_encode_genotypes
      • cellink.utils.dosage_per_strand
    • Resources
      • cellink.resources.get_1000genomes
      • cellink.resources.get_1000genomes_grch38
      • cellink.resources.get_dummy_onek1k
      • cellink.resources.get_onek1k
      • cellink.resources.get_eqtl_catalog_dataset_associations
      • cellink.resources.get_eqtl_catalog_datasets
      • cellink.resources.get_gwas_catalog_studies
      • cellink.resources.get_gwas_catalog_study
      • cellink.resources.get_gwas_catalog_study_summary_stats
      • cellink.resources.get_pgs_catalog_score
      • cellink.resources.get_pgs_catalog_scores
      • cellink.resources.get_1000genomes_ld_scores
      • cellink.resources.get_1000genomes_ld_weights
  • Changelog
  • Contributing guide
  • References
  • .ipynb

Tutorial: Using the MILDataset and PyTorch DataLoader in cellink

Contents

  • Setup and Configuration
  • Load Genotype Data (gdata)
  • Wrap the Data with MILDataset
  • Create a PyTorch DataLoader
  • Initialize and Train a Model

Tutorial: Using the MILDataset and PyTorch DataLoader in cellink#

This tutorial shows how to use the MILDataset together with PyTorch’s DataLoader to prepare multimodal input for model training in the cellink framework. This pipeline is designed to support donor-level learning using single-cell transcriptomic measurements.

We will use the OneK1K dataset and demonstrate how to:

  • Filter for CD8 Naive T cells,

  • Package the data using DonorData,

  • Wrap it into a MILDataset,

  • Load it with a PyTorch DataLoader,

  • Train a DonorMILModel.

Setup and Configuration#

We start by importing relevant modules and creating local directories to store input/output files. This ensures that any annotation tools have a consistent file structure to work with.

from cellink.resources import get_dummy_onek1k
from cellink.ml.dataset import MILDataset
from torch.utils.data import DataLoader
from cellink.ml.dataset import mil_collate_fn
from cellink.ml.model import DonorMILModel
import pytorch_lightning as pl
import numpy as np

Load Genotype Data (gdata)#

We load the example dataset using get_dummy_onek1k() (This is a subset of the full OneK1K dataset, which can be downloaded, and prepared using get_onek1k()). The DonorData object contains a .G attribute (gdata) that stores genotype information at the variant level. These variants will be the target of our annotations. We filter for chromosome 22 and CD8 Naive for fast execution of the notebook.

dd = get_dummy_onek1k(config_path="../../src/cellink/resources/config/dummy_onek1k.yaml")
dd
[2025-12-29 01:31:21,754] INFO:root: /Users/larnoldt/cellink_data/dummy_onek1k/dummy_onek1k.dd.h5 already exists
[2025-12-29 01:31:21,755] INFO:root: Veryifying checksum
[2025-12-29 01:31:23,226] INFO:root: Loaded dummy OneK1K dataset: (100, 146939, 125366, 34073)
╔═ DonorData(n_donors=100, n_cells_per_donor=[613-2,731], donor_id='donor_id') ═══════════════════════════════╗
║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║
║ ┃ G (donors)                                         ┃ C (cells)                                          ┃ ║
║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║
║ │ AnnData object with n_obs × n_vars = 100 × 146,939 │ View of AnnData object with n_obs × n_vars =       │ ║
║ │                                                    │ 125,366 × 34,073                                   │ ║
║ │     var: 'chrom', 'pos', 'a0', 'a1', 'AC',         │     obs: 'orig.ident', 'nCount_RNA',               │ ║
║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2',  │ 'nFeature_RNA', 'percent.mt', 'donor_id',          │ ║
║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2',     │ 'pool_number', 'predicted.celltype.l2',            │ ║
║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length',  │ 'predicted.celltype.l2.score', 'age',              │ ║
║ │ 'quality', 'pos_hg19', 'id_hg19'                   │ 'organism_ontology_term_id',                       │ ║
║ │                                                    │ 'tissue_ontology_term_id',                         │ ║
║ │                                                    │ 'assay_ontology_term_id',                          │ ║
║ │                                                    │ 'disease_ontology_term_id',                        │ ║
║ │                                                    │ 'cell_type_ontology_term_id',                      │ ║
║ │                                                    │ 'self_reported_ethnicity_ontology_term_id',        │ ║
║ │                                                    │ 'development_stage_ontology_term_id',              │ ║
║ │                                                    │ 'sex_ontology_term_id', 'is_primary_data',         │ ║
║ │                                                    │ 'suspension_type', 'tissue_type', 'cell_type',     │ ║
║ │                                                    │ 'assay', 'disease', 'organism', 'sex', 'tissue',   │ ║
║ │                                                    │ 'self_reported_ethnicity', 'development_stage',    │ ║
║ │                                                    │ 'observation_joinid'                               │ ║
║ │     uns: 'kinship'                                 │     var: 'vst.mean', 'vst.variance',               │ ║
║ │                                                    │ 'vst.variance.expected',                           │ ║
║ │                                                    │ 'vst.variance.standardized', 'vst.variable',       │ ║
║ │                                                    │ 'feature_is_filtered', 'feature_name',             │ ║
║ │                                                    │ 'feature_reference', 'feature_biotype',            │ ║
║ │                                                    │ 'feature_length', 'feature_type', 'start', 'end',  │ ║
║ │                                                    │ 'chrom'                                            │ ║
║ │     obsm: 'gPCs'                                   │     uns: 'cell_type_ontology_term_id_colors',      │ ║
║ │                                                    │ 'citation', 'default_embedding',                   │ ║
║ │                                                    │ 'schema_reference', 'schema_version', 'title'      │ ║
║ │     varm: 'filter'                                 │     obsm: 'X_azimuth_spca', 'X_azimuth_umap',      │ ║
║ │                                                    │ 'X_harmony', 'X_pca', 'X_umap'                     │ ║
║ │                                                    │     varm: 'PCs'                                    │ ║
║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝

chrom = 22
dd = dd.sel(G_var=dd.G.var.chrom == str(chrom), C_var=dd.C.var.chrom == str(chrom)).copy()
dd
╔═ DonorData(n_donors=100, n_cells_per_donor=[613-2,731], donor_id='donor_id') ═══════════════════════════════╗
║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║
║ ┃ G (donors)                                         ┃ C (cells)                                          ┃ ║
║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║
║ │ AnnData object with n_obs × n_vars = 100 × 136,776 │ AnnData object with n_obs × n_vars = 125,366 × 871 │ ║
║ │     var: 'chrom', 'pos', 'a0', 'a1', 'AC',         │     obs: 'orig.ident', 'nCount_RNA',               │ ║
║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2',  │ 'nFeature_RNA', 'percent.mt', 'donor_id',          │ ║
║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2',     │ 'pool_number', 'predicted.celltype.l2',            │ ║
║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length',  │ 'predicted.celltype.l2.score', 'age',              │ ║
║ │ 'quality', 'pos_hg19', 'id_hg19'                   │ 'organism_ontology_term_id',                       │ ║
║ │                                                    │ 'tissue_ontology_term_id',                         │ ║
║ │                                                    │ 'assay_ontology_term_id',                          │ ║
║ │                                                    │ 'disease_ontology_term_id',                        │ ║
║ │                                                    │ 'cell_type_ontology_term_id',                      │ ║
║ │                                                    │ 'self_reported_ethnicity_ontology_term_id',        │ ║
║ │                                                    │ 'development_stage_ontology_term_id',              │ ║
║ │                                                    │ 'sex_ontology_term_id', 'is_primary_data',         │ ║
║ │                                                    │ 'suspension_type', 'tissue_type', 'cell_type',     │ ║
║ │                                                    │ 'assay', 'disease', 'organism', 'sex', 'tissue',   │ ║
║ │                                                    │ 'self_reported_ethnicity', 'development_stage',    │ ║
║ │                                                    │ 'observation_joinid'                               │ ║
║ │     uns: 'kinship'                                 │     var: 'vst.mean', 'vst.variance',               │ ║
║ │                                                    │ 'vst.variance.expected',                           │ ║
║ │                                                    │ 'vst.variance.standardized', 'vst.variable',       │ ║
║ │                                                    │ 'feature_is_filtered', 'feature_name',             │ ║
║ │                                                    │ 'feature_reference', 'feature_biotype',            │ ║
║ │                                                    │ 'feature_length', 'feature_type', 'start', 'end',  │ ║
║ │                                                    │ 'chrom'                                            │ ║
║ │     obsm: 'gPCs'                                   │     uns: 'cell_type_ontology_term_id_colors',      │ ║
║ │                                                    │ 'citation', 'default_embedding',                   │ ║
║ │                                                    │ 'schema_reference', 'schema_version', 'title'      │ ║
║ │     varm: 'filter'                                 │     obsm: 'X_azimuth_spca', 'X_azimuth_umap',      │ ║
║ │                                                    │ 'X_harmony', 'X_pca', 'X_umap'                     │ ║
║ │                                                    │     varm: 'PCs'                                    │ ║
║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝

cell_type = "CD8 Naive"
celltype_key = "predicted.celltype.l2"
dd = dd[..., dd.C.obs[celltype_key] == cell_type, :].copy()
dd
╔═ DonorData(n_donors=100, n_cells_per_donor=[1-302], donor_id='donor_id') ═══════════════════════════════════╗
║ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ║
║ ┃ G (donors)                                         ┃ C (cells)                                          ┃ ║
║ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ ║
║ │ AnnData object with n_obs × n_vars = 100 × 136,776 │ AnnData object with n_obs × n_vars = 4,756 × 871   │ ║
║ │     var: 'chrom', 'pos', 'a0', 'a1', 'AC',         │     obs: 'orig.ident', 'nCount_RNA',               │ ║
║ │ 'AC_Hemi', 'AC_Het', 'AC_Hom', 'AF', 'AN', 'ER2',  │ 'nFeature_RNA', 'percent.mt', 'donor_id',          │ ║
║ │ 'ExcHet', 'HWE', 'IMPUTED', 'maf', 'NS', 'R2',     │ 'pool_number', 'predicted.celltype.l2',            │ ║
║ │ 'TYPED', 'TYPED_ONLY', 'id', 'id_mask', 'length',  │ 'predicted.celltype.l2.score', 'age',              │ ║
║ │ 'quality', 'pos_hg19', 'id_hg19'                   │ 'organism_ontology_term_id',                       │ ║
║ │                                                    │ 'tissue_ontology_term_id',                         │ ║
║ │                                                    │ 'assay_ontology_term_id',                          │ ║
║ │                                                    │ 'disease_ontology_term_id',                        │ ║
║ │                                                    │ 'cell_type_ontology_term_id',                      │ ║
║ │                                                    │ 'self_reported_ethnicity_ontology_term_id',        │ ║
║ │                                                    │ 'development_stage_ontology_term_id',              │ ║
║ │                                                    │ 'sex_ontology_term_id', 'is_primary_data',         │ ║
║ │                                                    │ 'suspension_type', 'tissue_type', 'cell_type',     │ ║
║ │                                                    │ 'assay', 'disease', 'organism', 'sex', 'tissue',   │ ║
║ │                                                    │ 'self_reported_ethnicity', 'development_stage',    │ ║
║ │                                                    │ 'observation_joinid'                               │ ║
║ │     uns: 'kinship'                                 │     var: 'vst.mean', 'vst.variance',               │ ║
║ │                                                    │ 'vst.variance.expected',                           │ ║
║ │                                                    │ 'vst.variance.standardized', 'vst.variable',       │ ║
║ │                                                    │ 'feature_is_filtered', 'feature_name',             │ ║
║ │                                                    │ 'feature_reference', 'feature_biotype',            │ ║
║ │                                                    │ 'feature_length', 'feature_type', 'start', 'end',  │ ║
║ │                                                    │ 'chrom'                                            │ ║
║ │     obsm: 'gPCs'                                   │     uns: 'cell_type_ontology_term_id_colors',      │ ║
║ │                                                    │ 'citation', 'default_embedding',                   │ ║
║ │                                                    │ 'schema_reference', 'schema_version', 'title'      │ ║
║ │     varm: 'filter'                                 │     obsm: 'X_azimuth_spca', 'X_azimuth_umap',      │ ║
║ │                                                    │ 'X_harmony', 'X_pca', 'X_umap'                     │ ║
║ │                                                    │     varm: 'PCs'                                    │ ║
║ └────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════╝

Wrap the Data with MILDataset#

We use the MILDataset to create a wrapper that returns bags of cells for each donor. This enables multiple instance learning (MIL) with donor labels. The MILDataset automatically packages all labels, categorical and continuous covariates and data matrices, when available. You may adjust the keys. Please note, that for demosntration purposes we are randomly generating labels now.

dd.G.obs["donor_id"] = dd.G.obs.index
dd.G.obs["donor_labels"] = np.random.randint(2, size=len(dd.G.obs))
dd.C.obs["pool_number"] = dd.C.obs["pool_number"].astype("float")
dataset = MILDataset(
    dd,
    donor_labels_key="donor_labels",
    cell_batch_key="pool_number",
    # split_donors=["OneK1K_1", "OneK1K_10", "OneK1K_1000"],
    split_indices=list(range(10)),
)

Create a PyTorch DataLoader#

We now build a PyTorch DataLoader that will feed batches of donors (each with a variable number of cells) into the model. We use the mil_collate_fn to handle the collation of variable-length inputs.

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=mil_collate_fn)

Initialize and Train a Model#

We now create a DonorMILModel, specifying the dimensionality of the donor and cell input features. We then use PyTorch Lightning to train for a single epoch.

model = DonorMILModel(n_input_donor=dd.G.n_vars, n_input_cell=dd.C.n_vars)
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, dataloader)
[2025-12-29 01:31:42,127] INFO:pytorch_lightning.utilities.rank_zero: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
[2025-12-29 01:31:42,191] INFO:pytorch_lightning.utilities.rank_zero: GPU available: True (mps), used: True
[2025-12-29 01:31:42,191] INFO:pytorch_lightning.utilities.rank_zero: TPU available: False, using: 0 TPU cores
[2025-12-29 01:31:42,192] INFO:pytorch_lightning.utilities.rank_zero: HPU available: False, using: 0 HPUs
[2025-12-29 01:32:10,856] INFO:pytorch_lightning.callbacks.model_summary: 
  | Name          | Type       | Params | Mode 
-----------------------------------------------------
0 | donor_encoder | Linear     | 17.5 M | train
1 | cell_encoder  | Sequential | 111 K  | train
2 | attention     | Sequential | 8.3 K  | train
3 | classifier    | Sequential | 257    | train
-----------------------------------------------------
17.6 M    Trainable params
0         Non-trainable params
17.6 M    Total params
70.511    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode
[2025-12-29 01:32:14,654] INFO:pytorch_lightning.utilities.rank_zero: `Trainer.fit` stopped: `max_epochs=2` reached.

previous

Tutorial: Using EHR Data as Donor-Level Input in cellink

next

API

Contents
  • Setup and Configuration
  • Load Genotype Data (gdata)
  • Wrap the Data with MILDataset
  • Create a PyTorch DataLoader
  • Initialize and Train a Model

By Jan Engelmann, Lucas Arnoldt, Eva Holtkamp

© Copyright 2026, Theislab..