diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/overview_page.py index 7d087c5..64ebf34 100644 --- a/src/entropice/dashboard/overview_page.py +++ b/src/entropice/dashboard/overview_page.py @@ -1,6 +1,8 @@ """Overview page: List of available result directories with some summary statistics.""" +from dataclasses import dataclass from datetime import datetime +from typing import TypedDict import pandas as pd import plotly.express as px @@ -11,7 +13,225 @@ from entropice.dashboard.utils.data import load_all_training_results from entropice.dataset import DatasetEnsemble -def render_sample_count_overview(): +# Type definitions for dataset statistics +class GridConfig(TypedDict): + """Grid configuration specification with metadata.""" + + grid: str + level: int + grid_name: str + grid_sort_key: str + disable_alphaearth: bool + + +class SampleCountData(TypedDict): + """Sample count statistics for a specific grid/target/task combination.""" + + grid_config: GridConfig + target: str + task: str + samples_coverage: int + samples_labels: int + samples_both: int + + +class FeatureCountData(TypedDict): + """Feature count statistics for a specific grid configuration.""" + + grid_config: GridConfig + total_features: int + data_sources: list[str] + inference_cells: int + total_samples: int + member_breakdown: dict[str, int] + + +@dataclass(frozen=True) +class DatasetAnalysisCache: + """Cache for dataset analysis data to avoid redundant computations.""" + + grid_configs: list[GridConfig] + sample_counts: list[SampleCountData] + feature_counts: list[FeatureCountData] + + def get_sample_count_df(self) -> pd.DataFrame: + """Convert sample count data to DataFrame.""" + rows = [] + for item in self.sample_counts: + rows.append( + { + "Grid": item["grid_config"]["grid_name"], + "Grid Type": item["grid_config"]["grid"], + "Level": item["grid_config"]["level"], + "Target": item["target"].replace("darts_", ""), + "Task": item["task"].capitalize(), + "Samples (Coverage)": item["samples_coverage"], + "Samples (Labels)": item["samples_labels"], + "Samples (Both)": item["samples_both"], + "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], + } + ) + return pd.DataFrame(rows) + + def get_feature_count_df(self) -> pd.DataFrame: + """Convert feature count data to DataFrame.""" + rows = [] + for item in self.feature_counts: + rows.append( + { + "Grid": item["grid_config"]["grid_name"], + "Grid Type": item["grid_config"]["grid"], + "Level": item["grid_config"]["level"], + "Total Features": item["total_features"], + "Data Sources": len(item["data_sources"]), + "Inference Cells": item["inference_cells"], + "Total Samples": item["total_samples"], + "AlphaEarth": "AlphaEarth" in item["data_sources"], + "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], + } + ) + return pd.DataFrame(rows) + + def get_feature_breakdown_df(self) -> pd.DataFrame: + """Convert feature breakdown data to DataFrame for stacked/donut charts.""" + rows = [] + for item in self.feature_counts: + for source, count in item["member_breakdown"].items(): + rows.append( + { + "Grid": item["grid_config"]["grid_name"], + "Data Source": source, + "Number of Features": count, + "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], + } + ) + return pd.DataFrame(rows) + + +@st.cache_data(show_spinner=False) +def load_dataset_analysis_data() -> DatasetAnalysisCache: + """Load and cache all dataset analysis data. + + This function computes both sample counts and feature counts for all grid configurations. + Results are cached to avoid redundant computations across different tabs. + """ + # Define all possible grid configurations + grid_configs_raw = [ + ("hex", 3), + ("hex", 4), + ("hex", 5), + ("hex", 6), + ("healpix", 6), + ("healpix", 7), + ("healpix", 8), + ("healpix", 9), + ("healpix", 10), + ] + + # Create structured grid config objects + grid_configs: list[GridConfig] = [] + for grid, level in grid_configs_raw: + disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) + grid_configs.append( + { + "grid": grid, + "level": level, + "grid_name": f"{grid}-{level}", + "grid_sort_key": f"{grid}_{level:02d}", + "disable_alphaearth": disable_alphaearth, + } + ) + + # Compute sample counts + sample_counts: list[SampleCountData] = [] + target_datasets = ["darts_rts", "darts_mllabels"] + tasks = ["binary", "count", "density"] + + for grid_config in grid_configs: + for target in target_datasets: + # Create minimal ensemble just to get target data + ensemble = DatasetEnsemble( + grid=grid_config["grid"], + level=grid_config["level"], + target=target, + members=[], # type: ignore[arg-type] + ) + targets = ensemble._read_target() + + for task in tasks: + # Get task-specific column + taskcol = ensemble.taskcol(task) # type: ignore[arg-type] + covcol = ensemble.covcol + + # Count samples with coverage and valid labels + if covcol in targets.columns and taskcol in targets.columns: + valid_coverage = targets[covcol].sum() + valid_labels = targets[taskcol].notna().sum() + valid_both = (targets[covcol] & targets[taskcol].notna()).sum() + + sample_counts.append( + { + "grid_config": grid_config, + "target": target, + "task": task, + "samples_coverage": valid_coverage, + "samples_labels": valid_labels, + "samples_both": valid_both, + } + ) + + # Compute feature counts + feature_counts: list[FeatureCountData] = [] + + for grid_config in grid_configs: + # Determine which members are available for this configuration + if grid_config["disable_alphaearth"]: + members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + else: + members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + + # Use darts_rts as default target for comparison + ensemble = DatasetEnsemble( + grid=grid_config["grid"], + level=grid_config["level"], + target="darts_rts", + members=members, # type: ignore[arg-type] + ) + stats = ensemble.get_stats() + + # Calculate minimum cells across all data sources + min_cells = min( + member_stats["dimensions"]["cell_ids"] # type: ignore[index] + for member_stats in stats["members"].values() + ) + + # Build member breakdown including lon/lat + member_breakdown = {} + for member, member_stats in stats["members"].items(): + member_breakdown[member] = member_stats["num_features"] + + if ensemble.add_lonlat: + member_breakdown["Lon/Lat"] = 2 + + feature_counts.append( + { + "grid_config": grid_config, + "total_features": stats["total_features"], + "data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []), + "inference_cells": min_cells, + "total_samples": stats["num_target_samples"], + "member_breakdown": member_breakdown, + } + ) + + return DatasetAnalysisCache( + grid_configs=grid_configs, + sample_counts=sample_counts, + feature_counts=feature_counts, + ) + + +def render_sample_count_overview(cache: DatasetAnalysisCache): """Render overview of sample counts per task+target+grid+level combination.""" st.subheader("📊 Sample Counts by Configuration") @@ -25,60 +245,9 @@ def render_sample_count_overview(): """ ) - # Define all possible grid configurations - grid_configs = [ - ("hex", 3), - ("hex", 4), - ("hex", 5), - ("hex", 6), - ("healpix", 6), - ("healpix", 7), - ("healpix", 8), - ("healpix", 9), - ("healpix", 10), - ] - + # Get sample count DataFrame from cache + sample_df = cache.get_sample_count_df() target_datasets = ["darts_rts", "darts_mllabels"] - tasks = ["binary", "count", "density"] - - # Collect sample counts - sample_data = [] - - with st.spinner("Computing sample counts for all configurations..."): - for grid, level in grid_configs: - for target in target_datasets: - # Create minimal ensemble just to get target data - ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type] - - # Read target data - targets = ensemble._read_target() - - for task in tasks: - # Get task-specific column - taskcol = ensemble.taskcol(task) # type: ignore[arg-type] - covcol = ensemble.covcol - - # Count samples with coverage and valid labels - if covcol in targets.columns and taskcol in targets.columns: - valid_coverage = targets[covcol].sum() - valid_labels = targets[taskcol].notna().sum() - valid_both = (targets[covcol] & targets[taskcol].notna()).sum() - - sample_data.append( - { - "Grid": f"{grid}-{level}", - "Grid Type": grid, - "Level": level, - "Target": target.replace("darts_", ""), - "Task": task.capitalize(), - "Samples (Coverage)": valid_coverage, - "Samples (Labels)": valid_labels, - "Samples (Both)": valid_both, - "Grid_Level_Sort": f"{grid}_{level:02d}", - } - ) - - sample_df = pd.DataFrame(sample_data) # Create tabs for different views tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"]) @@ -159,73 +328,15 @@ def render_sample_count_overview(): st.dataframe(display_df, hide_index=True, use_container_width=True) -@st.fragment -def render_feature_count_fragment(): - """Render interactive feature count visualization using fragments.""" - st.subheader("🔢 Feature Counts by Dataset Configuration") - - st.markdown( - """ - This visualization shows the total number of features that would be generated - for different combinations of data sources and grid configurations. - """ - ) - - # First section: Comparison across all grid configurations +def render_feature_count_comparison(cache: DatasetAnalysisCache): + """Render static comparison of feature counts across all grid configurations.""" st.markdown("### Feature Count Comparison Across Grid Configurations") st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") - # Define all possible grid configurations - grid_configs = [ - ("hex", 3), - ("hex", 4), - ("hex", 5), - ("hex", 6), - ("healpix", 6), - ("healpix", 7), - ("healpix", 8), - ("healpix", 9), - ("healpix", 10), - ] - - # Collect feature statistics for all configurations - feature_comparison_data = [] - - with st.spinner("Computing feature counts for all grid configurations..."): - for grid, level in grid_configs: - # Determine which members are available for this configuration - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - - if disable_alphaearth: - members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - else: - members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - - # Use darts_rts as default target for comparison - ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] - stats = ensemble.get_stats() - - # Calculate minimum cells across all data sources - min_cells = min( - member_stats["dimensions"]["cell_ids"] # type: ignore[index] - for member_stats in stats["members"].values() - ) - - feature_comparison_data.append( - { - "Grid": f"{grid}-{level}", - "Grid Type": grid, - "Level": level, - "Total Features": stats["total_features"], - "Data Sources": len(members), - "Inference Cells": min_cells, - "Total Samples": stats["num_target_samples"], - "AlphaEarth": "AlphaEarth" in members, - "Grid_Level_Sort": f"{grid}_{level:02d}", - } - ) - - comparison_df = pd.DataFrame(feature_comparison_data) + # Get data from cache + comparison_df = cache.get_feature_count_df() + breakdown_df = cache.get_feature_breakdown_df() + breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") # Create tabs for different comparison views comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"]) @@ -233,56 +344,14 @@ def render_feature_count_fragment(): with comp_tab1: st.markdown("#### Total Features by Grid Configuration") - # Collect breakdown data for stacked bar chart - stacked_data = [] - - for idx, row in comparison_df.iterrows(): - grid_config = row["Grid"] - grid, level_str = grid_config.split("-") - level = int(level_str) - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - - if disable_alphaearth: - members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - else: - members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - - ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] - stats = ensemble.get_stats() - - # Add data for each member - for member, member_stats in stats["members"].items(): - stacked_data.append( - { - "Grid": grid_config, - "Data Source": member, - "Number of Features": member_stats["num_features"], - "Grid_Level_Sort": row["Grid_Level_Sort"], - } - ) - - # Add lon/lat - if ensemble.add_lonlat: - stacked_data.append( - { - "Grid": grid_config, - "Data Source": "Lon/Lat", - "Number of Features": 2, - "Grid_Level_Sort": row["Grid_Level_Sort"], - } - ) - - stacked_df = pd.DataFrame(stacked_data) - stacked_df = stacked_df.sort_values("Grid_Level_Sort") - # Get color palette for data sources - unique_sources = stacked_df["Data Source"].unique() + unique_sources = breakdown_df["Data Source"].unique() n_sources = len(unique_sources) source_colors = get_palette("data_sources", n_colors=n_sources) # Create stacked bar chart fig = px.bar( - stacked_df, + breakdown_df, x="Grid", y="Number of Features", color="Data Source", @@ -336,58 +405,8 @@ def render_feature_count_fragment(): st.markdown("#### Feature Breakdown by Data Source") st.markdown("Showing percentage contribution of each data source across all grid configurations") - # Collect breakdown data for all grid configurations - all_breakdown_data = [] - - for idx, row in comparison_df.iterrows(): - grid_config = row["Grid"] - grid, level_str = grid_config.split("-") - level = int(level_str) - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - - if disable_alphaearth: - members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - else: - members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - - ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type] - stats = ensemble.get_stats() - - total_features = stats["total_features"] - - # Add data for each member with percentage - for member, member_stats in stats["members"].items(): - percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator] - all_breakdown_data.append( - { - "Grid": grid_config, - "Data Source": member, - "Percentage": percentage, - "Number of Features": member_stats["num_features"], - "Grid_Level_Sort": row["Grid_Level_Sort"], - } - ) - - # Add lon/lat - if ensemble.add_lonlat: - percentage = (2 / total_features) * 100 # type: ignore[operator] - all_breakdown_data.append( - { - "Grid": grid_config, - "Data Source": "Lon/Lat", - "Percentage": percentage, - "Number of Features": 2, - "Grid_Level_Sort": row["Grid_Level_Sort"], - } - ) - - breakdown_all_df = pd.DataFrame(all_breakdown_data) - - # Sort by grid configuration - breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort") - # Get color palette for data sources - unique_sources = breakdown_all_df["Data Source"].unique() + unique_sources = breakdown_df["Data Source"].unique() n_sources = len(unique_sources) source_colors = get_palette("data_sources", n_colors=n_sources) @@ -403,7 +422,7 @@ def render_feature_count_fragment(): grid_idx = row_idx * cols_per_row + col_idx if grid_idx < num_grids: grid_config = comparison_df.iloc[grid_idx]["Grid"] - grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config] + grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] with cols[col_idx]: fig = px.pie( @@ -435,24 +454,15 @@ def render_feature_count_fragment(): st.dataframe(display_df, hide_index=True, use_container_width=True) - st.divider() - # Second section: Detailed configuration with user selection +@st.fragment +def render_feature_count_explorer(cache: DatasetAnalysisCache): + """Render interactive detailed configuration explorer using fragments.""" st.markdown("### Detailed Configuration Explorer") st.markdown("Select specific grid configuration and data sources for detailed statistics") # Grid selection - grid_options = [ - "hex-3", - "hex-4", - "hex-5", - "hex-6", - "healpix-6", - "healpix-7", - "healpix-8", - "healpix-9", - "healpix-10", - ] + grid_options = [gc["grid_name"] for gc in cache.grid_configs] col1, col2 = st.columns(2) @@ -474,16 +484,15 @@ def render_feature_count_fragment(): key="feature_target_select", ) - # Parse grid type and level - grid, level_str = grid_level_combined.split("-") - level = int(level_str) + # Find the selected grid config + selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined) + grid = selected_grid_config["grid"] + level = selected_grid_config["level"] + disable_alphaearth = selected_grid_config["disable_alphaearth"] # Members selection st.markdown("#### Select Data Sources") - # Check if AlphaEarth should be disabled - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] # Use columns for checkboxes @@ -624,35 +633,46 @@ def render_feature_count_fragment(): st.info("👆 Select at least one data source to see feature statistics") +def render_feature_count_section(cache: DatasetAnalysisCache): + """Render the feature count section with comparison and explorer.""" + st.subheader("🔢 Feature Counts by Dataset Configuration") + + st.markdown( + """ + This visualization shows the total number of features that would be generated + for different combinations of data sources and grid configurations. + """ + ) + + # Static comparison across all grids + render_feature_count_comparison(cache) + + st.divider() + + # Interactive explorer for detailed analysis + render_feature_count_explorer(cache) + + def render_dataset_analysis(): """Render the dataset analysis section with sample and feature counts.""" st.header("📈 Dataset Analysis") + # Load all data once and cache it + with st.spinner("Loading dataset analysis data..."): + cache = load_dataset_analysis_data() + # Create tabs for the two different analyses analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) with analysis_tabs[0]: - render_sample_count_overview() + render_sample_count_overview(cache) with analysis_tabs[1]: - render_feature_count_fragment() + render_feature_count_section(cache) -def render_overview_page(): - """Render the Overview page of the dashboard.""" - st.title("🏡 Training Results Overview") - - training_results = load_all_training_results() - - if not training_results: - st.warning("No training results found. Please run some training experiments first.") - return - - st.write(f"Found **{len(training_results)}** training result(s)") - - st.divider() - - # Summary statistics at the top +def render_training_results_summary(training_results): + """Render summary metrics for training results.""" st.header("📊 Training Results Summary") col1, col2, col3, col4 = st.columns(4) @@ -673,14 +693,9 @@ def render_overview_page(): latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d") st.metric("Latest Run", latest_date) - st.divider() - # Add dataset analysis section - render_dataset_analysis() - - st.divider() - - # Detailed results table +def render_experiment_results(training_results): + """Render detailed experiment results table and expandable details.""" st.header("🎯 Experiment Results") st.subheader("Results Table") @@ -791,3 +806,33 @@ def render_overview_page(): st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") st.write(f"\n**Path:** `{tr.path}`") + + +def render_overview_page(): + """Render the Overview page of the dashboard.""" + st.title("🏡 Training Results Overview") + + # Load training results + training_results = load_all_training_results() + + if not training_results: + st.warning("No training results found. Please run some training experiments first.") + return + + st.write(f"Found **{len(training_results)}** training result(s)") + + st.divider() + + # Render training results sections + render_training_results_summary(training_results) + + st.divider() + + render_experiment_results(training_results) + + st.divider() + + # Render dataset analysis section + render_dataset_analysis() + + st.balloons()