SNP2P Trainer

Overview

This page documents the SNP2P training utilities, including helper losses and the main trainer class. The trainer orchestrates optimization, evaluation, and metric reporting for mixed phenotype types (quantitative and binary).

Usage and examples

Example: initialize and run training

from torch.utils.data import DataLoader
from src.utils.data.dataset import SNP2PCollator, PLINKDataset
from src.utils.trainer.snp2p_trainer import SNP2PTrainer

dataset = PLINKDataset(tree_parser, bfile="data/geno/plink_prefix", cov="cov.tsv", pheno="pheno.tsv")
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=SNP2PCollator(tree_parser))

trainer = SNP2PTrainer(
    snp2p_model=model,
    tree_parser=tree_parser,
    snp2p_dataloader=loader,
    device=device,
    args=args,
    target_phenotype="BMI",
)
trainer.train(epochs=10, output_path="checkpoints/snp2p")

API documentation

correlation_matching_loss(pred, target, lam=0.05)

Computes a correlation-structure matching loss between predictions and labels.

Parameters:
  • pred (torch.Tensor) – Model predictions ([B, P]).

  • target (torch.Tensor) – Target labels ([B, P]).

  • lam (float, optional) – Scaling factor for the loss.

Returns:

Scalar loss value.

Return type:

torch.Tensor

linear_temperature_schedule(epoch, total_epochs, T_init=1.0, T_final=0.1)

Linear temperature schedule used for annealing.

Parameters:
  • epoch (int) – Current epoch.

  • total_epochs (int) – Total training epochs.

  • T_init (float, optional) – Initial temperature.

  • T_final (float, optional) – Final temperature.

Returns:

Temperature for the epoch.

Return type:

float

get_param_groups(model, base_lr)

Split parameters into base and LoRA groups with different learning rates.

Parameters:
  • model (torch.nn.Module) – Model with named parameters.

  • base_lr (float) – Base learning rate.

Returns:

Parameter groups for an optimizer.

Return type:

list

class SNP2PTrainer

Trainer for SNP2P models with optional validation and MLflow logging.

Parameters:
  • snp2p_model (torch.nn.Module) – Model instance to train.

  • tree_parser (SNPTreeParser) – Parsed SNP ontology and masks.

  • snp2p_dataloader (torch.utils.data.DataLoader) – Training dataloader.

  • device (torch.device) – Device for model and tensors.

  • args (argparse.Namespace) – Training configuration namespace.

  • target_phenotype (str) – Target phenotype name or ID for logging.

  • validation_dataloader (torch.utils.data.DataLoader, optional) – Optional validation dataloader.

  • fix_system (bool, optional) – Whether to freeze system embeddings.

  • pretrain_dataloader (torch.utils.data.DataLoader, optional) – Optional pretraining dataloader.

  • label_smoothing (float, optional) – Label smoothing for phenotype loss.

  • use_mlflow (bool, optional) – Whether to log artifacts to MLflow.

train(epochs, output_path=None)

Run the training loop for the given number of epochs.

Parameters:
  • epochs (int) – Number of epochs to train.

  • output_path (str, optional) – Optional checkpoint path prefix.

evaluate(model, dataloader, epoch, phenotypes, name='Validation', print_importance=False, snp_only=False)

Evaluate a model on a dataloader and compute phenotype metrics.

Parameters:
  • model (torch.nn.Module) – Model to evaluate.

  • dataloader (torch.utils.data.DataLoader) – Dataloader to iterate over.

  • epoch (int) – Epoch index for logging.

  • phenotypes (list) – Phenotype IDs to evaluate.

  • name (str, optional) – Label used for logging output.

  • print_importance (bool, optional) – Whether to print attention importance scores.

  • snp_only (bool, optional) – Whether to evaluate SNP-only prediction.

Returns:

Aggregate performance score.

Return type:

float

evaluate_continuous_phenotype(trues, results, covariates=None, phenotype_name='', epoch=0, rank=0)

Compute regression metrics for continuous phenotypes.

Parameters:
  • trues (numpy.ndarray) – Ground-truth values.

  • results (numpy.ndarray) – Model predictions.

  • covariates (numpy.ndarray, optional) – Optional covariates for logging.

  • phenotype_name (str, optional) – Phenotype name for logging.

  • epoch (int, optional) – Epoch index for logging.

  • rank (int, optional) – Distributed rank for gating output.

evaluate_binary_phenotype(trues, results, covariates=None, phenotype_name='', epoch=0, rank=0)

Compute classification metrics for binary phenotypes.

Parameters:
  • trues (numpy.ndarray) – Ground-truth labels.

  • results (numpy.ndarray) – Model predictions.

  • covariates (numpy.ndarray, optional) – Optional covariates for logging.

  • phenotype_name (str, optional) – Phenotype name for logging.

  • epoch (int, optional) – Epoch index for logging.

  • rank (int, optional) – Distributed rank for gating output.

train_epoch(epoch, ccc=False, sex=False)

Train for a single epoch over the training dataloader.

Parameters:
  • epoch (int) – Epoch index.

  • ccc (bool, optional) – Whether to compute concordance correlation coefficient loss.

  • sex (bool, optional) – Whether to include sex-specific logic.

iter_minibatches(model, dataloader, optimizer, epoch, name='', snp_only=False, sex=False)

Iterate over minibatches and update model parameters.

Parameters:
  • model (torch.nn.Module) – Model to train.

  • dataloader (torch.utils.data.DataLoader) – Data loader to iterate.

  • optimizer (torch.optim.Optimizer) – Optimizer for parameter updates.

  • epoch (int) – Epoch index.

  • name (str, optional) – Label for progress logging.

  • snp_only (bool, optional) – Whether to train on SNP-only inputs.

  • sex (bool, optional) – Whether to include sex-specific logic.