From 36f8737075acc808347a5f7613934dfcde625f56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 4 Jan 2026 02:12:53 +0100 Subject: [PATCH] Cleanup dashboard utils and move pages to views --- src/entropice/dashboard/app.py | 10 +- .../plots/hyperparameter_analysis.py | 2 +- src/entropice/dashboard/plots/inference.py | 2 +- src/entropice/dashboard/plots/source_data.py | 2 +- .../dashboard/plots/training_data.py | 2 +- .../dashboard/{plots => utils}/colors.py | 2 +- src/entropice/dashboard/utils/formatters.py | 96 ++++ src/entropice/dashboard/utils/loaders.py | 163 ++++++ src/entropice/dashboard/utils/stats.py | 482 ++++++++++++++++++ src/entropice/dashboard/utils/training.py | 232 --------- .../dashboard/utils/{data.py => unsembler.py} | 164 +----- .../dashboard/{ => views}/inference_page.py | 0 .../dashboard/{ => views}/model_state_page.py | 2 +- .../dashboard/{ => views}/overview_page.py | 38 +- .../{ => views}/training_analysis_page.py | 0 .../{ => views}/training_data_page.py | 15 +- src/entropice/ingest/alphaearth.py | 10 +- src/entropice/ingest/arcticdem.py | 4 +- src/entropice/ingest/darts.py | 9 +- src/entropice/ingest/era5.py | 7 +- src/entropice/ml/dataset.py | 88 ++-- src/entropice/ml/training.py | 32 +- src/entropice/spatial/aggregators.py | 5 +- src/entropice/spatial/grids.py | 12 +- src/entropice/utils/paths.py | 28 +- src/entropice/utils/types.py | 11 + 26 files changed, 904 insertions(+), 514 deletions(-) rename src/entropice/dashboard/{plots => utils}/colors.py (98%) create mode 100644 src/entropice/dashboard/utils/formatters.py create mode 100644 src/entropice/dashboard/utils/loaders.py create mode 100644 src/entropice/dashboard/utils/stats.py delete mode 100644 src/entropice/dashboard/utils/training.py rename src/entropice/dashboard/utils/{data.py => unsembler.py} (73%) rename src/entropice/dashboard/{ => views}/inference_page.py (100%) rename src/entropice/dashboard/{ => views}/model_state_page.py (99%) rename src/entropice/dashboard/{ => views}/overview_page.py (96%) rename src/entropice/dashboard/{ => views}/training_analysis_page.py (100%) rename src/entropice/dashboard/{ => views}/training_data_page.py (95%) create mode 100644 src/entropice/utils/types.py diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 3edb288..713f7d0 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -11,11 +11,11 @@ Pages: import streamlit as st -from entropice.dashboard.inference_page import render_inference_page -from entropice.dashboard.model_state_page import render_model_state_page -from entropice.dashboard.overview_page import render_overview_page -from entropice.dashboard.training_analysis_page import render_training_analysis_page -from entropice.dashboard.training_data_page import render_training_data_page +from entropice.dashboard.views.inference_page import render_inference_page +from entropice.dashboard.views.model_state_page import render_model_state_page +from entropice.dashboard.views.overview_page import render_overview_page +from entropice.dashboard.views.training_analysis_page import render_training_analysis_page +from entropice.dashboard.views.training_data_page import render_training_data_page def main(): diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 2fff6b5..9141f95 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -12,7 +12,7 @@ import pydeck as pdk import streamlit as st from shapely.geometry import shape -from entropice.dashboard.plots.colors import get_cmap, get_palette +from entropice.dashboard.utils.colors import get_cmap, get_palette from entropice.ml.dataset import DatasetEnsemble diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py index 3c2b57f..a229161 100644 --- a/src/entropice/dashboard/plots/inference.py +++ b/src/entropice/dashboard/plots/inference.py @@ -7,7 +7,7 @@ import pydeck as pdk import streamlit as st from shapely.geometry import shape -from entropice.dashboard.plots.colors import get_palette +from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.data import TrainingResult diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py index 294fc65..ca11da9 100644 --- a/src/entropice/dashboard/plots/source_data.py +++ b/src/entropice/dashboard/plots/source_data.py @@ -10,7 +10,7 @@ import streamlit as st import xarray as xr from shapely.geometry import shape -from entropice.dashboard.plots.colors import get_cmap +from entropice.dashboard.utils.colors import get_cmap # TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py index 916a91f..1148b58 100644 --- a/src/entropice/dashboard/plots/training_data.py +++ b/src/entropice/dashboard/plots/training_data.py @@ -7,7 +7,7 @@ import pydeck as pdk import streamlit as st from shapely.geometry import shape -from entropice.dashboard.plots.colors import get_palette +from entropice.dashboard.utils.colors import get_palette from entropice.ml.dataset import CategoricalTrainingDataset diff --git a/src/entropice/dashboard/plots/colors.py b/src/entropice/dashboard/utils/colors.py similarity index 98% rename from src/entropice/dashboard/plots/colors.py rename to src/entropice/dashboard/utils/colors.py index 3924348..cd77d3e 100644 --- a/src/entropice/dashboard/plots/colors.py +++ b/src/entropice/dashboard/utils/colors.py @@ -132,7 +132,7 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m cmap = plt.get_cmap("viridis") # Sample colors evenly across the colormap indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends - base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices] + base_colors = [mcolors.rgb2hex(cmap(idx).tolist()[:3]) for idx in indices] # Create matplotlib colormap (for ultraplot and geopandas) matplotlib_cmap = mcolors.ListedColormap(base_colors) diff --git a/src/entropice/dashboard/utils/formatters.py b/src/entropice/dashboard/utils/formatters.py new file mode 100644 index 0000000..042a490 --- /dev/null +++ b/src/entropice/dashboard/utils/formatters.py @@ -0,0 +1,96 @@ +"""Formatters for dashboard display.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Literal + +from entropice.utils.types import Grid, Model, Task + + +@dataclass +class ModelDisplayInfo: + keyword: Model + short: str + long: str + + def __str__(self) -> str: + return f"{self.short} ({self.long})" + + +model_display_infos: dict[Model, ModelDisplayInfo] = { + "espa": ModelDisplayInfo( + keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm" + ), + "xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"), + "rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"), + "knn": ModelDisplayInfo(keyword="knn", short="k-NN", long="k-Nearest Neighbors Classifier"), +} + + +@dataclass +class TaskDisplayInfo: + keyword: Task + display_name: str + explanation: str + + +task_display_infos: dict[Task, TaskDisplayInfo] = { + "binary": TaskDisplayInfo( + keyword="binary", + display_name="Binary Classification", + explanation="Classify each grid cell as containing or not containing the target feature.", + ), + "count": TaskDisplayInfo( + keyword="count", + display_name="Count Prediction", + explanation="Predict the number of target features present within each grid cell.", + ), + "density": TaskDisplayInfo( + keyword="density", + display_name="Density Estimation", + explanation="Estimate the density of target features within each grid cell.", + ), +} + + +@dataclass +class TrainingResultDisplayInfo: + task: Task + model: Model + grid: Grid + level: int + timestamp: datetime + + def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str: + task = self.task.capitalize() + model = self.model.upper() + grid = self.grid.capitalize() + level = self.level + timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M") + + if format_type == "model_first": + return f"{model} - {task} - {grid}-{level} ({timestamp})" + else: # task_first + return f"{task} - {model} - {grid}-{level} ({timestamp})" + + +def format_metric_name(metric: str) -> str: + """Format metric name for display. + + Args: + metric: Raw metric name (e.g., 'f1_micro', 'precision_macro'). + + Returns: + Formatted metric name (e.g., 'F1 Micro', 'Precision Macro'). + + """ + # Split by underscore and capitalize each part + parts = metric.split("_") + # Special handling for F1 + formatted_parts = [] + for part in parts: + if part.lower() == "f1": + formatted_parts.append("F1") + else: + formatted_parts.append(part.capitalize()) + return " ".join(formatted_parts) diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py new file mode 100644 index 0000000..e19b2d1 --- /dev/null +++ b/src/entropice/dashboard/utils/loaders.py @@ -0,0 +1,163 @@ +"""Data utilities for Entropice dashboard.""" + +import pickle +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import antimeridian +import geopandas as gpd +import pandas as pd +import streamlit as st +import toml +import xarray as xr +from shapely.geometry import shape + +import entropice.spatial.grids +import entropice.utils.paths +from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo +from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble +from entropice.ml.training import TrainingSettings +from entropice.utils.types import L2SourceDataset, Task + + +def _fix_hex_geometry(geom): + """Fix hexagon geometry crossing the antimeridian.""" + try: + return shape(antimeridian.fix_shape(geom)) + except ValueError as e: + st.error(f"Error fixing geometry: {e}") + return geom + + +@dataclass +class TrainingResult: + path: Path + settings: TrainingSettings + results: pd.DataFrame + created_at: float + available_metrics: list[str] + + @classmethod + def from_path(cls, result_path: Path) -> "TrainingResult": + """Load a TrainingResult from a given result directory path.""" + result_file = result_path / "search_results.parquet" + preds_file = result_path / "predicted_probabilities.parquet" + settings_file = result_path / "search_settings.toml" + if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]): + raise FileNotFoundError(f"Missing required files in {result_path}") + + created_at = result_path.stat().st_ctime + settings = TrainingSettings(**(toml.load(settings_file)["settings"])) + results = pd.read_parquet(result_file) + + available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")] + + return cls( + path=result_path, + settings=settings, + results=results, + created_at=created_at, + available_metrics=available_metrics, + ) + + @property + def display_info(self) -> TrainingResultDisplayInfo: + return TrainingResultDisplayInfo( + task=self.settings.task, + model=self.settings.model, + grid=self.settings.grid, + level=self.settings.level, + timestamp=datetime.fromtimestamp(self.created_at), + ) + + def load_best_model(self) -> object | None: + """Load the best model from a training result.""" + model_file = self.path / "best_estimator_model.pkl" + if not model_file.exists(): + return None + + try: + with open(model_file, "rb") as f: + model = pickle.load(f) + return model + except Exception as e: + st.error(f"Error loading model: {e}") + return None + + def load_model_state(self) -> xr.Dataset | None: + """Load the model state from a training result.""" + state_file = self.path / "best_estimator_state.nc" + if not state_file.exists(): + return None + + try: + state = xr.open_dataset(state_file, engine="h5netcdf") + return state + except Exception as e: + st.error(f"Error loading model state: {e}") + return None + + def load_predictions(self) -> pd.DataFrame | None: + """Load predictions from a training result.""" + preds_file = self.path / "predicted_probabilities.parquet" + if not preds_file.exists(): + return None + + try: + preds = pd.read_parquet(preds_file) + return preds + except Exception as e: + st.error(f"Error loading predictions: {e}") + return None + + +@st.cache_data +def load_all_training_results() -> list[TrainingResult]: + results_dir = entropice.utils.paths.RESULTS_DIR + training_results: list[TrainingResult] = [] + for result_path in results_dir.iterdir(): + if not result_path.is_dir(): + continue + training_result = TrainingResult.from_path(result_path) + training_results.append(training_result) + + # Sort by creation time (most recent first) + training_results.sort(key=lambda tr: tr.created_at, reverse=True) + return training_results + + +def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]: + """Load training data for all three tasks. + + Args: + e: DatasetEnsemble object. + + Returns: + Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. + + """ + return { + "binary": e.create_cat_training_dataset("binary", device="cpu"), + "count": e.create_cat_training_dataset("count", device="cpu"), + "density": e.create_cat_training_dataset("density", device="cpu"), + } + + +def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]: + """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5). + + Args: + e: DatasetEnsemble object. + source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'. + + Returns: + xarray.Dataset with the raw data for the specified source. + + """ + targets = e._read_target() + + # Load the member data lazily to get metadata + ds = e._read_member(source, targets, lazy=False) + + return ds, targets diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py new file mode 100644 index 0000000..b0f6172 --- /dev/null +++ b/src/entropice/dashboard/utils/stats.py @@ -0,0 +1,482 @@ +"""General Statistics shared across multiple dashboard pages. + +- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination +""" + +from collections import defaultdict +from dataclasses import asdict, dataclass +from typing import Literal, cast, get_args + +import geopandas as gpd +import pandas as pd +import streamlit as st +import xarray as xr +from stopuhr import stopwatch + +import entropice.spatial.grids +import entropice.utils.paths +from entropice.dashboard.utils.loaders import TrainingResult +from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol +from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task + + +@dataclass(frozen=True) +class MemberStatistics: + """Statistics for a specific dataset member.""" + + feature_count: int # Number of features from this member + variable_names: list[str] # Names of variables from this member + dimensions: dict[str, int] # Dimension name to size mapping + coordinates: list[str] # Names of coordinates from this member + size_bytes: int # Size of this member's data on disk in bytes + + @classmethod + def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics": + if member == "AlphaEarth": + store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level) + elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: + era5_agg = member.split("-")[1] + store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type] + elif member == "ArcticDEM": + store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level) + else: + raise NotImplementedError(f"Member {member} not implemented.") + + size_bytes = store.stat().st_size + ds = xr.open_zarr(store, consolidated=False) + + # Delete all coordinates which are not in the dimension + for coord in ds.coords: + if coord not in ds.dims: + ds = ds.drop_vars(coord) + n_cols_member = len(ds.data_vars) + for dim in ds.sizes: + if dim != "cell_ids": + n_cols_member *= ds.sizes[dim] + + return cls( + feature_count=n_cols_member, + variable_names=list(ds.data_vars), + dimensions=dict(ds.sizes), + coordinates=list(ds.coords), + size_bytes=size_bytes, + ) + + +@dataclass(frozen=True) +class TargetStatistics: + """Statistics for a specific target dataset.""" + + training_cells: dict[Task, int] # Number of cells used for training + coverage: dict[Task, float] # Percentage of total cells covered by training data + class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task + class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task + size_bytes: int # Size of the target dataset on disk in bytes + + @classmethod + def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics": + if target == "darts_rts": + target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level) + elif target == "darts_mllabels": + target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True) + else: + raise NotImplementedError(f"Target {target} not implemented.") + target_gdf = gpd.read_parquet(target_store) + size_bytes = target_store.stat().st_size + training_cells: dict[Task, int] = {} + training_coverage: dict[Task, float] = {} + class_counts: dict[Task, dict[str, int]] = {} + class_distribution: dict[Task, dict[str, float]] = {} + tasks = cast(list[Task], get_args(Task)) + for task in tasks: + task_col = taskcol[task][target] + cov_col = covcol[target] + + task_gdf = target_gdf[target_gdf[cov_col]] + training_cells[task] = len(task_gdf) + training_coverage[task] = len(task_gdf) / total_cells * 100 + + model_labels = task_gdf[task_col].dropna() + if task == "binary": + binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category") + elif task == "count": + binned = bin_values(model_labels.astype(int), task=task) + elif task == "density": + binned = bin_values(model_labels, task=task) + else: + raise ValueError("Invalid task.") + counts = binned.value_counts() + distribution = counts / counts.sum() * 100 + class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment] + class_distribution[task] = distribution.to_dict() + return TargetStatistics( + training_cells=training_cells, + coverage=training_coverage, + class_counts=class_counts, + class_distribution=class_distribution, + size_bytes=size_bytes, + ) + + +@dataclass(frozen=True) +class DatasetStatistics: + """Statistics for a potential dataset at a specific grid and level. + + These statistics are meant to be easily compute without loading a full dataset into memory. + Further, it is meant to give an overview of all potential potential dataset compositions for a given grid and level. + """ + + total_features: int # Total number of features available in the dataset + total_cells: int # Total number of grid cells potentially covered + size_bytes: int # Size of the dataset on disk in bytes + members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member + target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset + + +@st.cache_data +def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]: + dataset_stats: dict[GridLevel, DatasetStatistics] = {} + grid_levels: set[tuple[Grid, int]] = { + ("hex", 3), + ("hex", 4), + ("hex", 5), + ("hex", 6), + ("healpix", 6), + ("healpix", 7), + ("healpix", 8), + ("healpix", 9), + ("healpix", 10), + } + for grid, level in grid_levels: + with stopwatch(f"Loading statistics for grid={grid}, level={level}"): + grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered + total_cells = len(grid_gdf) + assert total_cells > 0, "Grid must contain at least one cell." + target_statistics: dict[TargetDataset, TargetStatistics] = {} + targets = cast(list[TargetDataset], get_args(TargetDataset)) + for target in targets: + target_statistics[target] = TargetStatistics.compute( + grid=grid, level=level, target=target, total_cells=total_cells + ) + member_statistics: dict[L2SourceDataset, MemberStatistics] = {} + members = cast(list[L2SourceDataset], get_args(L2SourceDataset)) + for member in members: + member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member) + + total_features = sum(ms.feature_count for ms in member_statistics.values()) + total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum( + ts.size_bytes for ts in target_statistics.values() + ) + grid_level: GridLevel = cast(GridLevel, f"{grid}{level}") + dataset_stats[grid_level] = DatasetStatistics( + total_features=total_features, + total_cells=total_cells, + size_bytes=total_size_bytes, + members=member_statistics, + target=target_statistics, + ) + + return dataset_stats + + +@dataclass(frozen=True) +class EnsembleMemberStatistics: + n_features: int # Number of features from this member in the ensemble + p_features: float # Percentage of features from this member in the ensemble + n_nanrows: int # Number of rows which contain any NaN + size_bytes: int # Size of this member's data in the ensemble in bytes + + @classmethod + def compute( + cls, + dataset: gpd.GeoDataFrame, + member: L2SourceDataset, + n_features: int, + ) -> "EnsembleMemberStatistics": + if member == "AlphaEarth": + member_dataset = dataset[dataset.columns.str.startswith("embeddings_")] + elif member == "ArcticDEM": + member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")] + elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: + era5_cols = dataset.columns.str.startswith("era5_") + cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2) + cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3) + cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter")) + if member == "ERA5-yearly": + member_dataset = dataset[cols_with_three_splits] + elif member == "ERA5-seasonal": + member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits] + elif member == "ERA5-shoulder": + member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits] + else: + raise NotImplementedError(f"Member {member} not implemented.") + size_bytes_member = member_dataset.memory_usage(deep=True).sum() + n_features_member = len(member_dataset.columns) + p_features_member = n_features_member / n_features + n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum() + return EnsembleMemberStatistics( + n_features=n_features_member, + p_features=p_features_member, + n_nanrows=n_rows_with_nan_member, + size_bytes=size_bytes_member, + ) + + +@dataclass(frozen=True) +class EnsembleDatasetStatistics: + """Statistics for a specified composition / ensemble at a specific grid and level. + + These statistics are meant to be only computed on user demand, thus by loading a dataset into memory. + That way, the real number of features and cells after NaN filtering can be reported. + """ + + n_features: int # Number of features in the dataset + n_cells: int # Number of grid cells covered + n_nanrows: int # Number of rows which contain any NaN + size_bytes: int # Size of the dataset on disk in bytes + members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member + + @classmethod + def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics": + dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) + # Assert that no column in all-nan + assert not dataset.isna().all("index").any(), "Some input columns are all NaN" + + size_bytes = dataset.memory_usage(deep=True).sum() + n_features = len(dataset.columns) + n_cells = len(dataset) + + # Number of rows which contain any NaN + n_rows_with_nan = dataset.isna().any(axis=1).sum() + + member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {} + for member in ensemble.members: + member_statistics[member] = EnsembleMemberStatistics.compute( + dataset=dataset, + member=member, + n_features=n_features, + ) + return cls( + n_features=n_features, + n_cells=n_cells, + n_nanrows=n_rows_with_nan, + size_bytes=size_bytes, + members=member_statistics, + ) + + +@dataclass(frozen=True) +class TrainingDatasetStatistics: + n_samples: int # Total number of samples in the dataset + n_features: int # Number of features + feature_names: list[str] # Names of all features + n_training_samples: int # Number of cells used for training + n_test_samples: int # Number of cells used for testing + train_test_ratio: float # Ratio of training to test samples + + class_labels: list[str] # Ordered list of class labels + class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class + n_classes: int # Number of classes + + training_class_counts: dict[str, int] # Class counts in training set + training_class_distribution: dict[str, float] # Class percentages in training set + test_class_counts: dict[str, int] # Class counts in test set + test_class_distribution: dict[str, float] # Class percentages in test set + + raw_value_min: float # Minimum raw target value + raw_value_max: float # Maximum raw target value + raw_value_mean: float # Mean raw target value + raw_value_median: float # Median raw target value + raw_value_std: float # Standard deviation of raw target values + + imbalance_ratio: float # Smallest class count / largest class count (overall) + size_bytes: int # Total memory usage of features in bytes + + @classmethod + def compute( + cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None + ) -> "TrainingDatasetStatistics": + dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) + categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu") + + # Sample counts + n_samples = len(categorical_dataset) + n_training_samples = len(categorical_dataset.y.train) + n_test_samples = len(categorical_dataset.y.test) + train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0 + + # Feature statistics + n_features = len(categorical_dataset.X.data.columns) + feature_names = list(categorical_dataset.X.data.columns) + size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum() + + # Class information + class_labels = categorical_dataset.y.labels + class_intervals = categorical_dataset.y.intervals + n_classes = len(class_labels) + + # Training class distribution + train_y_series = pd.Series(categorical_dataset.y.train) + train_counts = train_y_series.value_counts().sort_index() + training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)} + train_total = sum(training_class_counts.values()) + training_class_distribution = { + k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items() + } + + # Test class distribution + test_y_series = pd.Series(categorical_dataset.y.test) + test_counts = test_y_series.value_counts().sort_index() + test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)} + test_total = sum(test_class_counts.values()) + test_class_distribution = { + k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items() + } + + # Raw value statistics + raw_values = categorical_dataset.y.raw_values + raw_value_min = float(raw_values.min()) + raw_value_max = float(raw_values.max()) + raw_value_mean = float(raw_values.mean()) + raw_value_median = float(raw_values.median()) + raw_value_std = float(raw_values.std()) + + # Imbalance ratio (smallest class / largest class across both splits) + all_counts = list(training_class_counts.values()) + list(test_class_counts.values()) + nonzero_counts = [c for c in all_counts if c > 0] + imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0 + + return cls( + n_samples=n_samples, + n_features=n_features, + feature_names=feature_names, + n_training_samples=n_training_samples, + n_test_samples=n_test_samples, + train_test_ratio=train_test_ratio, + class_labels=class_labels, + class_intervals=class_intervals, + n_classes=n_classes, + training_class_counts=training_class_counts, + training_class_distribution=training_class_distribution, + test_class_counts=test_class_counts, + test_class_distribution=test_class_distribution, + raw_value_min=raw_value_min, + raw_value_max=raw_value_max, + raw_value_mean=raw_value_mean, + raw_value_median=raw_value_median, + raw_value_std=raw_value_std, + imbalance_ratio=imbalance_ratio, + size_bytes=size_bytes, + ) + + +@dataclass(frozen=True) +class CVMetricStatistics: + best_score: float + mean_score: float + std_score: float + worst_score: float + median_score: float + mean_cv_std: float | None + + @classmethod + def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics": + """Get cross-validation statistics for a metric.""" + score_col = f"mean_test_{metric}" + std_col = f"std_test_{metric}" + + if score_col not in result.results.columns: + raise ValueError(f"Metric {metric} not found in results.") + + best_score = result.results[score_col].max() + mean_score = result.results[score_col].mean() + std_score = result.results[score_col].std() + worst_score = result.results[score_col].min() + median_score = result.results[score_col].median() + + mean_cv_std = None + if std_col in result.results.columns: + mean_cv_std = result.results[std_col].mean() + + return CVMetricStatistics( + best_score=best_score, + mean_score=mean_score, + std_score=std_score, + worst_score=worst_score, + median_score=median_score, + mean_cv_std=mean_cv_std, + ) + + +@dataclass(frozen=True) +class ParameterSpaceSummary: + parameter: str + type: Literal["Numeric", "Categorical"] + min: float | None + max: float | None + mean: float | None + unique_values: int + + @classmethod + def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary": + param_name = param_col.replace("param_", "") + param_values = result.results[param_col].dropna() + + if pd.api.types.is_numeric_dtype(param_values): + return ParameterSpaceSummary( + parameter=param_name, + type="Numeric", + min=param_values.min(), + max=param_values.max(), + mean=param_values.mean(), + unique_values=param_values.nunique(), + ) + else: + unique_vals = param_values.unique() + return ParameterSpaceSummary( + parameter=param_name, + type="Categorical", + min=None, + max=None, + mean=None, + unique_values=len(unique_vals), + ) + + +@dataclass(frozen=True) +class CVResultsStatistics: + metrics: dict[str, CVMetricStatistics] + parameter_summary: list[ParameterSpaceSummary] + + @classmethod + def compute(cls, result: TrainingResult) -> "CVResultsStatistics": + """Get cross-validation statistics for a metric.""" + metrics = result.available_metrics + metric_stats: dict[str, CVMetricStatistics] = {} + for metric in metrics: + metric_stats[metric] = CVMetricStatistics.compute(result, metric) + + param_cols = [col for col in result.results.columns if col.startswith("param_") and col != "params"] + summary_data = [] + for param_col in param_cols: + summary_data.append(ParameterSpaceSummary.compute(result, param_col)) + + return CVResultsStatistics(metrics=metric_stats, parameter_summary=summary_data) + + def metrics_to_dataframe(self) -> pd.DataFrame: + """Convert metric statistics to a DataFrame.""" + data = defaultdict(list) + for metric, stats in self.metrics.items(): + data["Metric"].append(metric) + data["Best Score"].append(stats.best_score) + data["Mean Score"].append(stats.mean_score) + data["Std Dev"].append(stats.std_score) + data["Worst Score"].append(stats.worst_score) + data["Median Score"].append(stats.median_score) + data["Mean CV Std Dev"].append(stats.mean_cv_std) + + return pd.DataFrame(data) + + def parameters_to_dataframe(self) -> pd.DataFrame: + """Convert parameter summary to a DataFrame.""" + return pd.DataFrame([asdict(p) for p in self.parameter_summary]) diff --git a/src/entropice/dashboard/utils/training.py b/src/entropice/dashboard/utils/training.py deleted file mode 100644 index 70c09b3..0000000 --- a/src/entropice/dashboard/utils/training.py +++ /dev/null @@ -1,232 +0,0 @@ -"""Training utilities for dashboard.""" - -import pickle - -import numpy as np -import pandas as pd -import streamlit as st -import xarray as xr - -from entropice.dashboard.utils.data import TrainingResult - - -def format_metric_name(metric: str) -> str: - """Format metric name for display. - - Args: - metric: Raw metric name (e.g., 'f1_micro', 'precision_macro'). - - Returns: - Formatted metric name (e.g., 'F1 Micro', 'Precision Macro'). - - """ - # Split by underscore and capitalize each part - parts = metric.split("_") - # Special handling for F1 - formatted_parts = [] - for part in parts: - if part.lower() == "f1": - formatted_parts.append("F1") - else: - formatted_parts.append(part.capitalize()) - return " ".join(formatted_parts) - - -def get_available_metrics(results: pd.DataFrame) -> list[str]: - """Get list of available metrics from results. - - Args: - results: DataFrame with CV results. - - Returns: - List of metric names (without 'mean_test_' prefix). - - """ - score_cols = [col for col in results.columns if col.startswith("mean_test_")] - return [col.replace("mean_test_", "") for col in score_cols] - - -def load_best_model(result: TrainingResult): - """Load the best model from a training result. - - Args: - result: TrainingResult object. - - Returns: - The loaded model object, or None if loading fails. - - """ - model_file = result.path / "best_estimator_model.pkl" - if not model_file.exists(): - return None - - try: - with open(model_file, "rb") as f: - model = pickle.load(f) - return model - except Exception as e: - st.error(f"Error loading model: {e}") - return None - - -def load_model_state(result: TrainingResult) -> xr.Dataset | None: - """Load the model state from a training result. - - Args: - result: TrainingResult object. - - Returns: - xarray Dataset with model state, or None if not available. - - """ - state_file = result.path / "best_estimator_state.nc" - if not state_file.exists(): - return None - - try: - state = xr.open_dataset(state_file, engine="h5netcdf") - return state - except Exception as e: - st.error(f"Error loading model state: {e}") - return None - - -def load_predictions(result: TrainingResult) -> pd.DataFrame | None: - """Load predictions from a training result. - - Args: - result: TrainingResult object. - - Returns: - DataFrame with predictions, or None if not available. - - """ - preds_file = result.path / "predicted_probabilities.parquet" - if not preds_file.exists(): - return None - - try: - preds = pd.read_parquet(preds_file) - return preds - except Exception as e: - st.error(f"Error loading predictions: {e}") - return None - - -def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame: - """Get summary of parameter space explored. - - Args: - results: DataFrame with CV results. - - Returns: - DataFrame with parameter ranges and statistics. - - """ - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - - summary_data = [] - for param_col in param_cols: - param_name = param_col.replace("param_", "") - param_values = results[param_col].dropna() - - if pd.api.types.is_numeric_dtype(param_values): - summary_data.append( - { - "Parameter": param_name, - "Type": "Numeric", - "Min": f"{param_values.min():.2e}", - "Max": f"{param_values.max():.2e}", - "Mean": f"{param_values.mean():.2e}", - "Unique Values": param_values.nunique(), - } - ) - else: - unique_vals = param_values.unique() - summary_data.append( - { - "Parameter": param_name, - "Type": "Categorical", - "Min": "-", - "Max": "-", - "Mean": "-", - "Unique Values": len(unique_vals), - } - ) - - return pd.DataFrame(summary_data) - - -def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict: - """Get cross-validation statistics for a metric. - - Args: - results: DataFrame with CV results. - metric: Metric name (without 'mean_test_' prefix). - - Returns: - Dictionary with CV statistics. - - """ - score_col = f"mean_test_{metric}" - std_col = f"std_test_{metric}" - - if score_col not in results.columns: - return {} - - stats = { - "best_score": results[score_col].max(), - "mean_score": results[score_col].mean(), - "std_score": results[score_col].std(), - "worst_score": results[score_col].min(), - "median_score": results[score_col].median(), - } - - if std_col in results.columns: - stats["mean_cv_std"] = results[std_col].mean() - - return stats - - -def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame: - """Prepare results dataframe with binned columns for plotting. - - Args: - results: DataFrame with CV results. - k_bin_width: Width of bins for initial_K parameter. - - Returns: - DataFrame with added binned columns. - - """ - results_copy = results.copy() - - # Check if we have the parameters - if "param_initial_K" in results.columns: - # Bin initial_K - k_values = results["param_initial_K"].dropna() - if len(k_values) > 0: - k_min = k_values.min() - k_max = k_values.max() - k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width) - results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False) - - if "param_eps_cl" in results.columns: - # Create logarithmic bins for eps_cl - eps_cl_values = results["param_eps_cl"].dropna() - if len(eps_cl_values) > 0 and eps_cl_values.min() > 0: - eps_cl_min = eps_cl_values.min() - eps_cl_max = eps_cl_values.max() - eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10) - results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins) - - if "param_eps_e" in results.columns: - # Create logarithmic bins for eps_e - eps_e_values = results["param_eps_e"].dropna() - if len(eps_e_values) > 0 and eps_e_values.min() > 0: - eps_e_min = eps_e_values.min() - eps_e_max = eps_e_values.max() - eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10) - results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins) - - return results_copy diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/unsembler.py similarity index 73% rename from src/entropice/dashboard/utils/data.py rename to src/entropice/dashboard/utils/unsembler.py index b57ef37..6f9e154 100644 --- a/src/entropice/dashboard/utils/data.py +++ b/src/entropice/dashboard/utils/unsembler.py @@ -1,147 +1,4 @@ -"""Data utilities for Entropice dashboard.""" - -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path - -import antimeridian -import pandas as pd -import streamlit as st -import toml import xarray as xr -from shapely.geometry import shape - -import entropice.utils.paths -from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble - - -@dataclass -class TrainingResult: - """Simple wrapper of training result data.""" - - name: str - path: Path - settings: dict - results: pd.DataFrame - created_at: float - - def get_display_name(self, format_type: str = "task_first") -> str: - """Get formatted display name for the training result. - - Args: - format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state) - - Returns: - Formatted name string - - """ - task = self.settings.get("task", "Unknown").capitalize() - model = self.settings.get("model", "Unknown").upper() - grid = self.settings.get("grid", "Unknown").capitalize() - level = self.settings.get("level", "Unknown") - timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M") - - if format_type == "model_first": - return f"{model} - {task} - {grid}-{level} ({timestamp})" - else: # task_first - return f"{task} - {model} - {grid}-{level} ({timestamp})" - - @classmethod - def from_path(cls, result_path: Path) -> "TrainingResult": - """Load a TrainingResult from a given result directory path.""" - result_file = result_path / "search_results.parquet" - preds_file = result_path / "predicted_probabilities.parquet" - settings_file = result_path / "search_settings.toml" - if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]): - raise FileNotFoundError(f"Missing required files in {result_path}") - - created_at = result_path.stat().st_ctime - settings = toml.load(settings_file)["settings"] - results = pd.read_parquet(result_file) - - # Name should be "task model grid-level (created_at)" - model = settings.get("model", "Unknown").upper() - name = ( - f"{settings.get('task', 'Unknown').capitalize()} -" - f" {model} -" - f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}" - f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})" - ) - - return cls( - name=name, - path=result_path, - settings=settings, - results=results, - created_at=created_at, - ) - - -def _fix_hex_geometry(geom): - """Fix hexagon geometry crossing the antimeridian.""" - try: - return shape(antimeridian.fix_shape(geom)) - except ValueError as e: - st.error(f"Error fixing geometry: {e}") - return geom - - -@st.cache_data -def load_all_training_results() -> list[TrainingResult]: - """Load all training results from the results directory.""" - results_dir = entropice.utils.paths.RESULTS_DIR - training_results: list[TrainingResult] = [] - for result_path in results_dir.iterdir(): - if not result_path.is_dir(): - continue - - try: - training_result = TrainingResult.from_path(result_path) - training_results.append(training_result) - except FileNotFoundError: - continue - - # Sort by creation time (most recent first) - training_results.sort(key=lambda tr: tr.created_at, reverse=True) - return training_results - - -@st.cache_data -def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]: - """Load training data for all three tasks. - - Args: - e: DatasetEnsemble object. - - Returns: - Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. - - """ - return { - "binary": e.create_cat_training_dataset("binary", device="cpu"), - "count": e.create_cat_training_dataset("count", device="cpu"), - "density": e.create_cat_training_dataset("density", device="cpu"), - } - - -@st.cache_data -def load_source_data(e: DatasetEnsemble, source: str): - """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5). - - Args: - e: DatasetEnsemble object. - source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'. - - Returns: - xarray.Dataset with the raw data for the specified source. - - """ - targets = e._read_target() - - # Load the member data lazily to get metadata - ds = e._read_member(source, targets, lazy=False) - - return ds, targets def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray: @@ -203,7 +60,7 @@ def extract_embedding_features( band=("feature", [f.split("_")[2] for f in embedding_features]), year=("feature", [f.split("_")[3] for f in embedding_features]), ) - embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 + embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") return embedding_feature_array @@ -419,9 +276,9 @@ def extract_era5_features( era5_features_array = era5_features_array.assign_coords( agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), ) - era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010 + era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") else: - era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") return era5_features_array @@ -472,7 +329,7 @@ def extract_arcticdem_features( variable=("feature", [_extract_var_name(f) for f in arcticdem_features]), agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]), ) - arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010 + arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") return arcticdem_feature_array @@ -503,16 +360,3 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea # Extract the feature importance for common features common_feature_array = importance_array.sel(feature=common_features) return common_feature_array - - -def get_members_from_settings(settings: dict) -> list[str]: - """Extract the list of dataset members used in training from settings. - - Args: - settings: Training settings dictionary. - - Returns: - List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']). - - """ - return settings.get("members", []) diff --git a/src/entropice/dashboard/inference_page.py b/src/entropice/dashboard/views/inference_page.py similarity index 100% rename from src/entropice/dashboard/inference_page.py rename to src/entropice/dashboard/views/inference_page.py diff --git a/src/entropice/dashboard/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py similarity index 99% rename from src/entropice/dashboard/model_state_page.py rename to src/entropice/dashboard/views/model_state_page.py index 031df78..1dd3c62 100644 --- a/src/entropice/dashboard/model_state_page.py +++ b/src/entropice/dashboard/views/model_state_page.py @@ -3,7 +3,6 @@ import streamlit as st import xarray as xr -from entropice.dashboard.plots.colors import generate_unified_colormap from entropice.dashboard.plots.model_state import ( plot_arcticdem_heatmap, plot_arcticdem_summary, @@ -17,6 +16,7 @@ from entropice.dashboard.plots.model_state import ( plot_era5_time_heatmap, plot_top_features, ) +from entropice.dashboard.utils.colors import generate_unified_colormap from entropice.dashboard.utils.data import ( extract_arcticdem_features, extract_common_features, diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/views/overview_page.py similarity index 96% rename from src/entropice/dashboard/overview_page.py rename to src/entropice/dashboard/views/overview_page.py index 830845a..de5e2de 100644 --- a/src/entropice/dashboard/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -8,16 +8,17 @@ import pandas as pd import plotly.express as px import streamlit as st -from entropice.dashboard.plots.colors import get_palette -from entropice.dashboard.utils.data import load_all_training_results +from entropice.dashboard.utils.colors import get_palette +from entropice.dashboard.utils.loaders import load_all_training_results from entropice.ml.dataset import DatasetEnsemble +from entropice.utils.types import Grid, TargetDataset, Task # Type definitions for dataset statistics class GridConfig(TypedDict): """Grid configuration specification with metadata.""" - grid: str + grid: Grid level: int grid_name: str grid_sort_key: str @@ -116,7 +117,7 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache: Results are cached to avoid redundant computations across different tabs. """ # Define all possible grid configurations - grid_configs_raw = [ + grid_configs_raw: list[tuple[Grid, int]] = [ ("hex", 3), ("hex", 4), ("hex", 5), @@ -144,8 +145,8 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache: # Compute sample counts sample_counts: list[SampleCountData] = [] - target_datasets = ["darts_rts", "darts_mllabels"] - tasks = ["binary", "count", "density"] + target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"] + tasks: list[Task] = ["binary", "count", "density"] for grid_config in grid_configs: for target in target_datasets: @@ -810,25 +811,32 @@ def render_experiment_results(training_results): def render_overview_page(): """Render the Overview page of the dashboard.""" - st.title("🏡 Training Results Overview") + st.title("🏡 Entropice Overview") + st.markdown( + """ + Welcome to the Entropice Dashboard! This overview page provides a summary of your + training experiments and dataset analyses. + Use the sections below to explore training results, dataset sample counts, + and feature configurations. + """ + ) # Load training results training_results = load_all_training_results() if not training_results: st.warning("No training results found. Please run some training experiments first.") - return + else: + st.write(f"Found **{len(training_results)}** training result(s)") - st.write(f"Found **{len(training_results)}** training result(s)") + st.divider() - st.divider() + # Render training results sections + render_training_results_summary(training_results) - # Render training results sections - render_training_results_summary(training_results) + st.divider() - st.divider() - - render_experiment_results(training_results) + render_experiment_results(training_results) st.divider() diff --git a/src/entropice/dashboard/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py similarity index 100% rename from src/entropice/dashboard/training_analysis_page.py rename to src/entropice/dashboard/views/training_analysis_page.py diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py similarity index 95% rename from src/entropice/dashboard/training_data_page.py rename to src/entropice/dashboard/views/training_data_page.py index fe91036..275d3c0 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/views/training_data_page.py @@ -15,7 +15,7 @@ from entropice.dashboard.plots.source_data import ( render_era5_plots, ) from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map -from entropice.dashboard.utils.data import load_all_training_data, load_source_data +from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data from entropice.ml.dataset import DatasetEnsemble from entropice.spatial import grids @@ -60,21 +60,12 @@ def render_training_data_page(): # Members selection st.subheader("Dataset Members") - # Check if AlphaEarth should be disabled - disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6) - all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] selected_members = [] for member in all_members: - if member == "AlphaEarth" and disable_alphaearth: - # Show disabled checkbox with explanation - st.checkbox( - member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}" - ) - else: - if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): - selected_members.append(member) + if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): + selected_members.append(member) # Form submit button load_button = st.form_submit_button( diff --git a/src/entropice/ingest/alphaearth.py b/src/entropice/ingest/alphaearth.py index da164b7..55098fe 100644 --- a/src/entropice/ingest/alphaearth.py +++ b/src/entropice/ingest/alphaearth.py @@ -6,7 +6,6 @@ Date: October 2025 import warnings from collections.abc import Generator -from typing import Literal import cudf import cuml.cluster @@ -25,6 +24,7 @@ from rich.progress import track from entropice.spatial import grids from entropice.utils import codecs from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store +from entropice.utils.types import Grid # Filter out the GeoDataFrame.swapaxes deprecation warning warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning) @@ -55,11 +55,11 @@ def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.D @cli.command() -def download(grid: Literal["hex", "healpix"], level: int): +def download(grid: Grid, level: int): """Extract satellite embeddings from Google Earth Engine and map them to a grid. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. """ @@ -126,11 +126,11 @@ def download(grid: Literal["hex", "healpix"], level: int): @cli.command() -def combine_to_zarr(grid: Literal["hex", "healpix"], level: int): +def combine_to_zarr(grid: Grid, level: int): """Combine yearly embeddings parquet files into a single zarr store. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. """ diff --git a/src/entropice/ingest/arcticdem.py b/src/entropice/ingest/arcticdem.py index b98e356..7837111 100644 --- a/src/entropice/ingest/arcticdem.py +++ b/src/entropice/ingest/arcticdem.py @@ -2,7 +2,6 @@ import datetime import multiprocessing as mp from dataclasses import dataclass from math import ceil -from typing import Literal import cupy as cp import cupy_xarray @@ -32,6 +31,7 @@ from entropice.spatial import grids, watermask from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid from entropice.utils import codecs from entropice.utils.paths import get_arcticdem_stores +from entropice.utils.types import Grid traceback.install(show_locals=True, suppress=[cyclopts]) pretty.install() @@ -330,7 +330,7 @@ def _open_adem(): @cli.command() -def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20): +def aggregate(grid: Grid, level: int, concurrent_partitions: int = 20): mp.set_start_method("forkserver", force=True) with ( dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster, diff --git a/src/entropice/ingest/darts.py b/src/entropice/ingest/darts.py index 1a4087c..19c6760 100644 --- a/src/entropice/ingest/darts.py +++ b/src/entropice/ingest/darts.py @@ -6,8 +6,6 @@ Author: Tobias Hölzer Date: October 2025 """ -from typing import Literal - import cyclopts import geopandas as gpd import pandas as pd @@ -17,6 +15,7 @@ from stopuhr import stopwatch from entropice.spatial import grids from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file +from entropice.utils.types import Grid traceback.install() pretty.install() @@ -25,11 +24,11 @@ cli = cyclopts.App(name="darts-rts") @cli.command() -def extract_darts_rts(grid: Literal["hex", "healpix"], level: int): +def extract_darts_rts(grid: Grid, level: int): """Extract RTS labels from DARTS dataset. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. """ @@ -91,7 +90,7 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int): @cli.command() -def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int): +def extract_darts_mllabels(grid: Grid, level: int): with stopwatch("Load data"): grid_gdf = grids.open(grid, level) darts_mllabels = ( diff --git a/src/entropice/ingest/era5.py b/src/entropice/ingest/era5.py index 68cc95c..9859f65 100644 --- a/src/entropice/ingest/era5.py +++ b/src/entropice/ingest/era5.py @@ -101,6 +101,7 @@ from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_g from entropice.spatial.xvec import to_xvec from entropice.utils import codecs from entropice.utils.paths import FIGURES_DIR, get_era5_stores +from entropice.utils.types import Grid traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile]) pretty.install() @@ -579,7 +580,7 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily @cli.command def viz( - grid: Literal["hex", "healpix"] | None = None, + grid: Grid | None = None, level: int | None = None, agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly", high_qual: bool = False, @@ -587,7 +588,7 @@ def viz( """Visualize a small overview of ERA5 variables for a given aggregation. Args: - grid (Literal["hex", "healpix"], optional): Grid type for spatial representation. + grid (Grid, optional): Grid type for spatial representation. If provided along with level, the ERA5 data will be decoded onto the specified grid. level (int, optional): Level of the grid for spatial representation. If provided along with grid, the ERA5 data will be decoded onto the specified grid. @@ -653,7 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset: @cli.command def spatial_agg( - grid: Literal["hex", "healpix"], + grid: Grid, level: int, concurrent_partitions: int = 20, ): diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index eb6fcb8..0344e39 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -16,7 +16,7 @@ import hashlib import json from collections.abc import Generator from dataclasses import asdict, dataclass, field -from functools import cached_property, lru_cache +from functools import cached_property from typing import Literal, TypedDict import cupy as cp @@ -32,6 +32,7 @@ from sklearn import set_config from sklearn.model_selection import train_test_split import entropice.utils.paths +from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task traceback.install() pretty.install() @@ -41,6 +42,27 @@ set_config(array_api_dispatch=True) sns.set_theme("talk", "whitegrid") +covcol: dict[TargetDataset, str] = { + "darts_rts": "darts_has_coverage", + "darts_mllabels": "dartsml_has_coverage", +} + +taskcol: dict[Task, dict[TargetDataset, str]] = { + "binary": { + "darts_rts": "darts_has_rts", + "darts_mllabels": "dartsml_has_rts", + }, + "count": { + "darts_rts": "darts_rts_count", + "darts_mllabels": "dartsml_rts_count", + }, + "density": { + "darts_rts": "darts_rts_density", + "darts_mllabels": "dartsml_rts_density", + }, +} + + def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): time_index = pd.DatetimeIndex(df.index.get_level_values("time")) if temporal == "yearly": @@ -53,10 +75,6 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", " return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_") -type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] -type Task = Literal["binary", "count", "density"] - - def bin_values( values: pd.Series, task: Literal["count", "density"], @@ -144,7 +162,7 @@ class DatasetInputs: @dataclass(frozen=True) class CategoricalTrainingDataset: - dataset: pd.DataFrame + dataset: gpd.GeoDataFrame X: DatasetInputs y: DatasetLabels z: pd.Series @@ -162,12 +180,12 @@ class DatasetStats(TypedDict): @cyclopts.Parameter("*") -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class DatasetEnsemble: - grid: Literal["hex", "healpix"] + grid: Grid level: int target: Literal["darts_rts", "darts_mllabels"] - members: list[L2Dataset] = field( + members: list[L2SourceDataset] = field( default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] ) dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) @@ -186,19 +204,12 @@ class DatasetEnsemble: @property def covcol(self) -> str: - return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage" + return covcol[self.target] def taskcol(self, task: Task) -> str: - if task == "binary": - return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts" - elif task == "count": - return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count" - elif task == "density": - return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density" - else: - raise ValueError(f"Invalid task: {task}") + return taskcol[task][self.target] - def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: + def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: if member == "AlphaEarth": store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: @@ -340,8 +351,9 @@ class DatasetEnsemble: print(f"=== Total number of features in dataset: {stats['total_features']}") - @lru_cache(maxsize=1) - def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: + def create( + self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r" + ) -> gpd.GeoDataFrame: # n: no cache, o: overwrite cache, r: read cache if exists cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) if cache_mode == "r" and cache_file.exists(): @@ -425,26 +437,14 @@ class DatasetEnsemble: print(f"Saved dataset to cache at {cache_file}.") yield dataset - def create_cat_training_dataset( - self, task: Task, device: Literal["cpu", "cuda", "torch"] + def _cat_and_split( + self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"] ) -> CategoricalTrainingDataset: - """Create a categorical dataset for training. - - Args: - task (Task): Task type. - device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto. - - Returns: - CategoricalTrainingDataset: The prepared categorical training dataset. - - """ - covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage" - dataset = self.create(filter_target_col=covcol) taskcol = self.taskcol(task) valid_labels = dataset[taskcol].notna() - cols_to_drop = {"geometry", taskcol, covcol} + cols_to_drop = {"geometry", taskcol, self.covcol} cols_to_drop |= { col for col in dataset.columns @@ -505,3 +505,19 @@ class DatasetEnsemble: z=model_labels, split=split, ) + + def create_cat_training_dataset( + self, task: Task, device: Literal["cpu", "cuda", "torch"] + ) -> CategoricalTrainingDataset: + """Create a categorical dataset for training. + + Args: + task (Task): Task type. + device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto. + + Returns: + CategoricalTrainingDataset: The prepared categorical training dataset. + + """ + dataset = self.create(filter_target_col=self.covcol) + return self._cat_and_split(dataset, task, device) diff --git a/src/entropice/ml/training.py b/src/entropice/ml/training.py index d307854..ae8c563 100644 --- a/src/entropice/ml/training.py +++ b/src/entropice/ml/training.py @@ -2,7 +2,6 @@ import pickle from dataclasses import asdict, dataclass -from typing import Literal import cupy as cp import cyclopts @@ -23,6 +22,7 @@ from xgboost.sklearn import XGBClassifier from entropice.ml.dataset import DatasetEnsemble from entropice.ml.inference import predict_proba from entropice.utils.paths import get_cv_results_dir +from entropice.utils.types import Model, Task traceback.install() pretty.install() @@ -50,12 +50,20 @@ _metrics = { @cyclopts.Parameter("*") -@dataclass +@dataclass(frozen=True, kw_only=True) class CVSettings: n_iter: int = 2000 robust: bool = False - task: Literal["binary", "count", "density"] = "binary" - model: Literal["espa", "xgboost", "rf", "knn"] = "espa" + task: Task = "binary" + model: Model = "espa" + + +@dataclass(frozen=True, kw_only=True) +class TrainingSettings(DatasetEnsemble, CVSettings): + param_grid: dict + cv_splits: int + metrics: list[str] + classes: list[str] def _create_clf( @@ -141,7 +149,7 @@ def random_cv( """Perform random cross-validation on the training dataset. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. robust (bool, optional): Whether to use robust training. Defaults to False. @@ -202,18 +210,18 @@ def random_cv( # Store the search settings # First convert the param_grid distributions to a serializable format param_grid_serializable = _serialize_param_grid(param_grid) - combined_settings = { + combined_settings = TrainingSettings( **asdict(settings), **asdict(dataset_ensemble), - "param_grid": param_grid_serializable, - "cv_splits": cv.get_n_splits(), - "metrics": metrics, - "classes": training_data.y.labels, - } + param_grid=param_grid_serializable, + cv_splits=cv.get_n_splits(), + metrics=metrics, + classes=training_data.y.labels, + ) settings_file = results_dir / "search_settings.toml" print(f"Storing search settings to {settings_file}") with open(settings_file, "w") as f: - toml.dump({"settings": combined_settings}, f) + toml.dump({"settings": asdict(combined_settings)}, f) # Store the best estimator model best_model_file = results_dir / "best_estimator_model.pkl" diff --git a/src/entropice/spatial/aggregators.py b/src/entropice/spatial/aggregators.py index caa2641..d775e87 100644 --- a/src/entropice/spatial/aggregators.py +++ b/src/entropice/spatial/aggregators.py @@ -33,6 +33,7 @@ from stopuhr import stopwatch from xdggs.healpix import HealpixInfo from entropice.spatial import grids +from entropice.utils.types import Grid @dataclass(frozen=True) @@ -585,7 +586,7 @@ def aggregate_raster_into_grid( raster: xr.Dataset | Callable[[], xr.Dataset], grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame], aggregations: _Aggregations | Literal["interpolate"], - grid: Literal["hex", "healpix"], + grid: Grid, level: int, n_partitions: int | None = 20, concurrent_partitions: int = 5, @@ -599,7 +600,7 @@ def aggregate_raster_into_grid( If a list of GeoDataFrames is provided, each will be processed as a separate partition. No further partitioning will be done and the n_partitions argument will be ignored. aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform. - grid (Literal["hex", "healpix"]): The type of grid to use. + grid (Grid): The type of grid to use. level (int): The level of the grid. n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20. concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions. diff --git a/src/entropice/spatial/grids.py b/src/entropice/spatial/grids.py index fb4b15f..3f99867 100644 --- a/src/entropice/spatial/grids.py +++ b/src/entropice/spatial/grids.py @@ -5,7 +5,6 @@ Date: 09. June 2025 """ from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Literal import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -26,16 +25,17 @@ from stopuhr import stopwatch from xdggs.healpix import HealpixInfo from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file +from entropice.utils.types import Grid traceback.install() pretty.install() -def open(grid: Literal["hex", "healpix"], level: int): +def open(grid: Grid, level: int): """Open a saved grid from parquet file. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. Returns: @@ -47,11 +47,11 @@ def open(grid: Literal["hex", "healpix"], level: int): return grid -def get_cell_ids(grid: Literal["hex", "healpix"], level: int): +def get_cell_ids(grid: Grid, level: int): """Get the cell IDs of a saved grid. Args: - grid (Literal["hex", "healpix"]): The grid type to use. + grid (Grid): The grid type to use. level (int): The grid level to use. Returns: @@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure: return fig -def cli(grid: Literal["hex", "healpix"], level: int): +def cli(grid: Grid, level: int): """CLI entry point.""" print(f"Creating {grid} grid at level {level}...") if grid == "hex": diff --git a/src/entropice/utils/paths.py b/src/entropice/utils/paths.py index 5a8a5a2..3c688c3 100644 --- a/src/entropice/utils/paths.py +++ b/src/entropice/utils/paths.py @@ -6,6 +6,8 @@ import os from pathlib import Path from typing import Literal +from entropice.utils.types import Grid, Task + DATA_DIR = ( Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice" ) @@ -42,23 +44,23 @@ dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.par darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps" -def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str: +def _get_gridname(grid: Grid, level: int) -> str: return f"permafrost_{grid}{level}" -def get_grid_file(grid: Literal["hex", "healpix"], level: int) -> Path: +def get_grid_file(grid: Grid, level: int) -> Path: gridname = _get_gridname(grid, level) gridfile = GRIDS_DIR / f"{gridname}_grid.parquet" return gridfile -def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path: +def get_grid_viz_file(grid: Grid, level: int) -> Path: gridname = _get_gridname(grid, level) vizfile = FIGURES_DIR / f"{gridname}_grid.png" return vizfile -def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path: +def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path: gridname = _get_gridname(grid, level) if labels: rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet" @@ -67,13 +69,13 @@ def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool return rtsfile -def get_annual_embeddings_file(grid: Literal["hex", "healpix"], level: int, year: int) -> Path: +def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path: gridname = _get_gridname(grid, level) embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet" return embfile -def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path: +def get_embeddings_store(grid: Grid, level: int) -> Path: gridname = _get_gridname(grid, level) embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr" return embstore @@ -81,9 +83,9 @@ def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path: def get_era5_stores( agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily", - grid: Literal["hex", "healpix"] | None = None, + grid: Grid | None = None, level: int | None = None, -): +) -> Path: if grid is None or level is None: (ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True) return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr" @@ -94,9 +96,9 @@ def get_era5_stores( def get_arcticdem_stores( - grid: Literal["hex", "healpix"] | None = None, + grid: Grid | None = None, level: int | None = None, -): +) -> Path: if grid is None or level is None: return DATA_DIR / "arcticdem32m.icechunk.zarr" gridname = _get_gridname(grid, level) @@ -104,7 +106,7 @@ def get_arcticdem_stores( return aligned_path -def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: +def get_train_dataset_file(grid: Grid, level: int) -> Path: gridname = _get_gridname(grid, level) dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet" return dataset_file @@ -122,9 +124,9 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int def get_cv_results_dir( name: str, - grid: Literal["hex", "healpix"], + grid: Grid, level: int, - task: Literal["binary", "count", "density"], + task: Task, ) -> Path: gridname = _get_gridname(grid, level) now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py new file mode 100644 index 0000000..87c32dd --- /dev/null +++ b/src/entropice/utils/types.py @@ -0,0 +1,11 @@ +"""Shared types used across the entropice codebase.""" + +from typing import Literal + +type Grid = Literal["hex", "healpix"] +type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"] +type TargetDataset = Literal["darts_rts", "darts_mllabels"] +type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] +type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] +type Task = Literal["binary", "count", "density"] +type Model = Literal["espa", "xgboost", "rf", "knn"]