From fca232da9128ea8155042615d48007627e8f45bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 4 Jan 2026 03:14:48 +0100 Subject: [PATCH] Overview Page refactor done --- src/entropice/dashboard/app.py | 23 +- src/entropice/dashboard/utils/loaders.py | 14 +- src/entropice/dashboard/utils/stats.py | 115 +++-- .../dashboard/views/inference_page.py | 2 +- .../dashboard/views/model_state_page.py | 18 +- .../dashboard/views/overview_page.py | 403 ++++-------------- src/entropice/utils/types.py | 58 +++ 7 files changed, 270 insertions(+), 363 deletions(-) diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 713f7d0..8e4555f 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -11,11 +11,12 @@ Pages: import streamlit as st -from entropice.dashboard.views.inference_page import render_inference_page -from entropice.dashboard.views.model_state_page import render_model_state_page +# from entropice.dashboard.views.inference_page import render_inference_page +# from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.overview_page import render_overview_page -from entropice.dashboard.views.training_analysis_page import render_training_analysis_page -from entropice.dashboard.views.training_data_page import render_training_data_page + +# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page +# from entropice.dashboard.views.training_data_page import render_training_data_page def main(): @@ -24,17 +25,17 @@ def main(): # Setup Navigation overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) - training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") - training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") - model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") - inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") + # training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") + # training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") + # model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") + # inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") pg = st.navigation( { "Overview": [overview_page], - "Training": [training_data_page, training_analysis_page], - "Model State": [model_state_page], - "Inference": [inference_page], + # "Training": [training_data_page, training_analysis_page], + # "Model State": [model_state_page], + # "Inference": [inference_page], } ) pg.run() diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index e19b2d1..c7f0645 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -44,8 +44,12 @@ class TrainingResult: result_file = result_path / "search_results.parquet" preds_file = result_path / "predicted_probabilities.parquet" settings_file = result_path / "search_settings.toml" - if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]): - raise FileNotFoundError(f"Missing required files in {result_path}") + if not result_file.exists(): + raise FileNotFoundError(f"Missing results file in {result_path}") + if not settings_file.exists(): + raise FileNotFoundError(f"Missing settings file in {result_path}") + if not preds_file.exists(): + raise FileNotFoundError(f"Missing predictions file in {result_path}") created_at = result_path.stat().st_ctime settings = TrainingSettings(**(toml.load(settings_file)["settings"])) @@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]: for result_path in results_dir.iterdir(): if not result_path.is_dir(): continue - training_result = TrainingResult.from_path(result_path) + try: + training_result = TrainingResult.from_path(result_path) + except FileNotFoundError as e: + st.warning(f"Skipping incomplete training result at {result_path}: {e}") + continue training_results.append(training_result) # Sort by creation time (most recent first) diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py index b0f6172..2d35e2f 100644 --- a/src/entropice/dashboard/utils/stats.py +++ b/src/entropice/dashboard/utils/stats.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import asdict, dataclass -from typing import Literal, cast, get_args +from typing import Literal import geopandas as gpd import pandas as pd @@ -17,7 +17,17 @@ import entropice.spatial.grids import entropice.utils.paths from entropice.dashboard.utils.loaders import TrainingResult from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol -from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task +from entropice.utils.types import ( + Grid, + GridLevel, + L2SourceDataset, + TargetDataset, + Task, + all_l2_source_datasets, + all_target_datasets, + all_tasks, + grid_configs, +) @dataclass(frozen=True) @@ -87,8 +97,7 @@ class TargetStatistics: training_coverage: dict[Task, float] = {} class_counts: dict[Task, dict[str, int]] = {} class_distribution: dict[Task, dict[str, float]] = {} - tasks = cast(list[Task], get_args(Task)) - for task in tasks: + for task in all_tasks: task_col = taskcol[task][target] cov_col = covcol[target] @@ -132,43 +141,97 @@ class DatasetStatistics: members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset + @staticmethod + def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + """Convert sample count data to DataFrame.""" + rows = [] + for grid_config in grid_configs: + stats = all_stats[grid_config.id] + for target_name, target_stats in stats.target.items(): + for task in all_tasks: + training_cells = target_stats.training_cells[task] + coverage = target_stats.coverage[task] + rows.append( + { + "Grid": grid_config.display_name, + "Target": target_name.replace("darts_", ""), + "Task": task.capitalize(), + "Samples (Coverage)": training_cells, + "Coverage %": coverage, + "Grid_Level_Sort": grid_config.sort_key, + } + ) + return pd.DataFrame(rows) + + @staticmethod + def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + """Convert feature count data to DataFrame.""" + rows = [] + for grid_config in grid_configs: + stats = all_stats[grid_config.id] + data_sources = list(stats.members.keys()) + + # Determine minimum cells across all data sources + min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) + + # Get sample count from first target dataset (darts_rts) + first_target = stats.target["darts_rts"] + total_samples = first_target.training_cells["binary"] + + rows.append( + { + "Grid": grid_config.display_name, + "Total Features": stats.total_features, + "Data Sources": len(data_sources), + "Inference Cells": min_cells, + "Total Samples": total_samples, + "Grid_Level_Sort": grid_config.sort_key, + } + ) + return pd.DataFrame(rows) + + @staticmethod + def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + """Convert feature breakdown data to DataFrame for stacked/donut charts.""" + rows = [] + for grid_config in grid_configs: + stats = all_stats[grid_config.id] + for member_name, member_stats in stats.members.items(): + rows.append( + { + "Grid": grid_config.display_name, + "Data Source": member_name, + "Number of Features": member_stats.feature_count, + "Grid_Level_Sort": grid_config.sort_key, + } + ) + return pd.DataFrame(rows) + @st.cache_data def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]: dataset_stats: dict[GridLevel, DatasetStatistics] = {} - grid_levels: set[tuple[Grid, int]] = { - ("hex", 3), - ("hex", 4), - ("hex", 5), - ("hex", 6), - ("healpix", 6), - ("healpix", 7), - ("healpix", 8), - ("healpix", 9), - ("healpix", 10), - } - for grid, level in grid_levels: - with stopwatch(f"Loading statistics for grid={grid}, level={level}"): - grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered + for grid_config in grid_configs: + with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"): + grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered total_cells = len(grid_gdf) assert total_cells > 0, "Grid must contain at least one cell." target_statistics: dict[TargetDataset, TargetStatistics] = {} - targets = cast(list[TargetDataset], get_args(TargetDataset)) - for target in targets: + for target in all_target_datasets: target_statistics[target] = TargetStatistics.compute( - grid=grid, level=level, target=target, total_cells=total_cells + grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells ) member_statistics: dict[L2SourceDataset, MemberStatistics] = {} - members = cast(list[L2SourceDataset], get_args(L2SourceDataset)) - for member in members: - member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member) + for member in all_l2_source_datasets: + member_statistics[member] = MemberStatistics.compute( + grid=grid_config.grid, level=grid_config.level, member=member + ) total_features = sum(ms.feature_count for ms in member_statistics.values()) total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum( ts.size_bytes for ts in target_statistics.values() ) - grid_level: GridLevel = cast(GridLevel, f"{grid}{level}") - dataset_stats[grid_level] = DatasetStatistics( + dataset_stats[grid_config.id] = DatasetStatistics( total_features=total_features, total_cells=total_cells, size_bytes=total_size_bytes, diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py index 5cdd277..86958b2 100644 --- a/src/entropice/dashboard/views/inference_page.py +++ b/src/entropice/dashboard/views/inference_page.py @@ -1,6 +1,7 @@ """Inference page: Visualization of model inference results across the study region.""" import streamlit as st +from entropice.dashboard.utils.data import load_all_training_results from entropice.dashboard.plots.inference import ( render_class_comparison, @@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import ( render_inference_statistics, render_spatial_distribution_stats, ) -from entropice.dashboard.utils.data import load_all_training_results def render_inference_page(): diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py index 1dd3c62..5cd90f5 100644 --- a/src/entropice/dashboard/views/model_state_page.py +++ b/src/entropice/dashboard/views/model_state_page.py @@ -2,6 +2,15 @@ import streamlit as st import xarray as xr +from entropice.dashboard.utils.data import ( + extract_arcticdem_features, + extract_common_features, + extract_embedding_features, + extract_era5_features, + get_members_from_settings, + load_all_training_results, +) +from entropice.dashboard.utils.training import load_model_state from entropice.dashboard.plots.model_state import ( plot_arcticdem_heatmap, @@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import ( plot_top_features, ) from entropice.dashboard.utils.colors import generate_unified_colormap -from entropice.dashboard.utils.data import ( - extract_arcticdem_features, - extract_common_features, - extract_embedding_features, - extract_era5_features, - get_members_from_settings, - load_all_training_results, -) -from entropice.dashboard.utils.training import load_model_state def render_model_state_page(): diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index de5e2de..d9676a3 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -1,238 +1,20 @@ """Overview page: List of available result directories with some summary statistics.""" -from dataclasses import dataclass from datetime import datetime -from typing import TypedDict +from typing import cast import pandas as pd import plotly.express as px import streamlit as st +from stopuhr import stopwatch from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.loaders import load_all_training_results -from entropice.ml.dataset import DatasetEnsemble -from entropice.utils.types import Grid, TargetDataset, Task +from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics +from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs -# Type definitions for dataset statistics -class GridConfig(TypedDict): - """Grid configuration specification with metadata.""" - - grid: Grid - level: int - grid_name: str - grid_sort_key: str - disable_alphaearth: bool - - -class SampleCountData(TypedDict): - """Sample count statistics for a specific grid/target/task combination.""" - - grid_config: GridConfig - target: str - task: str - samples_coverage: int - samples_labels: int - samples_both: int - - -class FeatureCountData(TypedDict): - """Feature count statistics for a specific grid configuration.""" - - grid_config: GridConfig - total_features: int - data_sources: list[str] - inference_cells: int - total_samples: int - member_breakdown: dict[str, int] - - -@dataclass(frozen=True) -class DatasetAnalysisCache: - """Cache for dataset analysis data to avoid redundant computations.""" - - grid_configs: list[GridConfig] - sample_counts: list[SampleCountData] - feature_counts: list[FeatureCountData] - - def get_sample_count_df(self) -> pd.DataFrame: - """Convert sample count data to DataFrame.""" - rows = [] - for item in self.sample_counts: - rows.append( - { - "Grid": item["grid_config"]["grid_name"], - "Grid Type": item["grid_config"]["grid"], - "Level": item["grid_config"]["level"], - "Target": item["target"].replace("darts_", ""), - "Task": item["task"].capitalize(), - "Samples (Coverage)": item["samples_coverage"], - "Samples (Labels)": item["samples_labels"], - "Samples (Both)": item["samples_both"], - "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], - } - ) - return pd.DataFrame(rows) - - def get_feature_count_df(self) -> pd.DataFrame: - """Convert feature count data to DataFrame.""" - rows = [] - for item in self.feature_counts: - rows.append( - { - "Grid": item["grid_config"]["grid_name"], - "Grid Type": item["grid_config"]["grid"], - "Level": item["grid_config"]["level"], - "Total Features": item["total_features"], - "Data Sources": len(item["data_sources"]), - "Inference Cells": item["inference_cells"], - "Total Samples": item["total_samples"], - "AlphaEarth": "AlphaEarth" in item["data_sources"], - "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], - } - ) - return pd.DataFrame(rows) - - def get_feature_breakdown_df(self) -> pd.DataFrame: - """Convert feature breakdown data to DataFrame for stacked/donut charts.""" - rows = [] - for item in self.feature_counts: - for source, count in item["member_breakdown"].items(): - rows.append( - { - "Grid": item["grid_config"]["grid_name"], - "Data Source": source, - "Number of Features": count, - "Grid_Level_Sort": item["grid_config"]["grid_sort_key"], - } - ) - return pd.DataFrame(rows) - - -@st.cache_data(show_spinner=False) -def load_dataset_analysis_data() -> DatasetAnalysisCache: - """Load and cache all dataset analysis data. - - This function computes both sample counts and feature counts for all grid configurations. - Results are cached to avoid redundant computations across different tabs. - """ - # Define all possible grid configurations - grid_configs_raw: list[tuple[Grid, int]] = [ - ("hex", 3), - ("hex", 4), - ("hex", 5), - ("hex", 6), - ("healpix", 6), - ("healpix", 7), - ("healpix", 8), - ("healpix", 9), - ("healpix", 10), - ] - - # Create structured grid config objects - grid_configs: list[GridConfig] = [] - for grid, level in grid_configs_raw: - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - grid_configs.append( - { - "grid": grid, - "level": level, - "grid_name": f"{grid}-{level}", - "grid_sort_key": f"{grid}_{level:02d}", - "disable_alphaearth": disable_alphaearth, - } - ) - - # Compute sample counts - sample_counts: list[SampleCountData] = [] - target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"] - tasks: list[Task] = ["binary", "count", "density"] - - for grid_config in grid_configs: - for target in target_datasets: - # Create minimal ensemble just to get target data - ensemble = DatasetEnsemble( - grid=grid_config["grid"], - level=grid_config["level"], - target=target, - members=[], # type: ignore[arg-type] - ) - targets = ensemble._read_target() - - for task in tasks: - # Get task-specific column - taskcol = ensemble.taskcol(task) # type: ignore[arg-type] - covcol = ensemble.covcol - - # Count samples with coverage and valid labels - if covcol in targets.columns and taskcol in targets.columns: - valid_coverage = targets[covcol].sum() - valid_labels = targets[taskcol].notna().sum() - valid_both = (targets[covcol] & targets[taskcol].notna()).sum() - - sample_counts.append( - { - "grid_config": grid_config, - "target": target, - "task": task, - "samples_coverage": valid_coverage, - "samples_labels": valid_labels, - "samples_both": valid_both, - } - ) - - # Compute feature counts - feature_counts: list[FeatureCountData] = [] - - for grid_config in grid_configs: - # Determine which members are available for this configuration - if grid_config["disable_alphaearth"]: - members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - else: - members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] - - # Use darts_rts as default target for comparison - ensemble = DatasetEnsemble( - grid=grid_config["grid"], - level=grid_config["level"], - target="darts_rts", - members=members, # type: ignore[arg-type] - ) - stats = ensemble.get_stats() - - # Calculate minimum cells across all data sources - min_cells = min( - member_stats["dimensions"]["cell_ids"] # type: ignore[index] - for member_stats in stats["members"].values() - ) - - # Build member breakdown including lon/lat - member_breakdown = {} - for member, member_stats in stats["members"].items(): - member_breakdown[member] = member_stats["num_features"] - - if ensemble.add_lonlat: - member_breakdown["Lon/Lat"] = 2 - - feature_counts.append( - { - "grid_config": grid_config, - "total_features": stats["total_features"], - "data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []), - "inference_cells": min_cells, - "total_samples": stats["num_target_samples"], - "member_breakdown": member_breakdown, - } - ) - - return DatasetAnalysisCache( - grid_configs=grid_configs, - sample_counts=sample_counts, - feature_counts=feature_counts, - ) - - -def render_sample_count_overview(cache: DatasetAnalysisCache): +def render_sample_count_overview(): """Render overview of sample counts per task+target+grid+level combination.""" st.subheader("📊 Sample Counts by Configuration") @@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): ) # Get sample count DataFrame from cache - sample_df = cache.get_sample_count_df() + all_stats = load_all_default_dataset_statistics() + sample_df = DatasetStatistics.get_sample_count_df(all_stats) target_datasets = ["darts_rts", "darts_mllabels"] # Create tabs for different views @@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): with tab1: st.markdown("### Sample Counts Heatmap") - st.markdown("Showing counts of samples with both coverage and valid labels") + st.markdown("Showing counts of samples with coverage") # Create heatmap for each target dataset for target in target_datasets: target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] # Pivot for heatmap: Grid x Task - pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean") + pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean") # Sort index by grid type and level sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") @@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): with tab2: st.markdown("### Sample Counts Bar Chart") - st.markdown("Showing counts of samples with both coverage and valid labels") + st.markdown("Showing counts of samples with coverage") # Create a faceted bar chart showing both targets side by side # Get color palette for tasks @@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): fig = px.bar( sample_df, x="Grid", - y="Samples (Both)", + y="Samples (Coverage)", color="Task", facet_col="Target", barmode="group", title="Sample Counts by Grid Configuration and Target Dataset", - labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"}, + labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"}, color_discrete_sequence=task_colors, height=500, ) @@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): st.markdown("### Detailed Sample Counts") # Display full table with formatting - display_df = sample_df[ - ["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"] - ].copy() + display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy() # Format numbers with commas - for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]: - display_df[col] = display_df[col].apply(lambda x: f"{x:,}") + display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}") + # Format coverage as percentage with 2 decimal places + display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%") st.dataframe(display_df, hide_index=True, width="stretch") -def render_feature_count_comparison(cache: DatasetAnalysisCache): +def render_feature_count_comparison(): """Render static comparison of feature counts across all grid configurations.""" st.markdown("### Feature Count Comparison Across Grid Configurations") st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") # Get data from cache - comparison_df = cache.get_feature_count_df() - breakdown_df = cache.get_feature_breakdown_df() + all_stats = load_all_default_dataset_statistics() + comparison_df = DatasetStatistics.get_feature_count_df(all_stats) + breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") # Create tabs for different comparison views @@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): # Display full comparison table with formatting display_df = comparison_df[ - ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"] + ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"] ].copy() # Format numbers with commas for col in ["Total Features", "Inference Cells", "Total Samples"]: display_df[col] = display_df[col].apply(lambda x: f"{x:,}") - # Format boolean as Yes/No - display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗") - st.dataframe(display_df, hide_index=True, width="stretch") @st.fragment -def render_feature_count_explorer(cache: DatasetAnalysisCache): +def render_feature_count_explorer(): """Render interactive detailed configuration explorer using fragments.""" st.markdown("### Detailed Configuration Explorer") st.markdown("Select specific grid configuration and data sources for detailed statistics") # Grid selection - grid_options = [gc["grid_name"] for gc in cache.grid_configs] + grid_options = [gc.display_name for gc in grid_configs] col1, col2 = st.columns(2) @@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache): ) # Find the selected grid config - selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined) - grid = selected_grid_config["grid"] - level = selected_grid_config["level"] - disable_alphaearth = selected_grid_config["disable_alphaearth"] + selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) + + # Get available members from the stats + all_stats = load_all_default_dataset_statistics() + stats = all_stats[selected_grid_config.id] + available_members = cast(list[L2SourceDataset], list(stats.members.keys())) # Members selection st.markdown("#### Select Data Sources") - all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + all_members = cast( + list[L2SourceDataset], + ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"], + ) # Use columns for checkboxes cols = st.columns(len(all_members)) - selected_members = [] + selected_members: list[L2SourceDataset] = [] for idx, member in enumerate(all_members): with cols[idx]: - if member == "AlphaEarth" and disable_alphaearth: - st.checkbox( - member, - value=False, - disabled=True, - help=f"Not available for {grid} level {level}", - key=f"feature_member_{member}", - ) - else: - if st.checkbox(member, value=True, key=f"feature_member_{member}"): - selected_members.append(member) + default_value = member in available_members + if st.checkbox(member, value=default_value, key=f"feature_member_{member}"): + selected_members.append(cast(L2SourceDataset, member)) # Show results if at least one member is selected if selected_members: st.markdown("---") - ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members) + # Get statistics from cache (already loaded) + grid_stats = all_stats[selected_grid_config.id] - with st.spinner("Computing dataset statistics..."): - stats = ensemble.get_stats() + # Filter to selected members only + selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members} + + # Calculate total features for selected members + total_features = sum(ms.feature_count for ms in selected_member_stats.values()) + + # Get target stats + target_stats = grid_stats.target[cast(TargetDataset, target)] # High-level metrics col1, col2, col3, col4, col5 = st.columns(5) with col1: - st.metric("Total Features", f"{stats['total_features']:,}") + st.metric("Total Features", f"{total_features:,}") with col2: # Calculate minimum cells across all data sources (for inference capability) - min_cells = min( - member_stats["dimensions"]["cell_ids"] # type: ignore[index] - for member_stats in stats["members"].values() - ) + min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values()) st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources") with col3: st.metric("Data Sources", len(selected_members)) with col4: - st.metric("Total Samples", f"{stats['num_target_samples']:,}") + # Use binary task training cells as sample count + st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}") with col5: # Calculate total data points - total_points = stats["total_features"] * stats["num_target_samples"] + total_points = total_features * target_stats.training_cells["binary"] st.metric("Total Data Points", f"{total_points:,}") # Feature breakdown by source st.markdown("#### Feature Breakdown by Data Source") breakdown_data = [] - for member, member_stats in stats["members"].items(): + for member, member_stats in selected_member_stats.items(): breakdown_data.append( { "Data Source": member, - "Number of Features": member_stats["num_features"], - "Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator] - } - ) - - # Add lon/lat - if ensemble.add_lonlat: - breakdown_data.append( - { - "Data Source": "Lon/Lat", - "Number of Features": 2, - "Percentage": f"{2 / stats['total_features'] * 100:.1f}%", + "Number of Features": member_stats.feature_count, + "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%", } ) @@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache): # Detailed member information with st.expander("📦 Detailed Source Information", expanded=False): - for member, member_stats in stats["members"].items(): + for member, member_stats in selected_member_stats.items(): st.markdown(f"### {member}") metric_cols = st.columns(4) with metric_cols[0]: - st.metric("Features", member_stats["num_features"]) + st.metric("Features", member_stats.feature_count) with metric_cols[1]: - st.metric("Variables", member_stats["num_variables"]) + st.metric("Variables", len(member_stats.variable_names)) with metric_cols[2]: - dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr] + dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()]) st.metric("Shape", dim_str) with metric_cols[3]: total_points = 1 - for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr] + for dim_size in member_stats.dimensions.values(): total_points *= dim_size st.metric("Data Points", f"{total_points:,}") @@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache): [ f'{v}' - for v in member_stats["variables"] # type: ignore[union-attr] + for v in member_stats.variable_names ] ) st.markdown(vars_html, unsafe_allow_html=True) @@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache): f'' f"{dim_name}: {dim_size}" - for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr] + for dim_name, dim_size in member_stats.dimensions.items() ] ) st.markdown(dim_html, unsafe_allow_html=True) + # Size on disk + size_mb = member_stats.size_bytes / (1024 * 1024) + st.markdown(f"**Size on Disk:** {size_mb:.2f} MB") + st.markdown("---") else: st.info("👆 Select at least one data source to see feature statistics") -def render_feature_count_section(cache: DatasetAnalysisCache): +def render_feature_count_section(): """Render the feature count section with comparison and explorer.""" st.subheader("🔢 Feature Counts by Dataset Configuration") @@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache): ) # Static comparison across all grids - render_feature_count_comparison(cache) + render_feature_count_comparison() st.divider() # Interactive explorer for detailed analysis - render_feature_count_explorer(cache) + render_feature_count_explorer() def render_dataset_analysis(): """Render the dataset analysis section with sample and feature counts.""" st.header("📈 Dataset Analysis") - # Load all data once and cache it - with st.spinner("Loading dataset analysis data..."): - cache = load_dataset_analysis_data() - # Create tabs for the two different analyses analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) with analysis_tabs[0]: - render_sample_count_overview(cache) + render_sample_count_overview() with analysis_tabs[1]: - render_feature_count_section(cache) + render_feature_count_section() def render_training_results_summary(training_results): @@ -678,15 +450,15 @@ def render_training_results_summary(training_results): col1, col2, col3, col4 = st.columns(4) with col1: - tasks = {tr.settings.get("task", "Unknown") for tr in training_results} + tasks = {tr.settings.task for tr in training_results} st.metric("Tasks", len(tasks)) with col2: - grids = {tr.settings.get("grid", "Unknown") for tr in training_results} + grids = {tr.settings.grid for tr in training_results} st.metric("Grid Types", len(grids)) with col3: - models = {tr.settings.get("model", "Unknown") for tr in training_results} + models = {tr.settings.model for tr in training_results} st.metric("Model Types", len(models)) with col4: @@ -725,10 +497,10 @@ def render_experiment_results(training_results): summary_data.append( { "Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"), - "Task": tr.settings.get("task", "Unknown"), - "Grid": tr.settings.get("grid", "Unknown"), - "Level": tr.settings.get("level", "Unknown"), - "Model": tr.settings.get("model", "Unknown"), + "Task": tr.settings.task, + "Grid": tr.settings.grid, + "Level": tr.settings.level, + "Model": tr.settings.model, f"Best {primary_metric.title()}": f"{primary_score:.4f}", "Trials": len(tr.results), "Path": str(tr.path.name), @@ -750,17 +522,20 @@ def render_experiment_results(training_results): st.subheader("Individual Experiment Details") for tr in training_results: - with st.expander(tr.get_display_name("task_first")): + display_name = ( + f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}" + ) + with st.expander(display_name): col1, col2 = st.columns([1, 2]) with col1: st.write("**Configuration:**") - st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}") - st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}") - st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}") - st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}") - st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}") - st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}") + st.write(f"- **Task:** {tr.settings.task}") + st.write(f"- **Grid:** {tr.settings.grid}") + st.write(f"- **Level:** {tr.settings.level}") + st.write(f"- **Model:** {tr.settings.model}") + st.write(f"- **CV Splits:** {tr.settings.cv_splits}") + st.write(f"- **Classes:** {tr.settings.classes}") st.write("\n**Files:**") st.write("- 📊 search_results.parquet") @@ -844,3 +619,5 @@ def render_overview_page(): render_dataset_analysis() st.balloons() + + stopwatch.summary() diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py index 87c32dd..eb7eb91 100644 --- a/src/entropice/utils/types.py +++ b/src/entropice/utils/types.py @@ -1,5 +1,6 @@ """Shared types used across the entropice codebase.""" +from dataclasses import dataclass from typing import Literal type Grid = Literal["hex", "healpix"] @@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] type Task = Literal["binary", "count", "density"] type Model = Literal["espa", "xgboost", "rf", "knn"] + + +@dataclass(frozen=True) +class GridConfig: + """Grid configuration specification with metadata.""" + + grid: Grid + level: int + id: GridLevel + display_name: str + sort_key: str + + @classmethod + def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig": + """Create a GridConfig from a GridLevel string.""" + if grid_level.startswith("hex"): + grid = "hex" + level = int(grid_level[3:]) + elif grid_level.startswith("healpix"): + grid = "healpix" + level = int(grid_level[7:]) + else: + raise ValueError(f"Invalid grid level: {grid_level}") + + display_name = f"{grid.capitalize()}-{level}" + + return cls( + grid=grid, + level=level, + id=grid_level, + display_name=display_name, + sort_key=f"{grid}_{level:02d}", + ) + + +# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists +all_tasks: list[Task] = ["binary", "count", "density"] +all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"] +all_l2_source_datasets: list[L2SourceDataset] = [ + "ArcticDEM", + "ERA5-shoulder", + "ERA5-seasonal", + "ERA5-yearly", + "AlphaEarth", +] + +grid_configs: list[GridConfig] = [ + GridConfig.from_grid_level("hex3"), + GridConfig.from_grid_level("hex4"), + GridConfig.from_grid_level("hex5"), + GridConfig.from_grid_level("hex6"), + GridConfig.from_grid_level("healpix6"), + GridConfig.from_grid_level("healpix7"), + GridConfig.from_grid_level("healpix8"), + GridConfig.from_grid_level("healpix9"), + GridConfig.from_grid_level("healpix10"), +]