SNP2Phenotype

Overview

A hierarchical transformer model to predict phenotypes from genotypes, guided by a biological ontology. This model translates SNP-level genetic information up through a biological hierarchy (SNPs -> Genes -> Biological Systems) to predict one or more phenotypes.

Usage and examples

Example: initialize the model

from g2pt.tree import SNPTreeParser
from src.model.model.snp2phenotype import SNP2PhenotypeModel

tree_parser = SNPTreeParser(
    ontology="ontology.tsv",
    snp2gene="snp2gene.tsv",
)

model = SNP2PhenotypeModel(
    tree_parser=tree_parser,
    hidden_dims=128,
    n_covariates=4,
    n_phenotypes=1,
)

API documentation

class SNP2PhenotypeModel

A hierarchical transformer model to predict phenotypes from genotypes, guided by a biological ontology.

This model translates SNP-level genetic information up through a biological hierarchy (SNPs -> Genes -> Biological Systems) to predict one or more phenotypes. It uses a series of transformer-based modules to propagate information and learn context-aware embeddings at each level of the hierarchy.

The core workflow is as follows: 1. Embedding: SNPs, genes, systems, and phenotypes are embedded into a high-dimensional space. 2. Propagation: Information flows up the hierarchy. SNP effects are propagated to genes,

gene effects are propagated to systems, and system-system interactions are resolved.

  1. Prediction: The final embeddings for genes and/or systems are used to predict the phenotype, modulated by covariate information.

Parameters:
  • tree_parser (SNPTreeParser) – An object that provides the hierarchical structure (SNP-gene-system mappings) and corresponding masks for the model.

  • hidden_dims (int) – The dimensionality of the embeddings and hidden layers.

  • snp2pheno (bool, optional) – Unused parameter for future extension.

  • gene2pheno (bool, optional) – If True, use the final gene embeddings for phenotype prediction.

  • sys2pheno (bool, optional) – If True, use the final system embeddings for phenotype prediction.

  • interaction_types (list, optional) – The types of interactions to use for system-to-system propagation.

  • n_covariates (int, optional) – The number of covariate features to include in the model.

  • n_phenotypes (int, optional) – The number of distinct phenotypes the model can predict.

  • dropout (float, optional) – The dropout rate for regularization.

  • activation (str, optional) – The activation function for attention mechanisms.

  • input_format (str, optional) – The format of the genotype input (‘indices’ or ‘block’).

  • cov_effect (str, optional) – Specifies how covariates affect the model (‘pre’, ‘post’, ‘direct’, or ‘both’).

  • pretrained_transformer (dict, optional) – A dictionary of pretrained transformer models for block-based input.

  • freeze_pretrained (bool, optional) – Unused parameter.

  • phenotypes (tuple, optional) – Unused parameter.

  • use_hierarchical_transformer (bool, optional) – If True, uses a hierarchical transformer for the final prediction heads.

forward(genotype_dict, covariates, phenotype_ids, nested_hierarchical_masks_forward, nested_hierarchical_masks_backward, snp2gene_mask, gene2sys_mask, sys2gene_mask, sys_temp=None, sys2env=True, env2sys=True, sys2gene=True, score=False, attention=False, snp_only=False, predict_snp=False, chunk=False)

Defines the main forward pass of the model.

Parameters:
  • genotype_dict (dict) – A dictionary containing genotype information (e.g., SNP indices).

  • covariates (torch.Tensor) – A tensor of covariate data for the batch.

  • phenotype_ids (torch.Tensor) – A tensor of phenotype IDs for the batch.

  • nested_hierarchical_masks_forward (list) – Masks for forward system-system propagation.

  • nested_hierarchical_masks_backward (list) – Masks for backward system-system propagation.

  • snp2gene_mask (torch.Tensor) – The attention mask for SNP-to-gene propagation.

  • gene2sys_mask (torch.Tensor) – The attention mask for gene-to-system propagation.

  • sys2gene_mask (torch.Tensor) – The attention mask for system-to-gene propagation.

  • sys_temp (torch.Tensor, optional) – A temperature mask for system attention.

  • score (bool, optional) – If True, return attention scores.

  • attention (bool, optional) – If True, return attention weights.

  • chunk (bool, optional) – If True, use chunk-wise propagation.

Returns:

The phenotype prediction tensor. If attention or score is True, returns a tuple containing the prediction and the requested attention/score tensors.

Return type:

torch.Tensor or tuple