Add Dataset Analysis to the overview page

This commit is contained in:
Tobias Hölzer 2025-12-28 15:31:51 +01:00
parent 6960571742
commit a304c96e4e

View file

@ -3,9 +3,639 @@
from datetime import datetime
import pandas as pd
import plotly.express as px
import streamlit as st
from entropice.dashboard.plots.colors import get_palette
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dataset import DatasetEnsemble
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
st.markdown(
"""
This visualization shows the number of available samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
"""
)
# Define all possible grid configurations
grid_configs = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
target_datasets = ["darts_rts", "darts_mllabels"]
tasks = ["binary", "count", "density"]
# Collect sample counts
sample_data = []
with st.spinner("Computing sample counts for all configurations..."):
for grid, level in grid_configs:
for target in target_datasets:
# Create minimal ensemble just to get target data
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
# Read target data
targets = ensemble._read_target()
for task in tasks:
# Get task-specific column
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
covcol = ensemble.covcol
# Count samples with coverage and valid labels
if covcol in targets.columns and taskcol in targets.columns:
valid_coverage = targets[covcol].sum()
valid_labels = targets[taskcol].notna().sum()
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
sample_data.append(
{
"Grid": f"{grid}-{level}",
"Grid Type": grid,
"Level": level,
"Target": target.replace("darts_", ""),
"Task": task.capitalize(),
"Samples (Coverage)": valid_coverage,
"Samples (Labels)": valid_labels,
"Samples (Both)": valid_both,
"Grid_Level_Sort": f"{grid}_{level:02d}",
}
)
sample_df = pd.DataFrame(sample_data)
# Create tabs for different views
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
with tab1:
st.markdown("### Sample Counts Heatmap")
st.markdown("Showing counts of samples with both coverage and valid labels")
# Create heatmap for each target dataset
for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
# Get color palette for sample counts
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
fig = px.imshow(
pivot_df,
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=sample_colors,
aspect="auto",
title=f"Target: {target}",
)
# Add text annotations
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
fig.update_layout(height=400)
st.plotly_chart(fig, use_container_width=True)
with tab2:
st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with both coverage and valid labels")
# Create a faceted bar chart showing both targets side by side
# Get color palette for tasks
n_tasks = sample_df["Task"].nunique()
task_colors = get_palette("task_types", n_colors=n_tasks)
fig = px.bar(
sample_df,
x="Grid",
y="Samples (Both)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
color_discrete_sequence=task_colors,
height=500,
)
# Update facet labels to be cleaner
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_xaxes(tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
with tab3:
st.markdown("### Detailed Sample Counts")
# Display full table with formatting
display_df = sample_df[
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
].copy()
# Format numbers with commas
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
st.dataframe(display_df, hide_index=True, use_container_width=True)
@st.fragment
def render_feature_count_fragment():
"""Render interactive feature count visualization using fragments."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
st.markdown(
"""
This visualization shows the total number of features that would be generated
for different combinations of data sources and grid configurations.
"""
)
# First section: Comparison across all grid configurations
st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Define all possible grid configurations
grid_configs = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
# Collect feature statistics for all configurations
feature_comparison_data = []
with st.spinner("Computing feature counts for all grid configurations..."):
for grid, level in grid_configs:
# Determine which members are available for this configuration
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
feature_comparison_data.append(
{
"Grid": f"{grid}-{level}",
"Grid Type": grid,
"Level": level,
"Total Features": stats["total_features"],
"Data Sources": len(members),
"Inference Cells": min_cells,
"Total Samples": stats["num_target_samples"],
"AlphaEarth": "AlphaEarth" in members,
"Grid_Level_Sort": f"{grid}_{level:02d}",
}
)
comparison_df = pd.DataFrame(feature_comparison_data)
# Create tabs for different comparison views
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
with comp_tab1:
st.markdown("#### Total Features by Grid Configuration")
# Collect breakdown data for stacked bar chart
stacked_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Add data for each member
for member, member_stats in stats["members"].items():
stacked_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
stacked_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
stacked_df = pd.DataFrame(stacked_data)
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources
unique_sources = stacked_df["Data Source"].unique()
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create stacked bar chart
fig = px.bar(
stacked_df,
x="Grid",
y="Number of Features",
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
color_discrete_sequence=source_colors,
text_auto=False,
)
fig.update_layout(height=500, xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
# Add secondary metrics
col1, col2 = st.columns(2)
with col1:
# Get color palette for grid configs
n_grids = len(comparison_df)
grid_colors = get_palette("grid_configs", n_colors=n_grids)
fig_cells = px.bar(
comparison_df,
x="Grid",
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_cells, use_container_width=True)
with col2:
fig_samples = px.bar(
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_samples, use_container_width=True)
with comp_tab2:
st.markdown("#### Feature Breakdown by Data Source")
st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Collect breakdown data for all grid configurations
all_breakdown_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
total_features = stats["total_features"]
# Add data for each member with percentage
for member, member_stats in stats["members"].items():
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Percentage": percentage,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
percentage = (2 / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Percentage": percentage,
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
breakdown_all_df = pd.DataFrame(all_breakdown_data)
# Sort by grid configuration
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources
unique_sources = breakdown_all_df["Data Source"].unique()
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create donut charts for each grid configuration
# Organize in a grid layout
num_grids = len(comparison_df)
cols_per_row = 3
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
for row_idx in range(num_rows):
cols = st.columns(cols_per_row)
for col_idx in range(cols_per_row):
grid_idx = row_idx * cols_per_row + col_idx
if grid_idx < num_grids:
grid_config = comparison_df.iloc[grid_idx]["Grid"]
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
with cols[col_idx]:
fig = px.pie(
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
st.plotly_chart(fig, use_container_width=True)
with comp_tab3:
st.markdown("#### Detailed Feature Count Comparison")
# Display full comparison table with formatting
display_df = comparison_df[
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
# Format boolean as Yes/No
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "" if x else "")
st.dataframe(display_df, hide_index=True, use_container_width=True)
st.divider()
# Second section: Detailed configuration with user selection
st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
grid_options = [
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
col1, col2 = st.columns(2)
with col1:
grid_level_combined = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
key="feature_grid_select",
)
with col2:
target = st.selectbox(
"Target Dataset",
options=["darts_rts", "darts_mllabels"],
index=0,
help="Select the target dataset",
key="feature_target_select",
)
# Parse grid type and level
grid, level_str = grid_level_combined.split("-")
level = int(level_str)
# Members selection
st.markdown("#### Select Data Sources")
# Check if AlphaEarth should be disabled
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use columns for checkboxes
cols = st.columns(len(all_members))
selected_members = []
for idx, member in enumerate(all_members):
with cols[idx]:
if member == "AlphaEarth" and disable_alphaearth:
st.checkbox(
member,
value=False,
disabled=True,
help=f"Not available for {grid} level {level}",
key=f"feature_member_{member}",
)
else:
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
selected_members.append(member)
# Show results if at least one member is selected
if selected_members:
st.markdown("---")
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# High-level metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Total Features", f"{stats['total_features']:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
with col5:
# Calculate total data points
total_points = stats["total_features"] * stats["num_target_samples"]
st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in stats["members"].items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
}
)
# Add lon/lat
if ensemble.add_lonlat:
breakdown_data.append(
{
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get color palette for data sources
n_sources = len(breakdown_df)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create pie chart
fig = px.pie(
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent+label")
st.plotly_chart(fig, use_container_width=True)
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, use_container_width=True)
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
st.metric("Shape", dim_str)
with metric_cols[3]:
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
# Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]:
render_sample_count_overview()
with analysis_tabs[1]:
render_feature_count_fragment()
def render_overview_page():
@ -20,8 +650,10 @@ def render_overview_page():
st.write(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Summary statistics at the top
st.subheader("Summary Statistics")
st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4)
with col1:
@ -43,8 +675,14 @@ def render_overview_page():
st.divider()
# Add dataset analysis section
render_dataset_analysis()
st.divider()
# Detailed results table
st.subheader("Training Results")
st.header("🎯 Experiment Results")
st.subheader("Results Table")
# Build a summary dataframe
summary_data = []
@ -93,7 +731,7 @@ def render_overview_page():
st.divider()
# Expandable details for each result
st.subheader("Detailed Results")
st.subheader("Individual Experiment Details")
for tr in training_results:
with st.expander(tr.get_display_name("task_first")):