diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/overview_page.py
index e373dae..7d087c5 100644
--- a/src/entropice/dashboard/overview_page.py
+++ b/src/entropice/dashboard/overview_page.py
@@ -3,9 +3,639 @@
from datetime import datetime
import pandas as pd
+import plotly.express as px
import streamlit as st
+from entropice.dashboard.plots.colors import get_palette
from entropice.dashboard.utils.data import load_all_training_results
+from entropice.dataset import DatasetEnsemble
+
+
+def render_sample_count_overview():
+ """Render overview of sample counts per task+target+grid+level combination."""
+ st.subheader("📊 Sample Counts by Configuration")
+
+ st.markdown(
+ """
+ This visualization shows the number of available samples for each combination of:
+ - **Task**: binary, count, density
+ - **Target Dataset**: darts_rts, darts_mllabels
+ - **Grid System**: hex, healpix
+ - **Grid Level**: varying by grid type
+ """
+ )
+
+ # Define all possible grid configurations
+ grid_configs = [
+ ("hex", 3),
+ ("hex", 4),
+ ("hex", 5),
+ ("hex", 6),
+ ("healpix", 6),
+ ("healpix", 7),
+ ("healpix", 8),
+ ("healpix", 9),
+ ("healpix", 10),
+ ]
+
+ target_datasets = ["darts_rts", "darts_mllabels"]
+ tasks = ["binary", "count", "density"]
+
+ # Collect sample counts
+ sample_data = []
+
+ with st.spinner("Computing sample counts for all configurations..."):
+ for grid, level in grid_configs:
+ for target in target_datasets:
+ # Create minimal ensemble just to get target data
+ ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
+
+ # Read target data
+ targets = ensemble._read_target()
+
+ for task in tasks:
+ # Get task-specific column
+ taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
+ covcol = ensemble.covcol
+
+ # Count samples with coverage and valid labels
+ if covcol in targets.columns and taskcol in targets.columns:
+ valid_coverage = targets[covcol].sum()
+ valid_labels = targets[taskcol].notna().sum()
+ valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
+
+ sample_data.append(
+ {
+ "Grid": f"{grid}-{level}",
+ "Grid Type": grid,
+ "Level": level,
+ "Target": target.replace("darts_", ""),
+ "Task": task.capitalize(),
+ "Samples (Coverage)": valid_coverage,
+ "Samples (Labels)": valid_labels,
+ "Samples (Both)": valid_both,
+ "Grid_Level_Sort": f"{grid}_{level:02d}",
+ }
+ )
+
+ sample_df = pd.DataFrame(sample_data)
+
+ # Create tabs for different views
+ tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
+
+ with tab1:
+ st.markdown("### Sample Counts Heatmap")
+ st.markdown("Showing counts of samples with both coverage and valid labels")
+
+ # Create heatmap for each target dataset
+ for target in target_datasets:
+ target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
+
+ # Pivot for heatmap: Grid x Task
+ pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
+
+ # Sort index by grid type and level
+ sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
+ pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
+
+ # Get color palette for sample counts
+ sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
+
+ fig = px.imshow(
+ pivot_df,
+ labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
+ x=pivot_df.columns,
+ y=pivot_df.index,
+ color_continuous_scale=sample_colors,
+ aspect="auto",
+ title=f"Target: {target}",
+ )
+
+ # Add text annotations
+ fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
+
+ fig.update_layout(height=400)
+ st.plotly_chart(fig, use_container_width=True)
+
+ with tab2:
+ st.markdown("### Sample Counts Bar Chart")
+ st.markdown("Showing counts of samples with both coverage and valid labels")
+
+ # Create a faceted bar chart showing both targets side by side
+ # Get color palette for tasks
+ n_tasks = sample_df["Task"].nunique()
+ task_colors = get_palette("task_types", n_colors=n_tasks)
+
+ fig = px.bar(
+ sample_df,
+ x="Grid",
+ y="Samples (Both)",
+ color="Task",
+ facet_col="Target",
+ barmode="group",
+ title="Sample Counts by Grid Configuration and Target Dataset",
+ labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
+ color_discrete_sequence=task_colors,
+ height=500,
+ )
+
+ # Update facet labels to be cleaner
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
+ fig.update_xaxes(tickangle=-45)
+ st.plotly_chart(fig, use_container_width=True)
+
+ with tab3:
+ st.markdown("### Detailed Sample Counts")
+
+ # Display full table with formatting
+ display_df = sample_df[
+ ["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
+ ].copy()
+
+ # Format numbers with commas
+ for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
+ display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
+
+ st.dataframe(display_df, hide_index=True, use_container_width=True)
+
+
+@st.fragment
+def render_feature_count_fragment():
+ """Render interactive feature count visualization using fragments."""
+ st.subheader("🔢 Feature Counts by Dataset Configuration")
+
+ st.markdown(
+ """
+ This visualization shows the total number of features that would be generated
+ for different combinations of data sources and grid configurations.
+ """
+ )
+
+ # First section: Comparison across all grid configurations
+ st.markdown("### Feature Count Comparison Across Grid Configurations")
+ st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
+
+ # Define all possible grid configurations
+ grid_configs = [
+ ("hex", 3),
+ ("hex", 4),
+ ("hex", 5),
+ ("hex", 6),
+ ("healpix", 6),
+ ("healpix", 7),
+ ("healpix", 8),
+ ("healpix", 9),
+ ("healpix", 10),
+ ]
+
+ # Collect feature statistics for all configurations
+ feature_comparison_data = []
+
+ with st.spinner("Computing feature counts for all grid configurations..."):
+ for grid, level in grid_configs:
+ # Determine which members are available for this configuration
+ disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
+
+ if disable_alphaearth:
+ members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ else:
+ members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+
+ # Use darts_rts as default target for comparison
+ ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
+ stats = ensemble.get_stats()
+
+ # Calculate minimum cells across all data sources
+ min_cells = min(
+ member_stats["dimensions"]["cell_ids"] # type: ignore[index]
+ for member_stats in stats["members"].values()
+ )
+
+ feature_comparison_data.append(
+ {
+ "Grid": f"{grid}-{level}",
+ "Grid Type": grid,
+ "Level": level,
+ "Total Features": stats["total_features"],
+ "Data Sources": len(members),
+ "Inference Cells": min_cells,
+ "Total Samples": stats["num_target_samples"],
+ "AlphaEarth": "AlphaEarth" in members,
+ "Grid_Level_Sort": f"{grid}_{level:02d}",
+ }
+ )
+
+ comparison_df = pd.DataFrame(feature_comparison_data)
+
+ # Create tabs for different comparison views
+ comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
+
+ with comp_tab1:
+ st.markdown("#### Total Features by Grid Configuration")
+
+ # Collect breakdown data for stacked bar chart
+ stacked_data = []
+
+ for idx, row in comparison_df.iterrows():
+ grid_config = row["Grid"]
+ grid, level_str = grid_config.split("-")
+ level = int(level_str)
+ disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
+
+ if disable_alphaearth:
+ members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ else:
+ members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+
+ ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
+ stats = ensemble.get_stats()
+
+ # Add data for each member
+ for member, member_stats in stats["members"].items():
+ stacked_data.append(
+ {
+ "Grid": grid_config,
+ "Data Source": member,
+ "Number of Features": member_stats["num_features"],
+ "Grid_Level_Sort": row["Grid_Level_Sort"],
+ }
+ )
+
+ # Add lon/lat
+ if ensemble.add_lonlat:
+ stacked_data.append(
+ {
+ "Grid": grid_config,
+ "Data Source": "Lon/Lat",
+ "Number of Features": 2,
+ "Grid_Level_Sort": row["Grid_Level_Sort"],
+ }
+ )
+
+ stacked_df = pd.DataFrame(stacked_data)
+ stacked_df = stacked_df.sort_values("Grid_Level_Sort")
+
+ # Get color palette for data sources
+ unique_sources = stacked_df["Data Source"].unique()
+ n_sources = len(unique_sources)
+ source_colors = get_palette("data_sources", n_colors=n_sources)
+
+ # Create stacked bar chart
+ fig = px.bar(
+ stacked_df,
+ x="Grid",
+ y="Number of Features",
+ color="Data Source",
+ barmode="stack",
+ title="Total Features by Data Source Across Grid Configurations",
+ labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
+ color_discrete_sequence=source_colors,
+ text_auto=False,
+ )
+
+ fig.update_layout(height=500, xaxis_tickangle=-45)
+ st.plotly_chart(fig, use_container_width=True)
+
+ # Add secondary metrics
+ col1, col2 = st.columns(2)
+ with col1:
+ # Get color palette for grid configs
+ n_grids = len(comparison_df)
+ grid_colors = get_palette("grid_configs", n_colors=n_grids)
+
+ fig_cells = px.bar(
+ comparison_df,
+ x="Grid",
+ y="Inference Cells",
+ color="Grid",
+ title="Inference Cells by Grid Configuration",
+ labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
+ color_discrete_sequence=grid_colors,
+ text="Inference Cells",
+ )
+ fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
+ fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
+ st.plotly_chart(fig_cells, use_container_width=True)
+
+ with col2:
+ fig_samples = px.bar(
+ comparison_df,
+ x="Grid",
+ y="Total Samples",
+ color="Grid",
+ title="Total Samples by Grid Configuration",
+ labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
+ color_discrete_sequence=grid_colors,
+ text="Total Samples",
+ )
+ fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
+ fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
+ st.plotly_chart(fig_samples, use_container_width=True)
+
+ with comp_tab2:
+ st.markdown("#### Feature Breakdown by Data Source")
+ st.markdown("Showing percentage contribution of each data source across all grid configurations")
+
+ # Collect breakdown data for all grid configurations
+ all_breakdown_data = []
+
+ for idx, row in comparison_df.iterrows():
+ grid_config = row["Grid"]
+ grid, level_str = grid_config.split("-")
+ level = int(level_str)
+ disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
+
+ if disable_alphaearth:
+ members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ else:
+ members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+
+ ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
+ stats = ensemble.get_stats()
+
+ total_features = stats["total_features"]
+
+ # Add data for each member with percentage
+ for member, member_stats in stats["members"].items():
+ percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
+ all_breakdown_data.append(
+ {
+ "Grid": grid_config,
+ "Data Source": member,
+ "Percentage": percentage,
+ "Number of Features": member_stats["num_features"],
+ "Grid_Level_Sort": row["Grid_Level_Sort"],
+ }
+ )
+
+ # Add lon/lat
+ if ensemble.add_lonlat:
+ percentage = (2 / total_features) * 100 # type: ignore[operator]
+ all_breakdown_data.append(
+ {
+ "Grid": grid_config,
+ "Data Source": "Lon/Lat",
+ "Percentage": percentage,
+ "Number of Features": 2,
+ "Grid_Level_Sort": row["Grid_Level_Sort"],
+ }
+ )
+
+ breakdown_all_df = pd.DataFrame(all_breakdown_data)
+
+ # Sort by grid configuration
+ breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
+
+ # Get color palette for data sources
+ unique_sources = breakdown_all_df["Data Source"].unique()
+ n_sources = len(unique_sources)
+ source_colors = get_palette("data_sources", n_colors=n_sources)
+
+ # Create donut charts for each grid configuration
+ # Organize in a grid layout
+ num_grids = len(comparison_df)
+ cols_per_row = 3
+ num_rows = (num_grids + cols_per_row - 1) // cols_per_row
+
+ for row_idx in range(num_rows):
+ cols = st.columns(cols_per_row)
+ for col_idx in range(cols_per_row):
+ grid_idx = row_idx * cols_per_row + col_idx
+ if grid_idx < num_grids:
+ grid_config = comparison_df.iloc[grid_idx]["Grid"]
+ grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
+
+ with cols[col_idx]:
+ fig = px.pie(
+ grid_data,
+ names="Data Source",
+ values="Number of Features",
+ title=grid_config,
+ hole=0.4,
+ color_discrete_sequence=source_colors,
+ )
+ fig.update_traces(textposition="inside", textinfo="percent")
+ fig.update_layout(showlegend=True, height=350)
+ st.plotly_chart(fig, use_container_width=True)
+
+ with comp_tab3:
+ st.markdown("#### Detailed Feature Count Comparison")
+
+ # Display full comparison table with formatting
+ display_df = comparison_df[
+ ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
+ ].copy()
+
+ # Format numbers with commas
+ for col in ["Total Features", "Inference Cells", "Total Samples"]:
+ display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
+
+ # Format boolean as Yes/No
+ display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
+
+ st.dataframe(display_df, hide_index=True, use_container_width=True)
+
+ st.divider()
+
+ # Second section: Detailed configuration with user selection
+ st.markdown("### Detailed Configuration Explorer")
+ st.markdown("Select specific grid configuration and data sources for detailed statistics")
+
+ # Grid selection
+ grid_options = [
+ "hex-3",
+ "hex-4",
+ "hex-5",
+ "hex-6",
+ "healpix-6",
+ "healpix-7",
+ "healpix-8",
+ "healpix-9",
+ "healpix-10",
+ ]
+
+ col1, col2 = st.columns(2)
+
+ with col1:
+ grid_level_combined = st.selectbox(
+ "Grid Configuration",
+ options=grid_options,
+ index=0,
+ help="Select the grid system and resolution level",
+ key="feature_grid_select",
+ )
+
+ with col2:
+ target = st.selectbox(
+ "Target Dataset",
+ options=["darts_rts", "darts_mllabels"],
+ index=0,
+ help="Select the target dataset",
+ key="feature_target_select",
+ )
+
+ # Parse grid type and level
+ grid, level_str = grid_level_combined.split("-")
+ level = int(level_str)
+
+ # Members selection
+ st.markdown("#### Select Data Sources")
+
+ # Check if AlphaEarth should be disabled
+ disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
+
+ all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+
+ # Use columns for checkboxes
+ cols = st.columns(len(all_members))
+ selected_members = []
+
+ for idx, member in enumerate(all_members):
+ with cols[idx]:
+ if member == "AlphaEarth" and disable_alphaearth:
+ st.checkbox(
+ member,
+ value=False,
+ disabled=True,
+ help=f"Not available for {grid} level {level}",
+ key=f"feature_member_{member}",
+ )
+ else:
+ if st.checkbox(member, value=True, key=f"feature_member_{member}"):
+ selected_members.append(member)
+
+ # Show results if at least one member is selected
+ if selected_members:
+ st.markdown("---")
+
+ ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
+
+ with st.spinner("Computing dataset statistics..."):
+ stats = ensemble.get_stats()
+
+ # High-level metrics
+ col1, col2, col3, col4, col5 = st.columns(5)
+ with col1:
+ st.metric("Total Features", f"{stats['total_features']:,}")
+ with col2:
+ # Calculate minimum cells across all data sources (for inference capability)
+ min_cells = min(
+ member_stats["dimensions"]["cell_ids"] # type: ignore[index]
+ for member_stats in stats["members"].values()
+ )
+ st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
+ with col3:
+ st.metric("Data Sources", len(selected_members))
+ with col4:
+ st.metric("Total Samples", f"{stats['num_target_samples']:,}")
+ with col5:
+ # Calculate total data points
+ total_points = stats["total_features"] * stats["num_target_samples"]
+ st.metric("Total Data Points", f"{total_points:,}")
+
+ # Feature breakdown by source
+ st.markdown("#### Feature Breakdown by Data Source")
+
+ breakdown_data = []
+ for member, member_stats in stats["members"].items():
+ breakdown_data.append(
+ {
+ "Data Source": member,
+ "Number of Features": member_stats["num_features"],
+ "Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
+ }
+ )
+
+ # Add lon/lat
+ if ensemble.add_lonlat:
+ breakdown_data.append(
+ {
+ "Data Source": "Lon/Lat",
+ "Number of Features": 2,
+ "Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
+ }
+ )
+
+ breakdown_df = pd.DataFrame(breakdown_data)
+
+ # Get color palette for data sources
+ n_sources = len(breakdown_df)
+ source_colors = get_palette("data_sources", n_colors=n_sources)
+
+ # Create pie chart
+ fig = px.pie(
+ breakdown_df,
+ names="Data Source",
+ values="Number of Features",
+ title="Feature Distribution by Data Source",
+ hole=0.4,
+ color_discrete_sequence=source_colors,
+ )
+ fig.update_traces(textposition="inside", textinfo="percent+label")
+ st.plotly_chart(fig, use_container_width=True)
+
+ # Show detailed table
+ st.dataframe(breakdown_df, hide_index=True, use_container_width=True)
+
+ # Detailed member information
+ with st.expander("📦 Detailed Source Information", expanded=False):
+ for member, member_stats in stats["members"].items():
+ st.markdown(f"### {member}")
+
+ metric_cols = st.columns(4)
+ with metric_cols[0]:
+ st.metric("Features", member_stats["num_features"])
+ with metric_cols[1]:
+ st.metric("Variables", member_stats["num_variables"])
+ with metric_cols[2]:
+ dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
+ st.metric("Shape", dim_str)
+ with metric_cols[3]:
+ total_points = 1
+ for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
+ total_points *= dim_size
+ st.metric("Data Points", f"{total_points:,}")
+
+ # Variables
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats["variables"] # type: ignore[union-attr]
+ ]
+ )
+ st.markdown(vars_html, unsafe_allow_html=True)
+
+ # Dimensions
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size}"
+ for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
+
+ st.markdown("---")
+ else:
+ st.info("👆 Select at least one data source to see feature statistics")
+
+
+def render_dataset_analysis():
+ """Render the dataset analysis section with sample and feature counts."""
+ st.header("📈 Dataset Analysis")
+
+ # Create tabs for the two different analyses
+ analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
+
+ with analysis_tabs[0]:
+ render_sample_count_overview()
+
+ with analysis_tabs[1]:
+ render_feature_count_fragment()
def render_overview_page():
@@ -20,8 +650,10 @@ def render_overview_page():
st.write(f"Found **{len(training_results)}** training result(s)")
+ st.divider()
+
# Summary statistics at the top
- st.subheader("Summary Statistics")
+ st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4)
with col1:
@@ -43,8 +675,14 @@ def render_overview_page():
st.divider()
+ # Add dataset analysis section
+ render_dataset_analysis()
+
+ st.divider()
+
# Detailed results table
- st.subheader("Training Results")
+ st.header("🎯 Experiment Results")
+ st.subheader("Results Table")
# Build a summary dataframe
summary_data = []
@@ -93,7 +731,7 @@ def render_overview_page():
st.divider()
# Expandable details for each result
- st.subheader("Detailed Results")
+ st.subheader("Individual Experiment Details")
for tr in training_results:
with st.expander(tr.get_display_name("task_first")):