Refactor overview page

This commit is contained in:
Tobias Hölzer 2025-12-28 17:06:35 +01:00
parent a304c96e4e
commit 1ee3d532fc

View file

@ -1,6 +1,8 @@
"""Overview page: List of available result directories with some summary statistics.""" """Overview page: List of available result directories with some summary statistics."""
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import TypedDict
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
@ -11,22 +13,110 @@ from entropice.dashboard.utils.data import load_all_training_results
from entropice.dataset import DatasetEnsemble from entropice.dataset import DatasetEnsemble
def render_sample_count_overview(): # Type definitions for dataset statistics
"""Render overview of sample counts per task+target+grid+level combination.""" class GridConfig(TypedDict):
st.subheader("📊 Sample Counts by Configuration") """Grid configuration specification with metadata."""
st.markdown( grid: str
""" level: int
This visualization shows the number of available samples for each combination of: grid_name: str
- **Task**: binary, count, density grid_sort_key: str
- **Target Dataset**: darts_rts, darts_mllabels disable_alphaearth: bool
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
""" class SampleCountData(TypedDict):
"""Sample count statistics for a specific grid/target/task combination."""
grid_config: GridConfig
target: str
task: str
samples_coverage: int
samples_labels: int
samples_both: int
class FeatureCountData(TypedDict):
"""Feature count statistics for a specific grid configuration."""
grid_config: GridConfig
total_features: int
data_sources: list[str]
inference_cells: int
total_samples: int
member_breakdown: dict[str, int]
@dataclass(frozen=True)
class DatasetAnalysisCache:
"""Cache for dataset analysis data to avoid redundant computations."""
grid_configs: list[GridConfig]
sample_counts: list[SampleCountData]
feature_counts: list[FeatureCountData]
def get_sample_count_df(self) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for item in self.sample_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Target": item["target"].replace("darts_", ""),
"Task": item["task"].capitalize(),
"Samples (Coverage)": item["samples_coverage"],
"Samples (Labels)": item["samples_labels"],
"Samples (Both)": item["samples_both"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
) )
return pd.DataFrame(rows)
def get_feature_count_df(self) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for item in self.feature_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Total Features": item["total_features"],
"Data Sources": len(item["data_sources"]),
"Inference Cells": item["inference_cells"],
"Total Samples": item["total_samples"],
"AlphaEarth": "AlphaEarth" in item["data_sources"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
def get_feature_breakdown_df(self) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for item in self.feature_counts:
for source, count in item["member_breakdown"].items():
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Data Source": source,
"Number of Features": count,
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
@st.cache_data(show_spinner=False)
def load_dataset_analysis_data() -> DatasetAnalysisCache:
"""Load and cache all dataset analysis data.
This function computes both sample counts and feature counts for all grid configurations.
Results are cached to avoid redundant computations across different tabs.
"""
# Define all possible grid configurations # Define all possible grid configurations
grid_configs = [ grid_configs_raw = [
("hex", 3), ("hex", 3),
("hex", 4), ("hex", 4),
("hex", 5), ("hex", 5),
@ -38,19 +128,34 @@ def render_sample_count_overview():
("healpix", 10), ("healpix", 10),
] ]
# Create structured grid config objects
grid_configs: list[GridConfig] = []
for grid, level in grid_configs_raw:
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
grid_configs.append(
{
"grid": grid,
"level": level,
"grid_name": f"{grid}-{level}",
"grid_sort_key": f"{grid}_{level:02d}",
"disable_alphaearth": disable_alphaearth,
}
)
# Compute sample counts
sample_counts: list[SampleCountData] = []
target_datasets = ["darts_rts", "darts_mllabels"] target_datasets = ["darts_rts", "darts_mllabels"]
tasks = ["binary", "count", "density"] tasks = ["binary", "count", "density"]
# Collect sample counts for grid_config in grid_configs:
sample_data = []
with st.spinner("Computing sample counts for all configurations..."):
for grid, level in grid_configs:
for target in target_datasets: for target in target_datasets:
# Create minimal ensemble just to get target data # Create minimal ensemble just to get target data
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type] ensemble = DatasetEnsemble(
grid=grid_config["grid"],
# Read target data level=grid_config["level"],
target=target,
members=[], # type: ignore[arg-type]
)
targets = ensemble._read_target() targets = ensemble._read_target()
for task in tasks: for task in tasks:
@ -64,21 +169,85 @@ def render_sample_count_overview():
valid_labels = targets[taskcol].notna().sum() valid_labels = targets[taskcol].notna().sum()
valid_both = (targets[covcol] & targets[taskcol].notna()).sum() valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
sample_data.append( sample_counts.append(
{ {
"Grid": f"{grid}-{level}", "grid_config": grid_config,
"Grid Type": grid, "target": target,
"Level": level, "task": task,
"Target": target.replace("darts_", ""), "samples_coverage": valid_coverage,
"Task": task.capitalize(), "samples_labels": valid_labels,
"Samples (Coverage)": valid_coverage, "samples_both": valid_both,
"Samples (Labels)": valid_labels,
"Samples (Both)": valid_both,
"Grid_Level_Sort": f"{grid}_{level:02d}",
} }
) )
sample_df = pd.DataFrame(sample_data) # Compute feature counts
feature_counts: list[FeatureCountData] = []
for grid_config in grid_configs:
# Determine which members are available for this configuration
if grid_config["disable_alphaearth"]:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(
grid=grid_config["grid"],
level=grid_config["level"],
target="darts_rts",
members=members, # type: ignore[arg-type]
)
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
# Build member breakdown including lon/lat
member_breakdown = {}
for member, member_stats in stats["members"].items():
member_breakdown[member] = member_stats["num_features"]
if ensemble.add_lonlat:
member_breakdown["Lon/Lat"] = 2
feature_counts.append(
{
"grid_config": grid_config,
"total_features": stats["total_features"],
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
"inference_cells": min_cells,
"total_samples": stats["num_target_samples"],
"member_breakdown": member_breakdown,
}
)
return DatasetAnalysisCache(
grid_configs=grid_configs,
sample_counts=sample_counts,
feature_counts=feature_counts,
)
def render_sample_count_overview(cache: DatasetAnalysisCache):
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
st.markdown(
"""
This visualization shows the number of available samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
"""
)
# Get sample count DataFrame from cache
sample_df = cache.get_sample_count_df()
target_datasets = ["darts_rts", "darts_mllabels"]
# Create tabs for different views # Create tabs for different views
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"]) tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
@ -159,73 +328,15 @@ def render_sample_count_overview():
st.dataframe(display_df, hide_index=True, use_container_width=True) st.dataframe(display_df, hide_index=True, use_container_width=True)
@st.fragment def render_feature_count_comparison(cache: DatasetAnalysisCache):
def render_feature_count_fragment(): """Render static comparison of feature counts across all grid configurations."""
"""Render interactive feature count visualization using fragments."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
st.markdown(
"""
This visualization shows the total number of features that would be generated
for different combinations of data sources and grid configurations.
"""
)
# First section: Comparison across all grid configurations
st.markdown("### Feature Count Comparison Across Grid Configurations") st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Define all possible grid configurations # Get data from cache
grid_configs = [ comparison_df = cache.get_feature_count_df()
("hex", 3), breakdown_df = cache.get_feature_breakdown_df()
("hex", 4), breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
# Collect feature statistics for all configurations
feature_comparison_data = []
with st.spinner("Computing feature counts for all grid configurations..."):
for grid, level in grid_configs:
# Determine which members are available for this configuration
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
feature_comparison_data.append(
{
"Grid": f"{grid}-{level}",
"Grid Type": grid,
"Level": level,
"Total Features": stats["total_features"],
"Data Sources": len(members),
"Inference Cells": min_cells,
"Total Samples": stats["num_target_samples"],
"AlphaEarth": "AlphaEarth" in members,
"Grid_Level_Sort": f"{grid}_{level:02d}",
}
)
comparison_df = pd.DataFrame(feature_comparison_data)
# Create tabs for different comparison views # Create tabs for different comparison views
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"]) comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
@ -233,56 +344,14 @@ def render_feature_count_fragment():
with comp_tab1: with comp_tab1:
st.markdown("#### Total Features by Grid Configuration") st.markdown("#### Total Features by Grid Configuration")
# Collect breakdown data for stacked bar chart
stacked_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
# Add data for each member
for member, member_stats in stats["members"].items():
stacked_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
stacked_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
stacked_df = pd.DataFrame(stacked_data)
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources # Get color palette for data sources
unique_sources = stacked_df["Data Source"].unique() unique_sources = breakdown_df["Data Source"].unique()
n_sources = len(unique_sources) n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources) source_colors = get_palette("data_sources", n_colors=n_sources)
# Create stacked bar chart # Create stacked bar chart
fig = px.bar( fig = px.bar(
stacked_df, breakdown_df,
x="Grid", x="Grid",
y="Number of Features", y="Number of Features",
color="Data Source", color="Data Source",
@ -336,58 +405,8 @@ def render_feature_count_fragment():
st.markdown("#### Feature Breakdown by Data Source") st.markdown("#### Feature Breakdown by Data Source")
st.markdown("Showing percentage contribution of each data source across all grid configurations") st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Collect breakdown data for all grid configurations
all_breakdown_data = []
for idx, row in comparison_df.iterrows():
grid_config = row["Grid"]
grid, level_str = grid_config.split("-")
level = int(level_str)
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
if disable_alphaearth:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
stats = ensemble.get_stats()
total_features = stats["total_features"]
# Add data for each member with percentage
for member, member_stats in stats["members"].items():
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": member,
"Percentage": percentage,
"Number of Features": member_stats["num_features"],
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
# Add lon/lat
if ensemble.add_lonlat:
percentage = (2 / total_features) * 100 # type: ignore[operator]
all_breakdown_data.append(
{
"Grid": grid_config,
"Data Source": "Lon/Lat",
"Percentage": percentage,
"Number of Features": 2,
"Grid_Level_Sort": row["Grid_Level_Sort"],
}
)
breakdown_all_df = pd.DataFrame(all_breakdown_data)
# Sort by grid configuration
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
# Get color palette for data sources # Get color palette for data sources
unique_sources = breakdown_all_df["Data Source"].unique() unique_sources = breakdown_df["Data Source"].unique()
n_sources = len(unique_sources) n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources) source_colors = get_palette("data_sources", n_colors=n_sources)
@ -403,7 +422,7 @@ def render_feature_count_fragment():
grid_idx = row_idx * cols_per_row + col_idx grid_idx = row_idx * cols_per_row + col_idx
if grid_idx < num_grids: if grid_idx < num_grids:
grid_config = comparison_df.iloc[grid_idx]["Grid"] grid_config = comparison_df.iloc[grid_idx]["Grid"]
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config] grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
with cols[col_idx]: with cols[col_idx]:
fig = px.pie( fig = px.pie(
@ -435,24 +454,15 @@ def render_feature_count_fragment():
st.dataframe(display_df, hide_index=True, use_container_width=True) st.dataframe(display_df, hide_index=True, use_container_width=True)
st.divider()
# Second section: Detailed configuration with user selection @st.fragment
def render_feature_count_explorer(cache: DatasetAnalysisCache):
"""Render interactive detailed configuration explorer using fragments."""
st.markdown("### Detailed Configuration Explorer") st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics") st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection # Grid selection
grid_options = [ grid_options = [gc["grid_name"] for gc in cache.grid_configs]
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
@ -474,16 +484,15 @@ def render_feature_count_fragment():
key="feature_target_select", key="feature_target_select",
) )
# Parse grid type and level # Find the selected grid config
grid, level_str = grid_level_combined.split("-") selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
level = int(level_str) grid = selected_grid_config["grid"]
level = selected_grid_config["level"]
disable_alphaearth = selected_grid_config["disable_alphaearth"]
# Members selection # Members selection
st.markdown("#### Select Data Sources") st.markdown("#### Select Data Sources")
# Check if AlphaEarth should be disabled
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use columns for checkboxes # Use columns for checkboxes
@ -624,35 +633,46 @@ def render_feature_count_fragment():
st.info("👆 Select at least one data source to see feature statistics") st.info("👆 Select at least one data source to see feature statistics")
def render_feature_count_section(cache: DatasetAnalysisCache):
"""Render the feature count section with comparison and explorer."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
st.markdown(
"""
This visualization shows the total number of features that would be generated
for different combinations of data sources and grid configurations.
"""
)
# Static comparison across all grids
render_feature_count_comparison(cache)
st.divider()
# Interactive explorer for detailed analysis
render_feature_count_explorer(cache)
def render_dataset_analysis(): def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts.""" """Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis") st.header("📈 Dataset Analysis")
# Load all data once and cache it
with st.spinner("Loading dataset analysis data..."):
cache = load_dataset_analysis_data()
# Create tabs for the two different analyses # Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]: with analysis_tabs[0]:
render_sample_count_overview() render_sample_count_overview(cache)
with analysis_tabs[1]: with analysis_tabs[1]:
render_feature_count_fragment() render_feature_count_section(cache)
def render_overview_page(): def render_training_results_summary(training_results):
"""Render the Overview page of the dashboard.""" """Render summary metrics for training results."""
st.title("🏡 Training Results Overview")
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
return
st.write(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Summary statistics at the top
st.header("📊 Training Results Summary") st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
@ -673,14 +693,9 @@ def render_overview_page():
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d") latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
st.metric("Latest Run", latest_date) st.metric("Latest Run", latest_date)
st.divider()
# Add dataset analysis section def render_experiment_results(training_results):
render_dataset_analysis() """Render detailed experiment results table and expandable details."""
st.divider()
# Detailed results table
st.header("🎯 Experiment Results") st.header("🎯 Experiment Results")
st.subheader("Results Table") st.subheader("Results Table")
@ -791,3 +806,33 @@ def render_overview_page():
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
st.write(f"\n**Path:** `{tr.path}`") st.write(f"\n**Path:** `{tr.path}`")
def render_overview_page():
"""Render the Overview page of the dashboard."""
st.title("🏡 Training Results Overview")
# Load training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
return
st.write(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Render training results sections
render_training_results_summary(training_results)
st.divider()
render_experiment_results(training_results)
st.divider()
# Render dataset analysis section
render_dataset_analysis()
st.balloons()