SNP2P Trainer Utilities

Overview

This page documents loss functions and utility helpers used by the SNP2P training loop.

Usage and examples

Example: compute focal loss

import torch
from src.utils.trainer.loss import FocalLoss

loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(8, 1)
targets = torch.randint(0, 2, (8, 1)).float()
loss = loss_fn(logits, targets)

API documentation

class CCCLoss

Concordance correlation coefficient (CCC) loss for regression targets.

Parameters:
  • eps (float, optional) – Numerical stability term.

  • mean_diff (bool, optional) – Whether to include mean difference in the denominator.

class FocalLoss

Focal loss for imbalanced binary classification.

Parameters:
  • alpha (float, optional) – Weight for positive examples.

  • gamma (float, optional) – Focusing parameter.

  • reduction (str, optional) – Reduction method (mean or sum).

class VarianceLoss

Matches the standard deviation of predictions to the targets within a batch.

class MultiplePhenotypeLoss

Multi-task loss that applies BCE to binary phenotype indices and MSE to quantitative phenotype indices while masking missing values.

Parameters:
  • bce_cols (list) – Column indices for BCE loss.

  • mse_cols (list) – Column indices for MSE loss.

  • label_smoothing (float, optional) – Optional label smoothing for BCE.

class BCEWithLogitsLossWithLabelSmoothing

Binary cross-entropy loss with label smoothing.

Parameters:
  • alpha (float, optional) – Smoothing factor.

  • reduction (str, optional) – Reduction method (mean or sum).

class EarlyStopping

Early stopping utility that tracks a validation score and restores the best weights when training stalls.

Parameters:
  • patience (int, optional) – Number of epochs to wait for improvement.

  • min_delta (float, optional) – Minimum change to qualify as an improvement.

  • mode (str, optional) – max for metrics to maximize, min for metrics to minimize.

  • restore_best_weights (bool, optional) – Whether to restore the best weights on stop.

  • verbose (bool, optional) – Whether to print progress messages.

class TrainingEfficiencyManager

Cache-aware helper that supports adaptive learning rates and transfer learning across phenotype combinations.

Parameters:
  • base_lr (float, optional) – Base learning rate.

  • warmup_epochs (int, optional) – Warmup epochs for scheduling.