SNP2P Trainer
Overview
This page documents the SNP2P training utilities, including helper losses and the main trainer class. The trainer orchestrates optimization, evaluation, and metric reporting for mixed phenotype types (quantitative and binary).
Usage and examples
Example: initialize and run training
from torch.utils.data import DataLoader
from src.utils.data.dataset import SNP2PCollator, PLINKDataset
from src.utils.trainer.snp2p_trainer import SNP2PTrainer
dataset = PLINKDataset(tree_parser, bfile="data/geno/plink_prefix", cov="cov.tsv", pheno="pheno.tsv")
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=SNP2PCollator(tree_parser))
trainer = SNP2PTrainer(
snp2p_model=model,
tree_parser=tree_parser,
snp2p_dataloader=loader,
device=device,
args=args,
target_phenotype="BMI",
)
trainer.train(epochs=10, output_path="checkpoints/snp2p")
API documentation
- correlation_matching_loss(pred, target, lam=0.05)
Computes a correlation-structure matching loss between predictions and labels.
- Parameters:
pred (torch.Tensor) – Model predictions (
[B, P]).target (torch.Tensor) – Target labels (
[B, P]).lam (float, optional) – Scaling factor for the loss.
- Returns:
Scalar loss value.
- Return type:
torch.Tensor
- linear_temperature_schedule(epoch, total_epochs, T_init=1.0, T_final=0.1)
Linear temperature schedule used for annealing.
- get_param_groups(model, base_lr)
Split parameters into base and LoRA groups with different learning rates.
- class SNP2PTrainer
Trainer for SNP2P models with optional validation and MLflow logging.
- Parameters:
snp2p_model (torch.nn.Module) – Model instance to train.
tree_parser (SNPTreeParser) – Parsed SNP ontology and masks.
snp2p_dataloader (torch.utils.data.DataLoader) – Training dataloader.
device (torch.device) – Device for model and tensors.
args (argparse.Namespace) – Training configuration namespace.
target_phenotype (str) – Target phenotype name or ID for logging.
validation_dataloader (torch.utils.data.DataLoader, optional) – Optional validation dataloader.
fix_system (bool, optional) – Whether to freeze system embeddings.
pretrain_dataloader (torch.utils.data.DataLoader, optional) – Optional pretraining dataloader.
label_smoothing (float, optional) – Label smoothing for phenotype loss.
use_mlflow (bool, optional) – Whether to log artifacts to MLflow.
- train(epochs, output_path=None)
Run the training loop for the given number of epochs.
- evaluate(model, dataloader, epoch, phenotypes, name='Validation', print_importance=False, snp_only=False)
Evaluate a model on a dataloader and compute phenotype metrics.
- Parameters:
model (torch.nn.Module) – Model to evaluate.
dataloader (torch.utils.data.DataLoader) – Dataloader to iterate over.
epoch (int) – Epoch index for logging.
phenotypes (list) – Phenotype IDs to evaluate.
name (str, optional) – Label used for logging output.
print_importance (bool, optional) – Whether to print attention importance scores.
snp_only (bool, optional) – Whether to evaluate SNP-only prediction.
- Returns:
Aggregate performance score.
- Return type:
- evaluate_continuous_phenotype(trues, results, covariates=None, phenotype_name='', epoch=0, rank=0)
Compute regression metrics for continuous phenotypes.
- Parameters:
trues (numpy.ndarray) – Ground-truth values.
results (numpy.ndarray) – Model predictions.
covariates (numpy.ndarray, optional) – Optional covariates for logging.
phenotype_name (str, optional) – Phenotype name for logging.
epoch (int, optional) – Epoch index for logging.
rank (int, optional) – Distributed rank for gating output.
- evaluate_binary_phenotype(trues, results, covariates=None, phenotype_name='', epoch=0, rank=0)
Compute classification metrics for binary phenotypes.
- Parameters:
trues (numpy.ndarray) – Ground-truth labels.
results (numpy.ndarray) – Model predictions.
covariates (numpy.ndarray, optional) – Optional covariates for logging.
phenotype_name (str, optional) – Phenotype name for logging.
epoch (int, optional) – Epoch index for logging.
rank (int, optional) – Distributed rank for gating output.
- train_epoch(epoch, ccc=False, sex=False)
Train for a single epoch over the training dataloader.
- iter_minibatches(model, dataloader, optimizer, epoch, name='', snp_only=False, sex=False)
Iterate over minibatches and update model parameters.
- Parameters:
model (torch.nn.Module) – Model to train.
dataloader (torch.utils.data.DataLoader) – Data loader to iterate.
optimizer (torch.optim.Optimizer) – Optimizer for parameter updates.
epoch (int) – Epoch index.
name (str, optional) – Label for progress logging.
snp_only (bool, optional) – Whether to train on SNP-only inputs.
sex (bool, optional) – Whether to include sex-specific logic.