Add Dataset Analysis to the overview page

This commit is contained in:
Tobias Hölzer 2025-12-28 15:31:51 +01:00
parent 6960571742
commit a304c96e4e

View file

@ -3,9 +3,639 @@
from datetime import datetime from datetime import datetime
import pandas as pd import pandas as pd
import plotly.express as px
import streamlit as st import streamlit as st
from entropice.dashboard.plots.colors import get_palette
from entropice.dashboard.utils.data import load_all_training_results from entropice.dashboard.utils.data import load_all_training_results
from entropice.dataset import DatasetEnsemble
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
st.markdown(
"""
This visualization shows the number of available samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
"""
)
# Define all possible grid configurations
grid_configs = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
target_datasets = ["darts_rts", "darts_mllabels"]
tasks = ["binary", "count", "density"]
# Collect sample counts
sample_data = []
with st.spinner("Computing sample counts for all configurations..."):
for grid, level in grid_configs:
for target in target_datasets:
# Create minimal ensemble just to get target data
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
# Read target data
targets = ensemble._read_target()
for task in tasks:
# Get task-specific column
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
covcol = ensemble.covcol
# Count samples with coverage and valid labels
if covcol in targets.columns and taskcol in targets.columns:
valid_coverage = targets[covcol].sum()
valid_labels = targets[taskcol].notna().sum()
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
sample_data.append(
{
"Grid": f"{grid}-{level}",
"Grid Type": grid,
"Level": level,
"Target": target.replace("darts_", ""),
"Task": task.capitalize(),
"Samples (Coverage)": valid_coverage,
"Samples (Labels)": valid_labels,
"Samples (Both)": valid_both,
"Grid_Level_Sort": f"{grid}_{level:02d}",
}
)
sample_df = pd.DataFrame(sample_data)
# Create tabs for different views
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
with tab1:
st.markdown("### Sample Counts Heatmap")
st.markdown("Showing counts of samples with both coverage and valid labels")
# Create heatmap for each target dataset
for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
# Get color palette for sample counts
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
fig = px.imshow(
pivot_df,
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=sample_colors,
aspect="auto",
title=f"Target: {target}",
)
# Add text annotations
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
fig.update_layout(height=400)
st.plotly_chart(fig, use_container_width=True)
with tab2:
st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with both coverage and valid labels")
# Create a faceted bar chart showing both targets side by side
# Get color palette for tasks
n_tasks = sample_df["Task"].nunique()
task_colors = get_palette("task_types", n_colors=n_tasks)
fig = px.bar(
sample_df,
x="Grid",
y="Samples (Both)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
color_discrete_sequence=task_colors,
height=500,
)
# Update facet labels to be cleaner
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_xaxes(tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
with tab3:
st.markdown("### Detailed Sample Counts")
# Display full table with formatting
display_df = sample_df[
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
].copy()
# Format numbers with commas
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
st.dataframe(display_df, hide_index=True, use_container_width=True)
@st.fragment
def render_feature_count_fragment():
"""Render interactive feature count visualization using fragments."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
st.markdown(
"""
This visualization shows the total number of features that would be generated
for different combinations of data sources and grid configurations.
"""
)
# First section: Comparison across all grid configurations
st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Define all possible grid configurations
grid_configs = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
# Collect feature statistics for all configurations
feature_comparison_data = []
with st.spinner("Computing feature counts for all grid configurations..."):
for grid, level in grid_configs:
# Determine which members are available for this configuration
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
feature_comparison_data.append(
{
"Grid": f"{grid}-{level}",
"Grid Type": grid,
"Level": level,
"Total Features": stats["total_features"],
"Data Sources": len(members),
"Inference Cells": min_cells,
"Total Samples": stats["num_target_samples"],
"AlphaEarth": "AlphaEarth" in members,
"Grid_Level_Sort": f"{grid}_{level:02d}",
}
)
comparison_df = pd.DataFrame(feature_comparison_data)
# Create tabs for different comparison views
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
with comp_tab1:
st.markdown("#### Total Features by Grid Configuration")
# Collect breakdown data for stacked bar chart
stacked_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Add data for each member
for member, member_stats in stats["members"].items():
stacked_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
stacked_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
stacked_df = pd.DataFrame(stacked_data)
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources
unique_sources = stacked_df["Data Source"].unique()
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create stacked bar chart
fig = px.bar(
stacked_df,
x="Grid",
y="Number of Features",
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
color_discrete_sequence=source_colors,
text_auto=False,
)
fig.update_layout(height=500, xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
# Add secondary metrics
col1, col2 = st.columns(2)
with col1:
# Get color palette for grid configs
n_grids = len(comparison_df)
grid_colors = get_palette("grid_configs", n_colors=n_grids)
fig_cells = px.bar(
comparison_df,
x="Grid",
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_cells, use_container_width=True)
with col2:
fig_samples = px.bar(
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_samples, use_container_width=True)
with comp_tab2:
st.markdown("#### Feature Breakdown by Data Source")
st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Collect breakdown data for all grid configurations
all_breakdown_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
total_features = stats["total_features"]
# Add data for each member with percentage
for member, member_stats in stats["members"].items():
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Percentage": percentage,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
percentage = (2 / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Percentage": percentage,
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
breakdown_all_df = pd.DataFrame(all_breakdown_data)
# Sort by grid configuration
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources
unique_sources = breakdown_all_df["Data Source"].unique()
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create donut charts for each grid configuration
# Organize in a grid layout
num_grids = len(comparison_df)
cols_per_row = 3
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
for row_idx in range(num_rows):
cols = st.columns(cols_per_row)
for col_idx in range(cols_per_row):
grid_idx = row_idx * cols_per_row + col_idx
if grid_idx < num_grids:
grid_config = comparison_df.iloc[grid_idx]["Grid"]
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
with cols[col_idx]:
fig = px.pie(
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
st.plotly_chart(fig, use_container_width=True)
with comp_tab3:
st.markdown("#### Detailed Feature Count Comparison")
# Display full comparison table with formatting
display_df = comparison_df[
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
# Format boolean as Yes/No
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "" if x else "")
st.dataframe(display_df, hide_index=True, use_container_width=True)
st.divider()
# Second section: Detailed configuration with user selection
st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
grid_options = [
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
col1, col2 = st.columns(2)
with col1:
grid_level_combined = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
key="feature_grid_select",
)
with col2:
target = st.selectbox(
"Target Dataset",
options=["darts_rts", "darts_mllabels"],
index=0,
help="Select the target dataset",
key="feature_target_select",
)
# Parse grid type and level
grid, level_str = grid_level_combined.split("-")
level = int(level_str)
# Members selection
st.markdown("#### Select Data Sources")
# Check if AlphaEarth should be disabled
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use columns for checkboxes
cols = st.columns(len(all_members))
selected_members = []
for idx, member in enumerate(all_members):
with cols[idx]:
if member == "AlphaEarth" and disable_alphaearth:
st.checkbox(
member,
value=False,
disabled=True,
help=f"Not available for {grid} level {level}",
key=f"feature_member_{member}",
)
else:
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
selected_members.append(member)
# Show results if at least one member is selected
if selected_members:
st.markdown("---")
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# High-level metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Total Features", f"{stats['total_features']:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
with col5:
# Calculate total data points
total_points = stats["total_features"] * stats["num_target_samples"]
st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in stats["members"].items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
}
)
# Add lon/lat
if ensemble.add_lonlat:
breakdown_data.append(
{
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get color palette for data sources
n_sources = len(breakdown_df)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create pie chart
fig = px.pie(
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent+label")
st.plotly_chart(fig, use_container_width=True)
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, use_container_width=True)
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
st.metric("Shape", dim_str)
with metric_cols[3]:
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
# Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]:
render_sample_count_overview()
with analysis_tabs[1]:
render_feature_count_fragment()
def render_overview_page(): def render_overview_page():
@ -20,8 +650,10 @@ def render_overview_page():
st.write(f"Found **{len(training_results)}** training result(s)") st.write(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Summary statistics at the top # Summary statistics at the top
st.subheader("Summary Statistics") st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
with col1: with col1:
@ -43,8 +675,14 @@ def render_overview_page():
st.divider() st.divider()
# Add dataset analysis section
render_dataset_analysis()
st.divider()
# Detailed results table # Detailed results table
st.subheader("Training Results") st.header("🎯 Experiment Results")
st.subheader("Results Table")
# Build a summary dataframe # Build a summary dataframe
summary_data = [] summary_data = []
@ -93,7 +731,7 @@ def render_overview_page():
st.divider() st.divider()
# Expandable details for each result # Expandable details for each result
st.subheader("Detailed Results") st.subheader("Individual Experiment Details")
for tr in training_results: for tr in training_results:
with st.expander(tr.get_display_name("task_first")): with st.expander(tr.get_display_name("task_first")):