Sankey Visualization Module
This module provides a SankeyVisualizer class and a create_sankey_app function to generate interactive, hierarchical Sankey diagrams for visualizing attention flow from machine learning models in genomics. It is designed to illustrate how attention is distributed from SNPs to genes and up through a biological ontology (like the Gene Ontology).
Overview
The primary goal of this module is to make the complex, multi-layered attention data from a model interpretable. It does this by:
Processing Attention Data: It takes a DataFrame of attention scores and recursively factorizes them, ensuring that the flow of attention is conserved as it moves up the hierarchy from SNPs to the root biological system.
Structuring the Hierarchy: It uses an
SNPTreeParserobject to understand the relationships between SNPs, genes, and systems, and uses this structure to order and position the nodes in the diagram.Generating Plots: It uses Plotly to create a
go.Figureobject, providing a powerful and interactive visualization.Creating an Interactive App: It provides a factory function to quickly launch a Dash web application for interactively exploring the attention flow for different biological systems.
Key Components
SankeyVisualizer Class
This is the core class responsible for creating the Sankey diagrams. It is designed to be reusable and can be used to generate static plots (e.g., in a Jupyter notebook) or as the engine for the interactive Dash app.
__init__(tree_parser, ...): Initializes the visualizer with anSNPTreeParserinstance and optional annotation dictionaries for enriching the plot’s tooltips.plot(attention_df, ...): The main public method. It takes the processed attention data and various configuration options and returns ago.Figureobject. All the complex data preparation and layout calculations are handled by internal private methods.
create_sankey_app() Function
This is a factory function that builds and returns a Dash application. It takes an initialized SankeyVisualizer instance and the attention data, and wires them up to interactive components like dropdowns and radio buttons. This separates the application logic from the plotting logic, making the code much cleaner and more modular.
Workflow
Data Preparation: The
_prepare_sankey_datamethod is the main internal workhorse. It:Takes the raw attention data for a specific system.
Calls
_factorize_attention_recursivelyto correctly weight the attention flow through the hierarchy.Calls
_get_component_ordersto sort all the nodes (SNPs, genes, systems) in a deterministic way, which is crucial for a stable layout.Calls
_calculate_node_positionsto compute thexandycoordinates for each node, creating the distinct columns for SNPs, genes, and different levels of the system hierarchy.Calls
_get_sankey_link_valuesto prepare the data for the links (the flows) between the nodes.
Plot Generation: The
plotmethod takes the highly structured data from_prepare_sankey_dataand passes it directly toplotly.graph_objects.Sankeyto generate the figure.App Creation: The
create_sankey_appfunction defines the Dash layout (e.g., dropdowns, graph component) and the callback function. The callback function is responsible for responding to user input (like selecting a new system), calling thevisualizer.plotmethod with the appropriate parameters, and updating the graph.
Usage Example
Generating a Static Plot
from src.utils.tree import SNPTreeParser
from src.utils.visualization.sankey import SankeyVisualizer
import pandas as pd
# 1. Initialize the SNPTreeParser
tree_parser = SNPTreeParser(...)
# 2. Load and process your attention data into a DataFrame
# This is a simplified example; your actual data will be more complex
attention_data = pd.read_csv('path/to/attention.csv', index_col=[0, 1, 2])
# 3. Initialize the visualizer
visualizer = SankeyVisualizer(tree_parser)
# 4. Generate the plot
fig = visualizer.plot(
attention_df=attention_data,
target_go='GO:0008150',
direction='forward',
genotypes=('homozygous', 'heterozygous'),
title="My First Sankey Plot"
)
# 5. Show the plot
fig.show()
Launching the Interactive Dash App
# (Continuing from the previous example)
from src.utils.visualization.sankey import create_sankey_app
# A list of GO terms you want to be able to select in the app
go_term_list = ['GO:0008150', 'GO:0006915', 'GO:0007049']
# Create the Dash app instance
app = create_sankey_app(
visualizer=visualizer,
attention_df=attention_data, # The full attention DataFrame
target_go_list=go_term_list,
initial_go='GO:0008150'
)
# Run the app
if __name__ == '__main__':
app.run_server(debug=True)