SNP2Phenotype
Overview
A hierarchical transformer model to predict phenotypes from genotypes, guided by a biological ontology. This model translates SNP-level genetic information up through a biological hierarchy (SNPs -> Genes -> Biological Systems) to predict one or more phenotypes.
Usage and examples
Example: initialize the model
from g2pt.tree import SNPTreeParser
from src.model.model.snp2phenotype import SNP2PhenotypeModel
tree_parser = SNPTreeParser(
ontology="ontology.tsv",
snp2gene="snp2gene.tsv",
)
model = SNP2PhenotypeModel(
tree_parser=tree_parser,
hidden_dims=128,
n_covariates=4,
n_phenotypes=1,
)
API documentation
- class SNP2PhenotypeModel
A hierarchical transformer model to predict phenotypes from genotypes, guided by a biological ontology.
This model translates SNP-level genetic information up through a biological hierarchy (SNPs -> Genes -> Biological Systems) to predict one or more phenotypes. It uses a series of transformer-based modules to propagate information and learn context-aware embeddings at each level of the hierarchy.
The core workflow is as follows: 1. Embedding: SNPs, genes, systems, and phenotypes are embedded into a high-dimensional space. 2. Propagation: Information flows up the hierarchy. SNP effects are propagated to genes,
gene effects are propagated to systems, and system-system interactions are resolved.
Prediction: The final embeddings for genes and/or systems are used to predict the phenotype, modulated by covariate information.
- Parameters:
tree_parser (SNPTreeParser) – An object that provides the hierarchical structure (SNP-gene-system mappings) and corresponding masks for the model.
hidden_dims (int) – The dimensionality of the embeddings and hidden layers.
snp2pheno (bool, optional) – Unused parameter for future extension.
gene2pheno (bool, optional) – If True, use the final gene embeddings for phenotype prediction.
sys2pheno (bool, optional) – If True, use the final system embeddings for phenotype prediction.
interaction_types (list, optional) – The types of interactions to use for system-to-system propagation.
n_covariates (int, optional) – The number of covariate features to include in the model.
n_phenotypes (int, optional) – The number of distinct phenotypes the model can predict.
dropout (float, optional) – The dropout rate for regularization.
activation (str, optional) – The activation function for attention mechanisms.
input_format (str, optional) – The format of the genotype input (‘indices’ or ‘block’).
cov_effect (str, optional) – Specifies how covariates affect the model (‘pre’, ‘post’, ‘direct’, or ‘both’).
pretrained_transformer (dict, optional) – A dictionary of pretrained transformer models for block-based input.
freeze_pretrained (bool, optional) – Unused parameter.
phenotypes (tuple, optional) – Unused parameter.
use_hierarchical_transformer (bool, optional) – If True, uses a hierarchical transformer for the final prediction heads.
- forward(genotype_dict, covariates, phenotype_ids, nested_hierarchical_masks_forward, nested_hierarchical_masks_backward, snp2gene_mask, gene2sys_mask, sys2gene_mask, sys_temp=None, sys2env=True, env2sys=True, sys2gene=True, score=False, attention=False, snp_only=False, predict_snp=False, chunk=False)
Defines the main forward pass of the model.
- Parameters:
genotype_dict (dict) – A dictionary containing genotype information (e.g., SNP indices).
covariates (torch.Tensor) – A tensor of covariate data for the batch.
phenotype_ids (torch.Tensor) – A tensor of phenotype IDs for the batch.
nested_hierarchical_masks_forward (list) – Masks for forward system-system propagation.
nested_hierarchical_masks_backward (list) – Masks for backward system-system propagation.
snp2gene_mask (torch.Tensor) – The attention mask for SNP-to-gene propagation.
gene2sys_mask (torch.Tensor) – The attention mask for gene-to-system propagation.
sys2gene_mask (torch.Tensor) – The attention mask for system-to-gene propagation.
sys_temp (torch.Tensor, optional) – A temperature mask for system attention.
score (bool, optional) – If True, return attention scores.
attention (bool, optional) – If True, return attention weights.
chunk (bool, optional) – If True, use chunk-wise propagation.
- Returns:
The phenotype prediction tensor. If attention or score is True, returns a tuple containing the prediction and the requested attention/score tensors.
- Return type:
torch.Tensor or tuple