Sankey Visualization Module

This module provides a SankeyVisualizer class and a create_sankey_app function to generate interactive, hierarchical Sankey diagrams for visualizing attention flow from machine learning models in genomics. It is designed to illustrate how attention is distributed from SNPs to genes and up through a biological ontology (like the Gene Ontology).

Overview

The primary goal of this module is to make the complex, multi-layered attention data from a model interpretable. It does this by:

  1. Processing Attention Data: It takes a DataFrame of attention scores and recursively factorizes them, ensuring that the flow of attention is conserved as it moves up the hierarchy from SNPs to the root biological system.

  2. Structuring the Hierarchy: It uses an SNPTreeParser object to understand the relationships between SNPs, genes, and systems, and uses this structure to order and position the nodes in the diagram.

  3. Generating Plots: It uses Plotly to create a go.Figure object, providing a powerful and interactive visualization.

  4. Creating an Interactive App: It provides a factory function to quickly launch a Dash web application for interactively exploring the attention flow for different biological systems.

Key Components

SankeyVisualizer Class

This is the core class responsible for creating the Sankey diagrams. It is designed to be reusable and can be used to generate static plots (e.g., in a Jupyter notebook) or as the engine for the interactive Dash app.

  • __init__(tree_parser, ...): Initializes the visualizer with an SNPTreeParser instance and optional annotation dictionaries for enriching the plot’s tooltips.

  • plot(attention_df, ...): The main public method. It takes the processed attention data and various configuration options and returns a go.Figure object. All the complex data preparation and layout calculations are handled by internal private methods.

create_sankey_app() Function

This is a factory function that builds and returns a Dash application. It takes an initialized SankeyVisualizer instance and the attention data, and wires them up to interactive components like dropdowns and radio buttons. This separates the application logic from the plotting logic, making the code much cleaner and more modular.

Workflow

  1. Data Preparation: The _prepare_sankey_data method is the main internal workhorse. It:

    • Takes the raw attention data for a specific system.

    • Calls _factorize_attention_recursively to correctly weight the attention flow through the hierarchy.

    • Calls _get_component_orders to sort all the nodes (SNPs, genes, systems) in a deterministic way, which is crucial for a stable layout.

    • Calls _calculate_node_positions to compute the x and y coordinates for each node, creating the distinct columns for SNPs, genes, and different levels of the system hierarchy.

    • Calls _get_sankey_link_values to prepare the data for the links (the flows) between the nodes.

  2. Plot Generation: The plot method takes the highly structured data from _prepare_sankey_data and passes it directly to plotly.graph_objects.Sankey to generate the figure.

  3. App Creation: The create_sankey_app function defines the Dash layout (e.g., dropdowns, graph component) and the callback function. The callback function is responsible for responding to user input (like selecting a new system), calling the visualizer.plot method with the appropriate parameters, and updating the graph.

Usage Example

Generating a Static Plot

from src.utils.tree import SNPTreeParser
from src.utils.visualization.sankey import SankeyVisualizer
import pandas as pd

# 1. Initialize the SNPTreeParser
tree_parser = SNPTreeParser(...)

# 2. Load and process your attention data into a DataFrame
# This is a simplified example; your actual data will be more complex
attention_data = pd.read_csv('path/to/attention.csv', index_col=[0, 1, 2])

# 3. Initialize the visualizer
visualizer = SankeyVisualizer(tree_parser)

# 4. Generate the plot
fig = visualizer.plot(
    attention_df=attention_data,
    target_go='GO:0008150',
    direction='forward',
    genotypes=('homozygous', 'heterozygous'),
    title="My First Sankey Plot"
)

# 5. Show the plot
fig.show()

Launching the Interactive Dash App

# (Continuing from the previous example)
from src.utils.visualization.sankey import create_sankey_app

# A list of GO terms you want to be able to select in the app
go_term_list = ['GO:0008150', 'GO:0006915', 'GO:0007049']

# Create the Dash app instance
app = create_sankey_app(
    visualizer=visualizer,
    attention_df=attention_data, # The full attention DataFrame
    target_go_list=go_term_list,
    initial_go='GO:0008150'
)

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True)