SNP2P Trainer Utilities
Overview
This page documents loss functions and utility helpers used by the SNP2P training loop.
Usage and examples
Example: compute focal loss
import torch
from src.utils.trainer.loss import FocalLoss
loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(8, 1)
targets = torch.randint(0, 2, (8, 1)).float()
loss = loss_fn(logits, targets)
API documentation
- class CCCLoss
Concordance correlation coefficient (CCC) loss for regression targets.
- class FocalLoss
Focal loss for imbalanced binary classification.
- class VarianceLoss
Matches the standard deviation of predictions to the targets within a batch.
- class MultiplePhenotypeLoss
Multi-task loss that applies BCE to binary phenotype indices and MSE to quantitative phenotype indices while masking missing values.
- class BCEWithLogitsLossWithLabelSmoothing
Binary cross-entropy loss with label smoothing.
- class EarlyStopping
Early stopping utility that tracks a validation score and restores the best weights when training stalls.
- Parameters:
patience (int, optional) – Number of epochs to wait for improvement.
min_delta (float, optional) – Minimum change to qualify as an improvement.
mode (str, optional) –
maxfor metrics to maximize,minfor metrics to minimize.restore_best_weights (bool, optional) – Whether to restore the best weights on stop.
verbose (bool, optional) – Whether to print progress messages.