diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 713f7d0..8e4555f 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -11,11 +11,12 @@ Pages:
import streamlit as st
-from entropice.dashboard.views.inference_page import render_inference_page
-from entropice.dashboard.views.model_state_page import render_model_state_page
+# from entropice.dashboard.views.inference_page import render_inference_page
+# from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
-from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
-from entropice.dashboard.views.training_data_page import render_training_data_page
+
+# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
+# from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@@ -24,17 +25,17 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
- training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
- training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
- model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
- inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
+ # training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
+ # training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
+ # model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
+ # inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
- "Training": [training_data_page, training_analysis_page],
- "Model State": [model_state_page],
- "Inference": [inference_page],
+ # "Training": [training_data_page, training_analysis_page],
+ # "Model State": [model_state_page],
+ # "Inference": [inference_page],
}
)
pg.run()
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index e19b2d1..c7f0645 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -44,8 +44,12 @@ class TrainingResult:
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
- if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
- raise FileNotFoundError(f"Missing required files in {result_path}")
+ if not result_file.exists():
+ raise FileNotFoundError(f"Missing results file in {result_path}")
+ if not settings_file.exists():
+ raise FileNotFoundError(f"Missing settings file in {result_path}")
+ if not preds_file.exists():
+ raise FileNotFoundError(f"Missing predictions file in {result_path}")
created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
@@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]:
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
- training_result = TrainingResult.from_path(result_path)
+ try:
+ training_result = TrainingResult.from_path(result_path)
+ except FileNotFoundError as e:
+ st.warning(f"Skipping incomplete training result at {result_path}: {e}")
+ continue
training_results.append(training_result)
# Sort by creation time (most recent first)
diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py
index b0f6172..2d35e2f 100644
--- a/src/entropice/dashboard/utils/stats.py
+++ b/src/entropice/dashboard/utils/stats.py
@@ -5,7 +5,7 @@
from collections import defaultdict
from dataclasses import asdict, dataclass
-from typing import Literal, cast, get_args
+from typing import Literal
import geopandas as gpd
import pandas as pd
@@ -17,7 +17,17 @@ import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
-from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
+from entropice.utils.types import (
+ Grid,
+ GridLevel,
+ L2SourceDataset,
+ TargetDataset,
+ Task,
+ all_l2_source_datasets,
+ all_target_datasets,
+ all_tasks,
+ grid_configs,
+)
@dataclass(frozen=True)
@@ -87,8 +97,7 @@ class TargetStatistics:
training_coverage: dict[Task, float] = {}
class_counts: dict[Task, dict[str, int]] = {}
class_distribution: dict[Task, dict[str, float]] = {}
- tasks = cast(list[Task], get_args(Task))
- for task in tasks:
+ for task in all_tasks:
task_col = taskcol[task][target]
cov_col = covcol[target]
@@ -132,43 +141,97 @@ class DatasetStatistics:
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
+ @staticmethod
+ def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ """Convert sample count data to DataFrame."""
+ rows = []
+ for grid_config in grid_configs:
+ stats = all_stats[grid_config.id]
+ for target_name, target_stats in stats.target.items():
+ for task in all_tasks:
+ training_cells = target_stats.training_cells[task]
+ coverage = target_stats.coverage[task]
+ rows.append(
+ {
+ "Grid": grid_config.display_name,
+ "Target": target_name.replace("darts_", ""),
+ "Task": task.capitalize(),
+ "Samples (Coverage)": training_cells,
+ "Coverage %": coverage,
+ "Grid_Level_Sort": grid_config.sort_key,
+ }
+ )
+ return pd.DataFrame(rows)
+
+ @staticmethod
+ def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ """Convert feature count data to DataFrame."""
+ rows = []
+ for grid_config in grid_configs:
+ stats = all_stats[grid_config.id]
+ data_sources = list(stats.members.keys())
+
+ # Determine minimum cells across all data sources
+ min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
+
+ # Get sample count from first target dataset (darts_rts)
+ first_target = stats.target["darts_rts"]
+ total_samples = first_target.training_cells["binary"]
+
+ rows.append(
+ {
+ "Grid": grid_config.display_name,
+ "Total Features": stats.total_features,
+ "Data Sources": len(data_sources),
+ "Inference Cells": min_cells,
+ "Total Samples": total_samples,
+ "Grid_Level_Sort": grid_config.sort_key,
+ }
+ )
+ return pd.DataFrame(rows)
+
+ @staticmethod
+ def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ """Convert feature breakdown data to DataFrame for stacked/donut charts."""
+ rows = []
+ for grid_config in grid_configs:
+ stats = all_stats[grid_config.id]
+ for member_name, member_stats in stats.members.items():
+ rows.append(
+ {
+ "Grid": grid_config.display_name,
+ "Data Source": member_name,
+ "Number of Features": member_stats.feature_count,
+ "Grid_Level_Sort": grid_config.sort_key,
+ }
+ )
+ return pd.DataFrame(rows)
+
@st.cache_data
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
- grid_levels: set[tuple[Grid, int]] = {
- ("hex", 3),
- ("hex", 4),
- ("hex", 5),
- ("hex", 6),
- ("healpix", 6),
- ("healpix", 7),
- ("healpix", 8),
- ("healpix", 9),
- ("healpix", 10),
- }
- for grid, level in grid_levels:
- with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
- grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
+ for grid_config in grid_configs:
+ with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
+ grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell."
target_statistics: dict[TargetDataset, TargetStatistics] = {}
- targets = cast(list[TargetDataset], get_args(TargetDataset))
- for target in targets:
+ for target in all_target_datasets:
target_statistics[target] = TargetStatistics.compute(
- grid=grid, level=level, target=target, total_cells=total_cells
+ grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
)
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
- members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
- for member in members:
- member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
+ for member in all_l2_source_datasets:
+ member_statistics[member] = MemberStatistics.compute(
+ grid=grid_config.grid, level=grid_config.level, member=member
+ )
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
ts.size_bytes for ts in target_statistics.values()
)
- grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
- dataset_stats[grid_level] = DatasetStatistics(
+ dataset_stats[grid_config.id] = DatasetStatistics(
total_features=total_features,
total_cells=total_cells,
size_bytes=total_size_bytes,
diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py
index 5cdd277..86958b2 100644
--- a/src/entropice/dashboard/views/inference_page.py
+++ b/src/entropice/dashboard/views/inference_page.py
@@ -1,6 +1,7 @@
"""Inference page: Visualization of model inference results across the study region."""
import streamlit as st
+from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.plots.inference import (
render_class_comparison,
@@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics,
render_spatial_distribution_stats,
)
-from entropice.dashboard.utils.data import load_all_training_results
def render_inference_page():
diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py
index 1dd3c62..5cd90f5 100644
--- a/src/entropice/dashboard/views/model_state_page.py
+++ b/src/entropice/dashboard/views/model_state_page.py
@@ -2,6 +2,15 @@
import streamlit as st
import xarray as xr
+from entropice.dashboard.utils.data import (
+ extract_arcticdem_features,
+ extract_common_features,
+ extract_embedding_features,
+ extract_era5_features,
+ get_members_from_settings,
+ load_all_training_results,
+)
+from entropice.dashboard.utils.training import load_model_state
from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap,
@@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import (
plot_top_features,
)
from entropice.dashboard.utils.colors import generate_unified_colormap
-from entropice.dashboard.utils.data import (
- extract_arcticdem_features,
- extract_common_features,
- extract_embedding_features,
- extract_era5_features,
- get_members_from_settings,
- load_all_training_results,
-)
-from entropice.dashboard.utils.training import load_model_state
def render_model_state_page():
diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py
index de5e2de..d9676a3 100644
--- a/src/entropice/dashboard/views/overview_page.py
+++ b/src/entropice/dashboard/views/overview_page.py
@@ -1,238 +1,20 @@
"""Overview page: List of available result directories with some summary statistics."""
-from dataclasses import dataclass
from datetime import datetime
-from typing import TypedDict
+from typing import cast
import pandas as pd
import plotly.express as px
import streamlit as st
+from stopuhr import stopwatch
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
-from entropice.ml.dataset import DatasetEnsemble
-from entropice.utils.types import Grid, TargetDataset, Task
+from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
+from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
-# Type definitions for dataset statistics
-class GridConfig(TypedDict):
- """Grid configuration specification with metadata."""
-
- grid: Grid
- level: int
- grid_name: str
- grid_sort_key: str
- disable_alphaearth: bool
-
-
-class SampleCountData(TypedDict):
- """Sample count statistics for a specific grid/target/task combination."""
-
- grid_config: GridConfig
- target: str
- task: str
- samples_coverage: int
- samples_labels: int
- samples_both: int
-
-
-class FeatureCountData(TypedDict):
- """Feature count statistics for a specific grid configuration."""
-
- grid_config: GridConfig
- total_features: int
- data_sources: list[str]
- inference_cells: int
- total_samples: int
- member_breakdown: dict[str, int]
-
-
-@dataclass(frozen=True)
-class DatasetAnalysisCache:
- """Cache for dataset analysis data to avoid redundant computations."""
-
- grid_configs: list[GridConfig]
- sample_counts: list[SampleCountData]
- feature_counts: list[FeatureCountData]
-
- def get_sample_count_df(self) -> pd.DataFrame:
- """Convert sample count data to DataFrame."""
- rows = []
- for item in self.sample_counts:
- rows.append(
- {
- "Grid": item["grid_config"]["grid_name"],
- "Grid Type": item["grid_config"]["grid"],
- "Level": item["grid_config"]["level"],
- "Target": item["target"].replace("darts_", ""),
- "Task": item["task"].capitalize(),
- "Samples (Coverage)": item["samples_coverage"],
- "Samples (Labels)": item["samples_labels"],
- "Samples (Both)": item["samples_both"],
- "Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
- }
- )
- return pd.DataFrame(rows)
-
- def get_feature_count_df(self) -> pd.DataFrame:
- """Convert feature count data to DataFrame."""
- rows = []
- for item in self.feature_counts:
- rows.append(
- {
- "Grid": item["grid_config"]["grid_name"],
- "Grid Type": item["grid_config"]["grid"],
- "Level": item["grid_config"]["level"],
- "Total Features": item["total_features"],
- "Data Sources": len(item["data_sources"]),
- "Inference Cells": item["inference_cells"],
- "Total Samples": item["total_samples"],
- "AlphaEarth": "AlphaEarth" in item["data_sources"],
- "Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
- }
- )
- return pd.DataFrame(rows)
-
- def get_feature_breakdown_df(self) -> pd.DataFrame:
- """Convert feature breakdown data to DataFrame for stacked/donut charts."""
- rows = []
- for item in self.feature_counts:
- for source, count in item["member_breakdown"].items():
- rows.append(
- {
- "Grid": item["grid_config"]["grid_name"],
- "Data Source": source,
- "Number of Features": count,
- "Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
- }
- )
- return pd.DataFrame(rows)
-
-
-@st.cache_data(show_spinner=False)
-def load_dataset_analysis_data() -> DatasetAnalysisCache:
- """Load and cache all dataset analysis data.
-
- This function computes both sample counts and feature counts for all grid configurations.
- Results are cached to avoid redundant computations across different tabs.
- """
- # Define all possible grid configurations
- grid_configs_raw: list[tuple[Grid, int]] = [
- ("hex", 3),
- ("hex", 4),
- ("hex", 5),
- ("hex", 6),
- ("healpix", 6),
- ("healpix", 7),
- ("healpix", 8),
- ("healpix", 9),
- ("healpix", 10),
- ]
-
- # Create structured grid config objects
- grid_configs: list[GridConfig] = []
- for grid, level in grid_configs_raw:
- disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
- grid_configs.append(
- {
- "grid": grid,
- "level": level,
- "grid_name": f"{grid}-{level}",
- "grid_sort_key": f"{grid}_{level:02d}",
- "disable_alphaearth": disable_alphaearth,
- }
- )
-
- # Compute sample counts
- sample_counts: list[SampleCountData] = []
- target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
- tasks: list[Task] = ["binary", "count", "density"]
-
- for grid_config in grid_configs:
- for target in target_datasets:
- # Create minimal ensemble just to get target data
- ensemble = DatasetEnsemble(
- grid=grid_config["grid"],
- level=grid_config["level"],
- target=target,
- members=[], # type: ignore[arg-type]
- )
- targets = ensemble._read_target()
-
- for task in tasks:
- # Get task-specific column
- taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
- covcol = ensemble.covcol
-
- # Count samples with coverage and valid labels
- if covcol in targets.columns and taskcol in targets.columns:
- valid_coverage = targets[covcol].sum()
- valid_labels = targets[taskcol].notna().sum()
- valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
-
- sample_counts.append(
- {
- "grid_config": grid_config,
- "target": target,
- "task": task,
- "samples_coverage": valid_coverage,
- "samples_labels": valid_labels,
- "samples_both": valid_both,
- }
- )
-
- # Compute feature counts
- feature_counts: list[FeatureCountData] = []
-
- for grid_config in grid_configs:
- # Determine which members are available for this configuration
- if grid_config["disable_alphaearth"]:
- members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
- else:
- members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
-
- # Use darts_rts as default target for comparison
- ensemble = DatasetEnsemble(
- grid=grid_config["grid"],
- level=grid_config["level"],
- target="darts_rts",
- members=members, # type: ignore[arg-type]
- )
- stats = ensemble.get_stats()
-
- # Calculate minimum cells across all data sources
- min_cells = min(
- member_stats["dimensions"]["cell_ids"] # type: ignore[index]
- for member_stats in stats["members"].values()
- )
-
- # Build member breakdown including lon/lat
- member_breakdown = {}
- for member, member_stats in stats["members"].items():
- member_breakdown[member] = member_stats["num_features"]
-
- if ensemble.add_lonlat:
- member_breakdown["Lon/Lat"] = 2
-
- feature_counts.append(
- {
- "grid_config": grid_config,
- "total_features": stats["total_features"],
- "data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
- "inference_cells": min_cells,
- "total_samples": stats["num_target_samples"],
- "member_breakdown": member_breakdown,
- }
- )
-
- return DatasetAnalysisCache(
- grid_configs=grid_configs,
- sample_counts=sample_counts,
- feature_counts=feature_counts,
- )
-
-
-def render_sample_count_overview(cache: DatasetAnalysisCache):
+def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
@@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
)
# Get sample count DataFrame from cache
- sample_df = cache.get_sample_count_df()
+ all_stats = load_all_default_dataset_statistics()
+ sample_df = DatasetStatistics.get_sample_count_df(all_stats)
target_datasets = ["darts_rts", "darts_mllabels"]
# Create tabs for different views
@@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab1:
st.markdown("### Sample Counts Heatmap")
- st.markdown("Showing counts of samples with both coverage and valid labels")
+ st.markdown("Showing counts of samples with coverage")
# Create heatmap for each target dataset
for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
- pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
+ pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
@@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab2:
st.markdown("### Sample Counts Bar Chart")
- st.markdown("Showing counts of samples with both coverage and valid labels")
+ st.markdown("Showing counts of samples with coverage")
# Create a faceted bar chart showing both targets side by side
# Get color palette for tasks
@@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
fig = px.bar(
sample_df,
x="Grid",
- y="Samples (Both)",
+ y="Samples (Coverage)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
- labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
+ labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
color_discrete_sequence=task_colors,
height=500,
)
@@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
st.markdown("### Detailed Sample Counts")
# Display full table with formatting
- display_df = sample_df[
- ["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
- ].copy()
+ display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
# Format numbers with commas
- for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
- display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
+ display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
+ # Format coverage as percentage with 2 decimal places
+ display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(display_df, hide_index=True, width="stretch")
-def render_feature_count_comparison(cache: DatasetAnalysisCache):
+def render_feature_count_comparison():
"""Render static comparison of feature counts across all grid configurations."""
st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Get data from cache
- comparison_df = cache.get_feature_count_df()
- breakdown_df = cache.get_feature_breakdown_df()
+ all_stats = load_all_default_dataset_statistics()
+ comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
+ breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Create tabs for different comparison views
@@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache):
# Display full comparison table with formatting
display_df = comparison_df[
- ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
+ ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
- # Format boolean as Yes/No
- display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
-
st.dataframe(display_df, hide_index=True, width="stretch")
@st.fragment
-def render_feature_count_explorer(cache: DatasetAnalysisCache):
+def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments."""
st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
- grid_options = [gc["grid_name"] for gc in cache.grid_configs]
+ grid_options = [gc.display_name for gc in grid_configs]
col1, col2 = st.columns(2)
@@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
)
# Find the selected grid config
- selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
- grid = selected_grid_config["grid"]
- level = selected_grid_config["level"]
- disable_alphaearth = selected_grid_config["disable_alphaearth"]
+ selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
+
+ # Get available members from the stats
+ all_stats = load_all_default_dataset_statistics()
+ stats = all_stats[selected_grid_config.id]
+ available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
# Members selection
st.markdown("#### Select Data Sources")
- all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ all_members = cast(
+ list[L2SourceDataset],
+ ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
+ )
# Use columns for checkboxes
cols = st.columns(len(all_members))
- selected_members = []
+ selected_members: list[L2SourceDataset] = []
for idx, member in enumerate(all_members):
with cols[idx]:
- if member == "AlphaEarth" and disable_alphaearth:
- st.checkbox(
- member,
- value=False,
- disabled=True,
- help=f"Not available for {grid} level {level}",
- key=f"feature_member_{member}",
- )
- else:
- if st.checkbox(member, value=True, key=f"feature_member_{member}"):
- selected_members.append(member)
+ default_value = member in available_members
+ if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
+ selected_members.append(cast(L2SourceDataset, member))
# Show results if at least one member is selected
if selected_members:
st.markdown("---")
- ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
+ # Get statistics from cache (already loaded)
+ grid_stats = all_stats[selected_grid_config.id]
- with st.spinner("Computing dataset statistics..."):
- stats = ensemble.get_stats()
+ # Filter to selected members only
+ selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
+
+ # Calculate total features for selected members
+ total_features = sum(ms.feature_count for ms in selected_member_stats.values())
+
+ # Get target stats
+ target_stats = grid_stats.target[cast(TargetDataset, target)]
# High-level metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
- st.metric("Total Features", f"{stats['total_features']:,}")
+ st.metric("Total Features", f"{total_features:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
- min_cells = min(
- member_stats["dimensions"]["cell_ids"] # type: ignore[index]
- for member_stats in stats["members"].values()
- )
+ min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
- st.metric("Total Samples", f"{stats['num_target_samples']:,}")
+ # Use binary task training cells as sample count
+ st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
with col5:
# Calculate total data points
- total_points = stats["total_features"] * stats["num_target_samples"]
+ total_points = total_features * target_stats.training_cells["binary"]
st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
- for member, member_stats in stats["members"].items():
+ for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
- "Number of Features": member_stats["num_features"],
- "Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
- }
- )
-
- # Add lon/lat
- if ensemble.add_lonlat:
- breakdown_data.append(
- {
- "Data Source": "Lon/Lat",
- "Number of Features": 2,
- "Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
+ "Number of Features": member_stats.feature_count,
+ "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
@@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
- for member, member_stats in stats["members"].items():
+ for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
metric_cols = st.columns(4)
with metric_cols[0]:
- st.metric("Features", member_stats["num_features"])
+ st.metric("Features", member_stats.feature_count)
with metric_cols[1]:
- st.metric("Variables", member_stats["num_variables"])
+ st.metric("Variables", len(member_stats.variable_names))
with metric_cols[2]:
- dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
+ dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
total_points = 1
- for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
+ for dim_size in member_stats.dimensions.values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
@@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
[
f'{v}'
- for v in member_stats["variables"] # type: ignore[union-attr]
+ for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
@@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
f''
f"{dim_name}: {dim_size}"
- for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
+ for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
+ # Size on disk
+ size_mb = member_stats.size_bytes / (1024 * 1024)
+ st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
+
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
-def render_feature_count_section(cache: DatasetAnalysisCache):
+def render_feature_count_section():
"""Render the feature count section with comparison and explorer."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
@@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache):
)
# Static comparison across all grids
- render_feature_count_comparison(cache)
+ render_feature_count_comparison()
st.divider()
# Interactive explorer for detailed analysis
- render_feature_count_explorer(cache)
+ render_feature_count_explorer()
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
- # Load all data once and cache it
- with st.spinner("Loading dataset analysis data..."):
- cache = load_dataset_analysis_data()
-
# Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]:
- render_sample_count_overview(cache)
+ render_sample_count_overview()
with analysis_tabs[1]:
- render_feature_count_section(cache)
+ render_feature_count_section()
def render_training_results_summary(training_results):
@@ -678,15 +450,15 @@ def render_training_results_summary(training_results):
col1, col2, col3, col4 = st.columns(4)
with col1:
- tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
+ tasks = {tr.settings.task for tr in training_results}
st.metric("Tasks", len(tasks))
with col2:
- grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
+ grids = {tr.settings.grid for tr in training_results}
st.metric("Grid Types", len(grids))
with col3:
- models = {tr.settings.get("model", "Unknown") for tr in training_results}
+ models = {tr.settings.model for tr in training_results}
st.metric("Model Types", len(models))
with col4:
@@ -725,10 +497,10 @@ def render_experiment_results(training_results):
summary_data.append(
{
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
- "Task": tr.settings.get("task", "Unknown"),
- "Grid": tr.settings.get("grid", "Unknown"),
- "Level": tr.settings.get("level", "Unknown"),
- "Model": tr.settings.get("model", "Unknown"),
+ "Task": tr.settings.task,
+ "Grid": tr.settings.grid,
+ "Level": tr.settings.level,
+ "Model": tr.settings.model,
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
"Trials": len(tr.results),
"Path": str(tr.path.name),
@@ -750,17 +522,20 @@ def render_experiment_results(training_results):
st.subheader("Individual Experiment Details")
for tr in training_results:
- with st.expander(tr.get_display_name("task_first")):
+ display_name = (
+ f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
+ )
+ with st.expander(display_name):
col1, col2 = st.columns([1, 2])
with col1:
st.write("**Configuration:**")
- st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
- st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
- st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
- st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
- st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
- st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
+ st.write(f"- **Task:** {tr.settings.task}")
+ st.write(f"- **Grid:** {tr.settings.grid}")
+ st.write(f"- **Level:** {tr.settings.level}")
+ st.write(f"- **Model:** {tr.settings.model}")
+ st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
+ st.write(f"- **Classes:** {tr.settings.classes}")
st.write("\n**Files:**")
st.write("- 📊 search_results.parquet")
@@ -844,3 +619,5 @@ def render_overview_page():
render_dataset_analysis()
st.balloons()
+
+ stopwatch.summary()
diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py
index 87c32dd..eb7eb91 100644
--- a/src/entropice/utils/types.py
+++ b/src/entropice/utils/types.py
@@ -1,5 +1,6 @@
"""Shared types used across the entropice codebase."""
+from dataclasses import dataclass
from typing import Literal
type Grid = Literal["hex", "healpix"]
@@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"]
type Model = Literal["espa", "xgboost", "rf", "knn"]
+
+
+@dataclass(frozen=True)
+class GridConfig:
+ """Grid configuration specification with metadata."""
+
+ grid: Grid
+ level: int
+ id: GridLevel
+ display_name: str
+ sort_key: str
+
+ @classmethod
+ def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
+ """Create a GridConfig from a GridLevel string."""
+ if grid_level.startswith("hex"):
+ grid = "hex"
+ level = int(grid_level[3:])
+ elif grid_level.startswith("healpix"):
+ grid = "healpix"
+ level = int(grid_level[7:])
+ else:
+ raise ValueError(f"Invalid grid level: {grid_level}")
+
+ display_name = f"{grid.capitalize()}-{level}"
+
+ return cls(
+ grid=grid,
+ level=level,
+ id=grid_level,
+ display_name=display_name,
+ sort_key=f"{grid}_{level:02d}",
+ )
+
+
+# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
+all_tasks: list[Task] = ["binary", "count", "density"]
+all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
+all_l2_source_datasets: list[L2SourceDataset] = [
+ "ArcticDEM",
+ "ERA5-shoulder",
+ "ERA5-seasonal",
+ "ERA5-yearly",
+ "AlphaEarth",
+]
+
+grid_configs: list[GridConfig] = [
+ GridConfig.from_grid_level("hex3"),
+ GridConfig.from_grid_level("hex4"),
+ GridConfig.from_grid_level("hex5"),
+ GridConfig.from_grid_level("hex6"),
+ GridConfig.from_grid_level("healpix6"),
+ GridConfig.from_grid_level("healpix7"),
+ GridConfig.from_grid_level("healpix8"),
+ GridConfig.from_grid_level("healpix9"),
+ GridConfig.from_grid_level("healpix10"),
+]