SNP2P Model Utilities

Overview

This page documents model helper layers that are shared across SNP2P architectures.

Usage and examples

Example: apply a FiLM layer

import torch
from src.model.utils import FiLM

film = FiLM(in_cov=4, hid=16)
covariates = torch.randn(8, 4)
features = torch.randn(8, 10, 16)
modulated = film(features, covariates)

API documentation

class FiLM

Feature-wise linear modulation layer for injecting covariates.

Parameters:
  • in_cov (int) – Covariate input dimension.

  • hid (int) – Hidden dimension of the modulation.

class MoEHeadPrediction

Mixture-of-experts head that produces per-position scalar predictions.

Parameters:
  • hid (int) – Hidden dimension of token embeddings.

  • k_experts (int, optional) – Number of experts.

  • top_k (int, optional) – Number of experts to select per token.

class LayerNormNormedScaleOnly

Layer normalization variant with normalized scaling weights.

Parameters:
  • normalized_shape (int or tuple) – Shape of the input to normalize.

  • eps (float, optional) – Numerical stability term.

class RMSNorm

Root-mean-square normalization.

Parameters:
  • dim (int) – Feature dimension.

  • eps (float, optional) – Numerical stability term.

  • elementwise_affine (bool, optional) – Whether to learn a scale parameter.

  • memory_efficient (bool, optional) – Unused placeholder for compatibility.

class BatchNorm1d_BatchOnly_NLC

Batch-only normalization over [B, L, C] inputs.

Parameters:
  • num_features (int) – Number of feature channels.

  • eps (float, optional) – Numerical stability term.

  • momentum (float, optional) – Momentum for running statistics.

  • affine (bool, optional) – Whether to learn affine parameters.

  • track_running_stats (bool, optional) – Whether to track running mean/variance.

  • length (int, optional) – Sequence length used to size running statistics.