From f8df10f6872cbf450de56dfd208b662dc4f0b333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 28 Dec 2025 20:11:11 +0100 Subject: [PATCH] Add some docs for copilot --- .github/agents/Dashboard.agent.md | 176 ++ .github/copilot-instructions.md | 193 ++ .github/python.instructions.md | 56 + ARCHITECTURE.md | 217 ++ CONTRIBUTING.md | 116 + README.md | 172 +- src/entropice/dashboard/overview_page.py | 20 +- src/entropice/dashboard/plots/inference.py | 8 +- src/entropice/training_analysis_dashboard.py | 1977 ------------------ 9 files changed, 908 insertions(+), 2027 deletions(-) create mode 100644 .github/agents/Dashboard.agent.md create mode 100644 .github/copilot-instructions.md create mode 100644 .github/python.instructions.md create mode 100644 ARCHITECTURE.md create mode 100644 CONTRIBUTING.md delete mode 100644 src/entropice/training_analysis_dashboard.py diff --git a/.github/agents/Dashboard.agent.md b/.github/agents/Dashboard.agent.md new file mode 100644 index 0000000..b9c1d71 --- /dev/null +++ b/.github/agents/Dashboard.agent.md @@ -0,0 +1,176 @@ +--- +description: 'Specialized agent for developing and enhancing the Streamlit dashboard for data and training analysis.' +name: Dashboard-Developer +argument-hint: 'Describe dashboard features, pages, visualizations, or improvements you want to add or modify' +tools: ['edit', 'runNotebooks', 'search', 'runCommands', 'usages', 'problems', 'changes', 'testFailure', 'fetch', 'githubRepo', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'ms-toolsai.jupyter/configureNotebook', 'ms-toolsai.jupyter/listNotebookPackages', 'ms-toolsai.jupyter/installNotebookPackages', 'todos', 'runSubagent', 'runTests'] +--- + +# Dashboard Development Agent + +You are a specialized agent for incrementally developing and enhancing the **Entropice Streamlit Dashboard** used to analyze geospatial machine learning data and training experiments. + +## Your Responsibilities + +### What You Should Do + +1. **Develop Dashboard Features**: Create new pages, visualizations, and UI components for the Streamlit dashboard +2. **Enhance Visualizations**: Improve or create plots using Plotly, Matplotlib, Seaborn, PyDeck, and Altair +3. **Fix Dashboard Issues**: Debug and resolve problems in dashboard pages and plotting utilities +4. **Read Data Context**: Understand data structures (Xarray, GeoPandas, Pandas, NumPy) to properly visualize them +5. **Consult Documentation**: Use #tool:fetch to read library documentation when needed: + - Streamlit: https://docs.streamlit.io/ + - Plotly: https://plotly.com/python/ + - PyDeck: https://deckgl.readthedocs.io/ + - Deck.gl: https://deck.gl/ + - Matplotlib: https://matplotlib.org/ + - Seaborn: https://seaborn.pydata.org/ + - Xarray: https://docs.xarray.dev/ + - GeoPandas: https://geopandas.org/ + - Pandas: https://pandas.pydata.org/pandas-docs/ + - NumPy: https://numpy.org/doc/stable/ + +6. **Understand Data Sources**: Read data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`, `dataset.py`, `training.py`, `inference.py`) to understand data structures—but **NEVER edit them** + +### What You Should NOT Do + +1. **Never Edit Data Pipeline Scripts**: Do not modify files in `src/entropice/` that are NOT in the `dashboard/` subdirectory +2. **Never Edit Training Scripts**: Do not modify `training.py`, `dataset.py`, or any model-related code outside the dashboard +3. **Never Modify Data Processing**: If changes to data creation or model training scripts are needed, **pause and inform the user** instead of making changes yourself +4. **Never Edit Configuration Files**: Do not modify `pyproject.toml`, pipeline scripts in `scripts/`, or configuration files + +### Boundaries + +If you identify that a dashboard improvement requires changes to: +- Data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`) +- Dataset assembly (`dataset.py`) +- Model training (`training.py`, `inference.py`) +- Pipeline automation scripts (`scripts/*.sh`) + +**Stop immediately** and inform the user: +``` +⚠️ This dashboard feature requires changes to the data pipeline/training code. +Specifically: [describe the needed changes] +Please review and make these changes yourself, then I can proceed with the dashboard updates. +``` + +## Dashboard Structure + +The dashboard is located in `src/entropice/dashboard/` with the following structure: + +``` +dashboard/ +├── app.py # Main Streamlit app with navigation +├── overview_page.py # Overview of training results +├── training_data_page.py # Training data visualizations +├── training_analysis_page.py # CV results and hyperparameter analysis +├── model_state_page.py # Feature importance and model state +├── inference_page.py # Spatial prediction visualizations +├── plots/ # Reusable plotting utilities +│ ├── colors.py # Color schemes +│ ├── hyperparameter_analysis.py +│ ├── inference.py +│ ├── model_state.py +│ ├── source_data.py +│ └── training_data.py +└── utils/ # Data loading and processing + ├── data.py + └── training.py +``` + +## Key Technologies + +- **Streamlit**: Web app framework +- **Plotly**: Interactive plots (preferred for most visualizations) +- **Matplotlib/Seaborn**: Statistical plots +- **PyDeck/Deck.gl**: Geospatial visualizations +- **Altair**: Declarative visualizations +- **Bokeh**: Alternative interactive plotting (already used in some places) + +## Critical Code Standards + +### Streamlit Best Practices + +**❌ INCORRECT** (deprecated): +```python +st.plotly_chart(fig, use_container_width=True) +``` + +**✅ CORRECT** (current API): +```python +st.plotly_chart(fig, width='stretch') +``` + +**Common width values**: +- `width='stretch'` - Use full container width (replaces `use_container_width=True`) +- `width='content'` - Use content width (replaces `use_container_width=False`) + +This applies to: +- `st.plotly_chart()` +- `st.altair_chart()` +- `st.vega_lite_chart()` +- `st.dataframe()` +- `st.image()` + +### Data Structure Patterns + +When working with Entropice data: + +1. **Grid Data**: GeoDataFrames with H3/HEALPix cell IDs +2. **L2 Datasets**: Xarray datasets with XDGGS dimensions +3. **Training Results**: Pickled models, Parquet/NetCDF CV results +4. **Predictions**: GeoDataFrames with predicted classes/probabilities + +### Visualization Guidelines + +1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps +2. **Time Series**: Prefer Plotly for interactivity +3. **Distributions**: Use Plotly or Seaborn +4. **Feature Importance**: Use Plotly bar charts +5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates + +## Workflow + +1. **Understand the Request**: Clarify what visualization or feature is needed +2. **Search for Context**: Use #tool:search to find relevant dashboard code and data structures +3. **Read Data Pipeline**: If needed, read (but don't edit) data pipeline scripts to understand data formats +4. **Consult Documentation**: Use #tool:fetch for library documentation when needed +5. **Implement Changes**: Edit dashboard files only +6. **Test Assumptions**: Check for errors with #tool:problems after edits +7. **Track Progress**: Use #tool:todos for multi-step dashboard development + +## Example Interactions + +### ✅ Good Requests (Within Scope) + +- "Add a new page to visualize feature correlations" +- "Create a PyDeck map showing RTS predictions by grid cell" +- "Improve the hyperparameter analysis plot to show confidence intervals" +- "Add a Plotly histogram showing the distribution of RTS density" +- "Fix the deprecation warning about use_container_width" + +### ⚠️ Boundary Cases (Requires User Approval) + +User: "Add a new climate variable to the dashboard" +Agent Response: +``` +⚠️ This requires changes to the data pipeline (era5.py) to extract the new variable. +Please add the variable to the ERA5 processing pipeline first, then I can add it to the dashboard visualizations. +``` + +## Progress Reporting + +For complex dashboard development tasks: + +1. Use #tool:todos to create a task list +2. Mark tasks as in-progress before starting +3. Mark completed immediately after finishing +4. Keep the user informed of progress + +## Remember + +- **Read-only for data pipeline**: You can read any file to understand data structures, but only edit `dashboard/` files +- **Documentation first**: When unsure about Streamlit/Plotly/PyDeck APIs, fetch documentation +- **Modern Streamlit API**: Always use `width='stretch'` instead of `use_container_width=True` +- **Pause when needed**: If data pipeline changes are required, stop and inform the user + +You are here to make the dashboard better, not to change how data is created or models are trained. Stay within these boundaries and you'll be most helpful! diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..3bcf718 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,193 @@ +# Entropice - GitHub Copilot Instructions + +## Project Overview + +Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. The system processes multi-source geospatial data, aggregates it into discrete global grids (H3/HEALPix), and trains probabilistic classifiers to estimate RTS occurrence patterns. + +For detailed architecture information, see [ARCHITECTURE.md](../ARCHITECTURE.md). +For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md). +For project goals and setup, see [README.md](../README.md). + +## Core Technologies + +- **Python**: 3.13 (strict version requirement) +- **Package Manager**: [Pixi](https://pixi.sh/) (not pip/conda directly) +- **GPU**: CUDA 12 with RAPIDS (CuPy, cuML) +- **Geospatial**: xarray, xdggs, GeoPandas, H3, Rasterio +- **ML**: scikit-learn, XGBoost, entropy (eSPA), cuML +- **Storage**: Zarr, Icechunk, Parquet, NetCDF +- **Visualization**: Streamlit, Bokeh, Matplotlib, Cartopy + +## Code Style Guidelines + +### Python Standards + +- Follow **PEP 8** conventions +- Use **type hints** for all function signatures +- Write **numpy-style docstrings** for public functions +- Keep functions **focused and modular** +- Prefer descriptive variable names over abbreviations + +### Geospatial Best Practices + +- Use **EPSG:3413** (Arctic Stereographic) for computations +- Use **EPSG:4326** (WGS84) for visualization and library compatibility +- Store gridded data using **xarray with XDGGS** indexing +- Store tabular data as **Parquet**, array data as **Zarr** +- Leverage **Dask** for lazy evaluation of large datasets +- Use **GeoPandas** for vector operations +- Handle **antimeridian** correctly for polar regions + +### Data Pipeline Conventions + +- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → `02alphaearth.sh` → `03era5.sh` → `04arcticdem.sh` → `05train.sh` +- Each pipeline stage should produce **reproducible intermediate outputs** +- Use `src/entropice/paths.py` for consistent path management +- Environment variable `FAST_DATA_DIR` controls data directory location (default: `./data`) + +### Storage Hierarchy + +All data follows this structure: +``` +DATA_DIR/ +├── grids/ # H3/HEALPix tessellations (GeoParquet) +├── darts/ # RTS labels (GeoParquet) +├── era5/ # Climate data (Zarr) +├── arcticdem/ # Terrain data (Icechunk Zarr) +├── alphaearth/ # Satellite embeddings (Zarr) +├── datasets/ # L2 XDGGS datasets (Zarr) +├── training-results/ # Models, CV results, predictions +└── watermask/ # Ocean mask (GeoParquet) +``` + +## Module Organization + +### Core Modules (`src/entropice/`) + +- **`grids.py`**: H3/HEALPix spatial grid generation with watermask +- **`darts.py`**: RTS label extraction from DARTS v2 dataset +- **`era5.py`**: Climate data processing from ERA5 (Arctic-aligned years: Oct 1 - Sep 30) +- **`arcticdem.py`**: Terrain analysis from 32m Arctic elevation data +- **`alphaearth.py`**: Satellite image embeddings via Google Earth Engine +- **`aggregators.py`**: Raster-to-vector spatial aggregation engine +- **`dataset.py`**: Multi-source data integration and feature engineering +- **`training.py`**: Model training with eSPA, XGBoost, Random Forest, KNN +- **`inference.py`**: Batch prediction pipeline for trained models +- **`paths.py`**: Centralized path management + +### Dashboard (`src/entropice/dashboard/`) + +- Streamlit-based interactive visualization +- Modular pages: overview, training data, analysis, model state, inference +- Bokeh-based geospatial plotting utilities +- Run with: `pixi run dashboard` + +### Scripts (`scripts/`) + +- Numbered pipeline scripts (`00grids.sh` through `05train.sh`) +- Run entire pipeline for multiple grid configurations +- Each script uses CLIs from core modules + +### Notebooks (`notebooks/`) + +- Exploratory analysis and validation +- **NOT committed to git** +- Keep production code in `src/entropice/` + +## Development Workflow + +### Setup + +```bash +pixi install # NOT pip install or conda install +``` + +### Running Tests + +```bash +pixi run pytest +``` + +### Common Tasks + +- **Generate grids**: Use `grids.py` CLI +- **Process labels**: Use `darts.py` CLI +- **Train models**: Use `training.py` CLI with TOML config +- **Run inference**: Use `inference.py` CLI +- **View results**: `pixi run dashboard` + +## Key Design Patterns + +### 1. XDGGS Indexing + +All geospatial data uses discrete global grid systems (H3 or HEALPix) via `xdggs` library for consistent spatial indexing across sources. + +### 2. Lazy Evaluation + +Use Xarray/Dask for out-of-core computation with Zarr/Icechunk chunked storage to manage large datasets. + +### 3. GPU Acceleration + +Prefer GPU-accelerated operations: +- CuPy for array operations +- cuML for Random Forest and KNN +- XGBoost GPU training +- PyTorch tensors when applicable + +### 4. Configuration as Code + +- Use dataclasses for typed configuration +- TOML files for training hyperparameters +- Cyclopts for CLI argument parsing + +## Data Sources + +- **DARTS v2**: RTS labels (year, area, count, density) +- **ERA5**: Climate data (40-year history, Arctic-aligned years) +- **ArcticDEM**: 32m resolution terrain (slope, aspect, indices) +- **AlphaEarth**: 256-dimensional satellite embeddings +- **Watermask**: Ocean exclusion layer + +## Model Support + +Primary: **eSPA** (entropy-optimal Scalable Probabilistic Approximations) +Alternatives: XGBoost, Random Forest, K-Nearest Neighbors + +Training features: +- Randomized hyperparameter search +- K-Fold cross-validation +- Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall) + +## Extension Points + +To extend Entropice: + +- **New data source**: Follow patterns in `era5.py` or `arcticdem.py` +- **Custom aggregations**: Add to `_Aggregations` dataclass in `aggregators.py` +- **Alternative labels**: Implement extractor following `darts.py` pattern +- **New models**: Add scikit-learn compatible estimators to `training.py` +- **Dashboard pages**: Add Streamlit pages to `dashboard/` module + +## Important Notes + +- Grid resolutions: **H3** (3-6), **HEALPix** (6-10) +- Arctic years run **October 1 to September 30** (not calendar years) +- Handle **antimeridian crossing** in polar regions +- Use **batch processing** for GPU memory management +- Notebooks are for exploration only - **keep production code in `src/`** +- Always use **absolute paths** or paths from `paths.py` + +## Common Issues + +- **Memory**: Use batch processing and Dask chunking for large datasets +- **GPU OOM**: Reduce batch size in inference or training +- **Antimeridian**: Use proper handling in `aggregators.py` for polar grids +- **Temporal alignment**: ERA5 uses Arctic-aligned years (Oct-Sep) +- **CRS**: Compute in EPSG:3413, visualize in EPSG:4326 + +## References + +For more details, consult: +- [ARCHITECTURE.md](../ARCHITECTURE.md) - System architecture and design patterns +- [CONTRIBUTING.md](../CONTRIBUTING.md) - Development workflow and standards +- [README.md](../README.md) - Project goals and setup instructions diff --git a/.github/python.instructions.md b/.github/python.instructions.md new file mode 100644 index 0000000..3ae0290 --- /dev/null +++ b/.github/python.instructions.md @@ -0,0 +1,56 @@ +--- +description: 'Python coding conventions and guidelines' +applyTo: '**/*.py,**/*.ipynb' +--- + +# Python Coding Conventions + +## Python Instructions + +- Write clear and concise comments for each function. +- Ensure functions have descriptive names and include type hints. +- Provide docstrings following PEP 257 conventions. +- Use the `typing` module for type annotations (e.g., `List[str]`, `Dict[str, int]`). +- Break down complex functions into smaller, more manageable functions. + +## General Instructions + +- Always prioritize readability and clarity. +- For algorithm-related code, include explanations of the approach used. +- Write code with good maintainability practices, including comments on why certain design decisions were made. +- Handle edge cases and write clear exception handling. +- For libraries or external dependencies, mention their usage and purpose in comments. +- Use consistent naming conventions and follow language-specific best practices. +- Write concise, efficient, and idiomatic code that is also easily understandable. + +## Code Style and Formatting + +- Follow the **PEP 8** style guide for Python. +- Maintain proper indentation (use 4 spaces for each level of indentation). +- Ensure lines do not exceed 79 characters. +- Place function and class docstrings immediately after the `def` or `class` keyword. +- Use blank lines to separate functions, classes, and code blocks where appropriate. + +## Edge Cases and Testing + +- Always include test cases for critical paths of the application. +- Account for common edge cases like empty inputs, invalid data types, and large datasets. +- Include comments for edge cases and the expected behavior in those cases. +- Write unit tests for functions and document them with docstrings explaining the test cases. + +## Example of Proper Documentation + +```python +def calculate_area(radius: float) -> float: + """ + Calculate the area of a circle given the radius. + + Parameters: + radius (float): The radius of the circle. + + Returns: + float: The area of the circle, calculated as π * radius^2. + """ + import math + return math.pi * radius ** 2 +``` diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..8da0fab --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,217 @@ +# Entropice Architecture + +## Overview + +Entropice is a geospatial machine learning system for predicting Retrogressive Thaw Slump (RTS) density across the Arctic using entropy-optimal Scalable Probabilistic Approximations (eSPA). +The system processes multi-source geospatial data, aggregates it into discrete global grids, and trains probabilistic classifiers to estimate RTS occurrence patterns. +It allows for rigorous experimentation with the data and the models, modularizing and abstracting most parts of the data flow pipeline. +Thus, alternative models, datasets and target labels can be used to compare, but also to explore new methodologies other than RTS prediction or eSPA modelling. + +## Core Architecture + +### Data Flow Pipeline + +```txt +Data Sources → Grid Aggregation → Feature Engineering → Dataset Assembly → Model Training → Inference → Visualization +``` + +The pipeline follows a sequential processing approach where each stage produces intermediate datasets that feed into subsequent stages: + +1. **Grid Generation** (H3/HEALPix) → Parquet files +2. **Label Extraction** (DARTS) → Grid-enriched labels +3. **Feature Extraction** (ERA5, ArcticDEM, AlphaEarth) → L2 Datasets (XDGGS Xarray) +4. **Dataset Ensemble** → Training-ready tabular data +5. **Model Training** → Pickled classifiers + CV results +6. **Inference** → GeoDataFrames with predictions +7. **Dashboard** → Interactive visualizations + +### System Components + +#### 1. Spatial Grid System (`grids.py`) + +- **Purpose**: Creates global tessellations (discrete global grid systems) for spatial aggregation +- **Grid Types & Levels**: H3 hexagonal grids and HEALPix grids + - Hex: 3, 4, 5, 6 + - Healpix: 6, 7, 8, 9, 10 +- **Functionality**: + - Generates cell IDs and geometries at configurable resolutions + - Applies watermask to exclude ocean areas + - Provides spatial indexing for efficient data aggregation +- **Output**: GeoDataFrames with cell IDs, geometries, and land areas + +#### 2. Label Management (`darts.py`) + +- **Purpose**: Extracts RTS labels from DARTS v2 dataset +- **Processing**: + - Spatial overlay of DARTS polygons with grid cells + - Temporal aggregation across multiple observation years + - Computes RTS count, area, density, and coverage per cell + - Supports both standard DARTS and ML training labels +- **Output**: Grid-enriched parquet files with RTS metrics + +#### 3. Feature Extractors + +**ERA5 Climate Data (`era5.py`)** + +- Downloads hourly climate variables from Copernicus Climate Data Store +- Computes daily aggregates: temperature extrema, precipitation, snow metrics +- Derives thaw/freeze metrics: degree days, thawing period length +- Temporal aggregations: yearly, seasonal, shoulder seasons +- Uses Arctic-aligned years (October 1 - September 30) + +**ArcticDEM Terrain (`arcticdem.py`)** + +- Processes 32m resolution Arctic elevation data +- Computes terrain derivatives: slope, aspect, curvature +- Calculates terrain indices: TPI, TRI, ruggedness, VRM +- Applies watermask clipping and GPU-accelerated convolutions +- Aggregates terrain statistics per grid cell + +**AlphaEarth Embeddings (`alphaearth.py`)** + +- Extracts 256-dimensional satellite image embeddings via Google Earth Engine +- Uses foundation models to capture visual patterns +- Partitions large grids using KMeans clustering +- Temporal sampling across multiple years + +#### 4. Spatial Aggregation Framework (`aggregators.py`) + +- **Core Capability**: Raster-to-vector aggregation engine +- **Methods**: + - Exact geometry-based aggregation using polygon overlay + - Grid-aligned batch processing for memory efficiency + - Statistical aggregations: mean, sum, std, min, max, median, quantiles +- **Optimization**: + - Antimeridian handling for polar regions + - Parallel processing with worker pools + - GPU acceleration via CuPy/CuML where applicable + +#### 5. Dataset Assembly (`dataset.py`) + +- **DatasetEnsemble Class**: Orchestrates multi-source data integration +- **L2 Datasets**: Standardized XDGGS Xarray datasets per data source +- **Features**: + - Dynamic dataset configuration via dataclasses + - Multi-task support: binary classification, count, density estimation + - Target binning: quantile-based categorical bins + - Train/test splitting with spatial awareness + - GPU-accelerated data loading (PyTorch/CuPy) +- **Output**: Tabular feature matrices ready for scikit-learn API + +#### 6. Model Training (`training.py`) + +- **Supported Models**: + - **eSPA**: Entropy-optimal probabilistic classifier (primary) + - **XGBoost**: Gradient boosting (GPU-accelerated) + - **Random Forest**: cuML implementation + - **K-Nearest Neighbors**: cuML implementation +- **Training Strategy**: + - Randomized hyperparameter search (Scipy distributions) + - K-Fold cross-validation + - Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall) +- **Configuration**: TOML-based configuration with Cyclopts CLI +- **Output**: Pickled models, CV results, feature importance + +#### 7. Inference (`inference.py`) + +- Batch prediction pipeline for trained classifiers +- GPU memory management with configurable batch sizes +- Outputs GeoDataFrames with predicted classes and probabilities +- Supports all trained model types via unified interface + +#### 8. Interactive Dashboard (`dashboard/`) + +- **Technology**: Streamlit-based web application +- **Pages**: + - **Overview**: Result directory summary and metrics + - **Training Data**: Feature distribution visualizations + - **Training Analysis**: CV results, hyperparameter analysis + - **Model State**: Feature importance, decision boundaries + - **Inference**: Spatial prediction maps +- **Plotting Modules**: Bokeh-based interactive geospatial plots + +## Key Design Patterns + +### 1. XDGGS Integration + +All geospatial data is indexed using discrete global grid systems (H3 or HEALPix) via the `xdggs` library, enabling: + +- Consistent spatial indexing across data sources +- Efficient spatial joins and aggregations +- Multi-resolution analysis capabilities + +### 2. Lazy Evaluation & Chunking + +- Xarray/Dask for out-of-core computation +- Zarr/Icechunk for chunked storage +- Batch processing to manage GPU memory + +### 3. GPU Acceleration + +- CuPy for array operations +- cuML for machine learning (RF, KNN) +- XGBoost GPU training +- PyTorch tensor operations + +### 4. Modular CLI Design + +Each module exposes a Cyclopts-based CLI for standalone execution. Super-scripts utilize them to run parts of the data flow for all grid-level combinations: + +```bash +scripts/00grids.sh # Grid generation +scripts/01darts.sh # Label extraction +scripts/02alphaearth.sh # Satellite embeddings +scripts/03era5.sh # Climate features +scripts/04arcticdem.sh # Terrain features +scripts/05train.sh # Model training +``` + +### 5. Configuration as Code + +- Dataclasses for typed configuration +- TOML files for training hyperparameters +- Environment-based path management (`paths.py`) + +## Data Storage Hierarchy + +```sh +DATA_DIR/ +├── grids/ # Hexagonal/HEALPix tessellations (GeoParquet) +├── darts/ # RTS labels (GeoParquet) +├── era5/ # Climate data (Zarr) +├── arcticdem/ # Terrain data (Icechunk Zarr) +├── alphaearth/ # Satellite embeddings (Zarr) +├── datasets/ # L2 XDGGS datasets (Zarr) +├── training-results/ # Models, CV results, predictions (Parquet, NetCDF, Pickle) +└── watermask/ # Ocean mask (GeoParquet) +``` + +## Technology Stack + +**Core**: Python 3.13, NumPy, Pandas, Xarray, GeoPandas +**Spatial**: H3, xdggs, xvec, Shapely, Rasterio +**ML**: scikit-learn, XGBoost, cuML, entropy (eSPA) +**GPU**: CuPy, PyTorch, CUDA +**Storage**: Zarr, Icechunk, Parquet, NetCDF +**Visualization**: Matplotlib, Seaborn, Bokeh, Streamlit, Cartopy, PyDeck, Altair, Plotly +**CLI**: Cyclopts, Rich +**External APIs**: Google Earth Engine, Copernicus Climate Data Store + +## Performance Considerations + +- **Memory Management**: Batch processing with configurable chunk sizes +- **Parallelization**: Multi-process aggregation, Dask distributed computing +- **GPU Utilization**: CuPy/cuML for array operations and ML training +- **Storage Optimization**: Blosc compression, Zarr chunking strategies +- **Spatial Indexing**: Grid-based partitioning for large-scale operations + +## Extension Points + +The architecture supports extension through: + +- **New Data Sources**: Implement feature extractor following ERA5/ArcticDEM patterns +- **Custom Aggregations**: Add methods to `_Aggregations` dataclass +- **Alternative Targets**: Implement label extractor following DARTS pattern +- **Alternative Models**: Extend training CLI with new scikit-learn compatible estimators +- **Dashboard Pages**: Add Streamlit pages to `dashboard/` module +- **Grid Systems**: Support additional DGGS via xdggs integration diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..89f06f0 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,116 @@ +# Contributing to Entropice + +Thank you for your interest in contributing to Entropice! This document provides guidelines for contributing to the project. + +## Getting Started + +### Prerequisites + +- Python 3.13 +- CUDA 12 compatible GPU (for full functionality) +- [Pixi package manager](https://pixi.sh/) + +### Setup + +```bash +pixi install +``` + +This will set up the complete environment including RAPIDS, PyTorch, and all geospatial dependencies. + +## Development Workflow + +### Code Organization + +- **`src/entropice/`**: Core modules (grids, data sources, training, inference) +- **`src/entropice/dashboard/`**: Streamlit visualization dashboard +- **`scripts/`**: Data processing pipeline scripts (numbered 00-05) +- **`notebooks/`**: Exploratory analysis and validation notebooks +- **`tests/`**: Unit tests + +### Key Modules + +- `grids.py`: H3/HEALPix spatial grid systems +- `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`: Data source processors +- `dataset.py`: Dataset assembly and feature engineering +- `training.py`: Model training with eSPA/SPARTAn +- `inference.py`: Prediction generation + +## Coding Standards + +### Python Style + +- Follow PEP 8 conventions +- Use type hints for function signatures +- Prefer numpy-style docstrings for public functions +- Keep functions focused and modular + +### Geospatial Best Practices + +- Use **xarray** with XDGGS for gridded data storage +- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays) +- Leverage **Dask** for lazy evaluation of large datasets +- Use **GeoPandas** for vector operations +- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries + +### Data Pipeline + +- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → ... → `05train.sh` +- Each stage should produce reproducible intermediate outputs +- Document data dependencies in module docstrings +- Use `paths.py` for consistent path management + +## Testing + +Run tests for specific modules: + +```bash +pixi run pytest +``` + +When adding features, include tests that verify: + +- Correct handling of geospatial coordinates and projections +- Proper aggregation to grid cells +- Data integrity through pipeline stages + +## Submitting Changes + +### Pull Request Process + +1. **Branch**: Create a feature branch from `main` +2. **Commit**: Write clear, descriptive commit messages +3. **Test**: Verify your changes don't break existing functionality +4. **Document**: Update relevant docstrings and documentation +5. **PR**: Submit a pull request with a clear description of changes + +### Commit Messages + +- Use present tense: "Add feature" not "Added feature" +- Reference issues when applicable: "Fix #123: Correct grid aggregation" +- Keep first line under 72 characters + +## Working with Data + +### Local Development + +- Set `FAST_DATA_DIR` environment variable for data directory (default: `./data`) + +### Notebooks + +- Notebooks in `notebooks/` are for exploration and validation, they are not commited to git +- Keep production code in `src/entropice/` + +## Dashboard Development + +Run the dashboard locally: + +```bash +pixi run dashboard +``` + +Dashboard code is in `src/entropice/dashboard/` with modular pages and plotting utilities. + +## Questions? + +For questions about the architecture, see `ARCHITECTURE.md`. For scientific background, see `README.md`. diff --git a/README.md b/README.md index a3cb03c..f266b10 100755 --- a/README.md +++ b/README.md @@ -1,45 +1,145 @@ -# eSPA for RTS +# Entropice -Goal of this project is to utilize the entropy-optimal Scalable Probabilistic Approximations algorithm (eSPA) to create a model which can estimate the density of Retrogressive-Thaw-Slumps (RTS) across the globe with different levels of detail. -Hoping, that a successful training could gain new knowledge about RTX-proxies. +**Geospatial Machine Learning for Arctic Permafrost Degradation.** +Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. +The system integrates multi-source geospatial data (climate, terrain, satellite imagery) into discrete global grids and trains probabilistic classifiers to estimate RTS occurrence patterns at multiple spatial resolutions. -## Setup +## Scientific Background -```sh -uv sync +### Retrogressive Thaw Slumps + +Retrogressive Thaw Slumps are Arctic landslides caused by permafrost degradation. As ice-rich permafrost thaws, ground collapses create distinctive bowl-shaped features that retreat upslope over time. RTS are: + +- **Climate indicators**: Sensitive to warming temperatures and changing precipitation +- **Ecological disruptors**: Release sediment, nutrients, and greenhouse gases into Arctic waterways +- **Infrastructure hazards**: Threaten communities and industrial facilities in permafrost regions +- **Feedback mechanisms**: Accelerate local warming through albedo changes and carbon release + +Understanding RTS distribution patterns is critical for predicting permafrost stability under climate change. + +### The Challenge + +Current remote sensing approaches try to map a specific landscape feature and then try to extract spatio-temporal statistical information from that dataset. + +Traditional RTS mapping relies on manual digitization from satellite imagery (e.g., the DARTS v2 training-dataset), which is: + +- Labor-intensive and limited in spatial/temporal coverage +- Challenging due to cloud cover and seasonal visibility +- Insufficient for pan-Arctic prediction at decision-relevant scales + +Modern mapping approaches utilize machine learning to create segmented labels from satellite imagery (e.g. the DARTS dataset), which comes with it own problems: + +- Huge data transfer needed between satellite imagery providers and HPC where the models are run +- Large energy consumtion in both data transfer and inference +- Uncertainty about the quality of the results +- Pot. compute waste when running inference on regions where it is clear that the searched landscape feature does not exist + +### Our Approach + +Instead of global mapping followed by calculation of spatio-temporal statistics, Entropice tries to learn spatio-temporal patterns from a small subset based on a large varyity of data features to get an educated guess about the spatio-temporal statistics of a landscape feature. + +Entropice addresses this by: + +1. **Spatial Discretization across scales**: Representing the Arctic using discrete global grid systems (H3 hexagonal grids, HEALPix) on different low to mid resolutions (levels) +2. **Multi-Source Integration**: Aggregating climate (ERA5), terrain (ArcticDEM), and satellite embeddings (AlphaEarth) into feature-rich datasets to obtain environmental proxies across spatio-temporal scales +3. **Probabilistic Modeling**: Training eSPA classifiers to predict RTS density classes based on environmental proxies + +This hopefully leads to the following advances in permafrost research: + +- Better understanding of RTS occurance + - Potential proxy for Ice-Rich permafrost +- Reduction of compute waste of image segmentation pipelines +- Better modelling by providing better starting conditions + +### Entropy-Optimal Scalable Probabilistic Approximations (eSPA) + +eSPA is a probabilistic classification framework that: + +- Provides calibrated probability estimates (not just point predictions) +- Handles imbalanced datasets common in geospatial phenomena +- Captures uncertainty in predictions across poorly-sampled regions +- Enables interpretable feature importance analysis + +This approach aims to discover which environmental variables best predict RTS occurrence, potentially revealing new proxies for permafrost vulnerability. + +## Key Features + +- **Modular Data Pipeline**: Sequential processing stages from raw data to trained models +- **Multiple Grid Systems**: H3 (resolutions 3-6) and HEALPix (resolutions 6-10) +- **GPU-Accelerated**: RAPIDS (CuPy, cuML) and PyTorch for large-scale computation +- **Interactive Dashboard**: Streamlit-based visualization of training data, results, and predictions +- **Reproducible Workflows**: Configuration-as-code with TOML files and CLI tools +- **Extensible Architecture**: Support for alternative models (XGBoost, Random Forest, KNN) and data sources + +## Quick Start + +### Installation + +Requires Python 3.13 and CUDA 12 compatible GPU. + +```bash +pixi install ``` -## Project Plan +This sets up the complete environment including RAPIDS, PyTorch, and geospatial libraries. -1. Create global hexagon grids with h3 -2. Enrich the grids with data from various sources and with labels from DARTS v2 -3. Use eSPA for simple classification: hex has [many slumps / some slumps / few slumps / no slumps] -4. use SPARTAn for regression: one for slumps density (area) and one for total number of slumps +### Running the Pipeline -### Data Sources and Engineering +Execute the numbered scripts to process data and train models: -- Labels - - `"year"`: Year of observation - - `"area"`: Total land-area of the hexagon - - `"rts_density"`: Area of RTS divided by total land-area - - `"rts_count"`: Number of single RTS instances -- ERA5 (starting 40 years from `"year"`) - - `"temp_yearXXXX_qY"`: Y-th quantile temperature of year XXXX. Used to enter the temperature distribution into the model. - - `"thawing_days_yearXXXX"`: Number of thawing-days of year XXXX. - - `"precip_yearXXXX_qY"`: Y-th quantile precipitation of year XXXX. Similar to temperature. - - `"temp_5year_diff_XXXXtoXXXX_qY"`: Difference of the Y-th quantile temperature between year XXXX and XXXX. Always 5 years difference. - - `"temp_10year_diff_XXXXtoXXXX_qY"`: Difference of the Y-th quantile temperature between year XXXX and XXXX. Always 10 years difference. - - `"temp_diff_qY"`: Difference of the Y-th quantile temperature between year XXXX and XXXX. Always 10 years difference. -- ArcticDEM - - `"dissection_index"`: Dissection Index, (max - min) / max - - `"max_elevation"`: Maximum elevation - - `"elevationX_density"`: Area where the elevation is larger than X divided by the total land-area -- TCVIS - - ??? -- Wildfire??? -- Permafrost??? -- GroundIceContent??? -- Biome +```bash +scripts/00grids.sh # Generate spatial grids +scripts/01darts.sh # Extract RTS labels from DARTS v2 +scripts/02alphaearth.sh # Extract satellite embeddings +scripts/03era5.sh # Process climate data +scripts/04arcticdem.sh # Compute terrain features +scripts/05train.sh # Train models +``` -**About temporals** Every label has its own year - all temporal dependent data features, e.g. `"temp_5year_diff_XXXXtoXXXX_qY"` are calculated respective to that year. -The number of years added from a dataset is always the same, e.g. for ERA5 for an observation in 2024 the ERA5 data would start in 1984 and for an observation from 2023 in 1983. +### Visualizing Results + +Launch the interactive dashboard: + +```bash +pixi run dashboard +``` + +Explore training data distributions, cross-validation results, feature importance, and spatial predictions. + +## Data Sources + +- **DARTS v2**: RTS labels (polygons with year, area, count) +- **ERA5**: Climate reanalysis (40-year history, Arctic-aligned years) +- **ArcticDEM**: 32m resolution terrain elevation +- **AlphaEarth**: 64-dimensional satellite image embeddings + +## Project Structure + +- `src/entropice/`: Core modules (grids, data processors, training, inference) +- `src/entropice/dashboard/`: Streamlit visualization application +- `scripts/`: Data processing pipeline automation +- `notebooks/`: Exploratory analysis (not version-controlled) + +## Documentation + +- **[ARCHITECTURE.md](ARCHITECTURE.md)**: System design, components, and data flow +- **[CONTRIBUTING.md](CONTRIBUTING.md)**: Development guidelines and standards + +## Research Goals + +1. **Predictive Modeling**: Estimate RTS density at unobserved locations +2. **Proxy Discovery**: Identify environmental variables most predictive of RTS occurrence +3. **Multi-Scale Analysis**: Compare model performance across spatial resolutions +4. **Uncertainty Quantification**: Provide calibrated probabilities for decision-making + +## License + +TODO + +## Citation + +If you use Entropice in your research, please cite: + +```txt +TODO +``` diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/overview_page.py index 64ebf34..bb43dd9 100644 --- a/src/entropice/dashboard/overview_page.py +++ b/src/entropice/dashboard/overview_page.py @@ -284,7 +284,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10) fig.update_layout(height=400) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") with tab2: st.markdown("### Sample Counts Bar Chart") @@ -311,7 +311,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): # Update facet labels to be cleaner fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) fig.update_xaxes(tickangle=-45) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") with tab3: st.markdown("### Detailed Sample Counts") @@ -325,7 +325,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache): for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]: display_df[col] = display_df[col].apply(lambda x: f"{x:,}") - st.dataframe(display_df, hide_index=True, use_container_width=True) + st.dataframe(display_df, hide_index=True, width="stretch") def render_feature_count_comparison(cache: DatasetAnalysisCache): @@ -363,7 +363,7 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): ) fig.update_layout(height=500, xaxis_tickangle=-45) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") # Add secondary metrics col1, col2 = st.columns(2) @@ -384,7 +384,7 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): ) fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside") fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False) - st.plotly_chart(fig_cells, use_container_width=True) + st.plotly_chart(fig_cells, width="stretch") with col2: fig_samples = px.bar( @@ -399,7 +399,7 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): ) fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside") fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False) - st.plotly_chart(fig_samples, use_container_width=True) + st.plotly_chart(fig_samples, width="stretch") with comp_tab2: st.markdown("#### Feature Breakdown by Data Source") @@ -435,7 +435,7 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): ) fig.update_traces(textposition="inside", textinfo="percent") fig.update_layout(showlegend=True, height=350) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") with comp_tab3: st.markdown("#### Detailed Feature Count Comparison") @@ -452,7 +452,7 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache): # Format boolean as Yes/No display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗") - st.dataframe(display_df, hide_index=True, use_container_width=True) + st.dataframe(display_df, hide_index=True, width="stretch") @st.fragment @@ -581,10 +581,10 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache): color_discrete_sequence=source_colors, ) fig.update_traces(textposition="inside", textinfo="percent+label") - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") # Show detailed table - st.dataframe(breakdown_df, hide_index=True, use_container_width=True) + st.dataframe(breakdown_df, hide_index=True, width="stretch") # Detailed member information with st.expander("📦 Detailed Source Information", expanded=False): diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py index b6c8720..3c2b57f 100644 --- a/src/entropice/dashboard/plots/inference.py +++ b/src/entropice/dashboard/plots/inference.py @@ -108,7 +108,7 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: xaxis={"tickangle": -45 if len(categories) > 3 else 0}, ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") # Show percentages in a table with st.expander("📋 Detailed Class Distribution", expanded=False): @@ -119,7 +119,7 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: "Percentage": (class_counts.to_numpy() / len(predictions_gdf) * 100).round(2), } ) - st.dataframe(distribution_df, hide_index=True, use_container_width=True) + st.dataframe(distribution_df, hide_index=True, width="stretch") def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame): @@ -394,7 +394,7 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str): showlegend=True, ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") with col2: st.markdown("**Cumulative Distribution") @@ -426,4 +426,4 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str): yaxis={"range": [0, 105]}, ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width="stretch") diff --git a/src/entropice/training_analysis_dashboard.py b/src/entropice/training_analysis_dashboard.py deleted file mode 100644 index d1f7952..0000000 --- a/src/entropice/training_analysis_dashboard.py +++ /dev/null @@ -1,1977 +0,0 @@ -"""Streamlit dashboard for training analysis results visualization.""" - -from datetime import datetime -from pathlib import Path - -import altair as alt -import geopandas as gpd -import matplotlib.colors as mcolors -import matplotlib.path -import numpy as np -import pandas as pd -import streamlit as st -import streamlit_folium as st_folium -import toml -import ultraplot as uplt -import xarray as xr -import xdggs -from matplotlib.patches import PathPatch - -from entropice.paths import RESULTS_DIR -from entropice.training import create_xy_data - - -def generate_unified_colormap(settings: dict): - """Generate unified colormaps for all plotting libraries. - - This function creates consistent color schemes across Matplotlib/Ultraplot, - Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number - of classes from the settings, then generating appropriate colormaps for each library. - - Args: - settings: Settings dictionary containing task type, classes, and other configuration. - - Returns: - Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where: - - matplotlib_cmap: matplotlib ListedColormap object - - folium_cmap: matplotlib ListedColormap object (for geopandas.explore) - - altair_colors: list of hex color strings for Altair - - """ - # Determine task type and number of classes from settings - task = settings.get("task", "binary") - n_classes = len(settings.get("classes", [])) - - # Check theme - is_dark_theme = st.context.theme.type == "dark" - - # Define base colormaps for different tasks - if task == "binary": - # For binary: use a simple two-color scheme - if is_dark_theme: - base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme - else: - base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme - else: - # For multi-class: use a sequential colormap - # Use matplotlib's viridis colormap as base (better perceptual uniformity than inferno) - cmap = uplt.Colormap("viridis") - # Sample colors evenly across the colormap - indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends - base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices] - - # Create matplotlib colormap (for ultraplot and geopandas) - matplotlib_cmap = mcolors.ListedColormap(base_colors) - - # Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps) - folium_cmap = mcolors.ListedColormap(base_colors) - - # Create Altair color list (Altair uses hex color strings in range) - altair_colors = base_colors - - return matplotlib_cmap, folium_cmap, altair_colors - - -def format_metric_name(metric: str) -> str: - """Format metric name for display. - - Args: - metric: Raw metric name (e.g., 'f1_micro', 'precision_macro'). - - Returns: - Formatted metric name (e.g., 'F1 Micro', 'Precision Macro'). - - """ - # Split by underscore and capitalize each part - parts = metric.split("_") - # Special handling for F1 - formatted_parts = [] - for part in parts: - if part.lower() == "f1": - formatted_parts.append("F1") - else: - formatted_parts.append(part.capitalize()) - return " ".join(formatted_parts) - - -def get_available_result_files() -> list[Path]: - """Get all available result files from RESULTS_DIR.""" - if not RESULTS_DIR.exists(): - return [] - - result_files = [] - for search_dir in RESULTS_DIR.iterdir(): - if not search_dir.is_dir(): - continue - - result_file = search_dir / "search_results.parquet" - state_file = search_dir / "best_estimator_state.nc" - preds_file = search_dir / "predicted_probabilities.parquet" - settings_file = search_dir / "search_settings.toml" - if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists(): - result_files.append(search_dir) - - def _key_func(path: Path): - return path.stat().st_mtime - - return sorted(result_files, key=_key_func, reverse=True) # Most recent first - - -def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame: - """Load results file and prepare binned columns. - - Args: - file_path: Path to the results parquet file. - settings: Dictionary of search settings. - k_bin_width: Width of bins for initial_K parameter. - - Returns: - DataFrame with added binned columns. - - """ - results = pd.read_parquet(file_path) - - # Automatically determine bin width for initial_K based on data range - k_min = settings["param_grid"]["initial_K"]["low"] - k_max = settings["param_grid"]["initial_K"]["high"] - # Use configurable bin width, adapted to actual data range - k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) - results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) - - # Automatically create logarithmic bins for epsilon parameters based on data range - # Use 10 bins spanning the actual data range - eps_cl_min = np.log10(settings["param_grid"]["eps_cl"]["low"]) - eps_cl_max = np.log10(settings["param_grid"]["eps_cl"]["high"]) - eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=int(eps_cl_max - eps_cl_min + 1)) - - eps_e_min = np.log10(settings["param_grid"]["eps_e"]["low"]) - eps_e_max = np.log10(settings["param_grid"]["eps_e"]["high"]) - eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=int(eps_e_max - eps_e_min + 1)) - - results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) - results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) - - return results - - -def load_and_prepare_model_state(file_path: Path) -> xr.Dataset: - """Load a model state from a NetCDF file. - - Args: - file_path (Path): The path to the NetCDF file. - - Returns: - xr.Dataset: The model state as an xarray Dataset. - - """ - return xr.open_dataset(file_path, engine="h5netcdf") - - -def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None: - """Extract embedding features from the model state. - - Args: - model_state: The xarray Dataset containing the model state. - - Returns: - xr.DataArray: The extracted embedding features. This DataArray has dimensions - ('agg', 'band', 'year') corresponding to the different components of the embedding features. - Returns None if no embedding features are found. - - """ - - def _is_embedding_feature(feature: str) -> bool: - return feature.startswith("embeddings_") - - embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] - if len(embedding_features) == 0: - return None - - # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) - embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] - embedding_feature_array = embedding_feature_array.assign_coords( - agg=("feature", [f.split("_")[1] for f in embedding_features]), - band=("feature", [f.split("_")[2] for f in embedding_features]), - year=("feature", [f.split("_")[3] for f in embedding_features]), - ) - embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 - return embedding_feature_array - - -def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None: - """Extract ERA5 features from the model state. - - Args: - model_state: The xarray Dataset containing the model state. - - Returns: - xr.DataArray: The extracted ERA5 features. This DataArray has dimensions - ('variable', 'time') corresponding to the different components of the ERA5 features. - Returns None if no ERA5 features are found. - - """ - - def _is_era5_feature(feature: str) -> bool: - return feature.startswith("era5_") - - def _extract_var_name(feature: str) -> str: - feature = feature.replace("era5_", "") - if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]): - return feature.rsplit("_", 2)[0] - else: - return feature.rsplit("_", 1)[0] - - def _extract_time_name(feature: str) -> str: - feature = feature.replace("era5_", "") - if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]): - return "_".join(feature.rsplit("_", 2)[-2:]) - else: - return feature.rsplit("_", 1)[-1] - - era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] - if len(era5_features) == 0: - return None - # Split the single feature dimension of era5 features into separate dimensions (variable, time) - era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] - era5_features_array = era5_features_array.assign_coords( - variable=("feature", [_extract_var_name(f) for f in era5_features]), - time=("feature", [_extract_time_name(f) for f in era5_features]), - ) - era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 - return era5_features_array - - -def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None: - """Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state. - - Args: - model_state: The xarray Dataset containing the model state. - - Returns: - xr.DataArray: The extracted common features with a single 'feature' dimension. - Returns None if no common features are found. - - """ - common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"] - - def _is_common_feature(feature: str) -> bool: - return feature in common_feature_names - - common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)] - if len(common_features) == 0: - return None - - # Extract the feature weights for common features - common_feature_array = model_state.sel(feature=common_features)["feature_weights"] - return common_feature_array - - -def _plot_prediction_map_static(preds: gpd.GeoDataFrame, matplotlib_cmap: mcolors.ListedColormap): - """Create a static map of predictions using ultraplot and cartopy. - - Args: - preds: GeoDataFrame with predicted classes (intervals as strings). - matplotlib_cmap: Matplotlib ListedColormap object. - - Returns: - Matplotlib figure with predictions colored by class. - - """ - # Create a copy to avoid modifying the original data - preds_plot = preds.copy() - - # Replace the special (-1, 0] interval with "No RTS" - preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(class_str.split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - # Get unique classes and sort them - unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key) - - # Create categorical with ordered categories - preds_plot["predicted_class"] = pd.Categorical( - preds_plot["predicted_class"], categories=unique_classes, ordered=True - ) - - # Detect theme for styling - theme = st.context.theme.type - if theme == "light": - figcolor = "#f8f9fa" - axcolor = "#ffffff" - oceancolor = "#e3f2fd" - landcolor = "#ecf0f1" - gridcolor = "#666666" - fgcolor = "#2c3e50" - coastcolor = "#34495e" - else: - figcolor = "#0e1117" - axcolor = "#0e1117" - oceancolor = "#1a1d29" - landcolor = "#262730" - gridcolor = "#ffffff" - fgcolor = "#fafafa" - coastcolor = "#d0d0d0" - - # Create figure with North Polar Stereographic projection - # proj = uplt.Proj("npaeqd") - proj = uplt.Proj("npstere", lon_0=-45) - fig, ax = uplt.subplots(proj=proj, figsize=(12, 12), facecolor=figcolor) - - # Put a background on the figure. - bgpatch = PathPatch(matplotlib.path.Path.unit_rectangle(), transform=fig.transFigure, color=figcolor, zorder=-1) - fig.patches.append(bgpatch) - - # Apply theme-appropriate styling with enhanced aesthetics - ax.format( - boundinglat=50, - coast=True, - coastcolor=coastcolor, - coastlinewidth=0.8, - land=True, - landcolor=landcolor, - ocean=True, - oceancolor=oceancolor, - title="Predicted RTS Classes", - titlecolor=fgcolor, - titleweight="bold", - titlesize=14, - color=fgcolor, - gridcolor=gridcolor, - gridlinewidth=0.5, - gridlinestyle="-", - gridalpha=0.3, - facecolor=axcolor, - labels=True, - labelcolor=fgcolor, - latlines=10, - lonlines=30, - ) - - # Plot the predictions using the provided colormap with enhanced styling - preds_plot.to_crs(proj.proj4_init).plot( - ax=ax, - column="predicted_class", - cmap=matplotlib_cmap, - legend=True, - legend_kwds={ - "loc": "lower left", - "frameon": True, - "framealpha": 0.95, - "edgecolor": gridcolor, - "facecolor": figcolor, - "fancybox": True, - "shadow": False, - "title": "RTS Class", - "title_fontsize": 11, - "fontsize": 9, - "labelspacing": 0.8, - "borderpad": 1.0, - "columnspacing": 1.0, - }, - edgecolor="none", - linewidth=0, - ) - - # Enhance legend styling after creation - legend = ax.get_legend() - if legend: - legend.set_title("RTS Class", prop={"size": 11, "weight": "bold"}) - legend.get_frame().set_linewidth(0.5) - legend.get_frame().set_edgecolor(gridcolor) - legend.get_frame().set_facecolor(figcolor) - legend.get_frame().set_alpha(0.95) - # Set text color for legend labels - for text in legend.get_texts(): - text.set_color(fgcolor) - # Set title color - legend.get_title().set_color(fgcolor) - - return fig - - -def _plot_prediction_map(preds: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap): - """Create an interactive map of predictions with properly sorted classes. - - Args: - preds: GeoDataFrame with predicted classes (intervals as strings). - folium_cmap: Matplotlib ListedColormap object (for geopandas.explore). - - Returns: - Folium map with predictions colored by class. - - """ - # Create a copy to avoid modifying the original data - preds_plot = preds.copy() - - # Replace the special (-1, 0] interval with "No RTS" - preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(class_str.split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - # Get unique classes and sort them - unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key) - - # Create categorical with ordered categories - preds_plot["predicted_class"] = pd.Categorical( - preds_plot["predicted_class"], categories=unique_classes, ordered=True - ) - - # Select tiles based on theme - tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron" - return preds_plot.explore(column="predicted_class", cmap=folium_cmap, legend=True, tiles=tiles) - - -def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame, altair_colors: list[str]): - """Create a bar chart showing the count of each predicted class. - - Args: - preds: GeoDataFrame with predicted classes. - altair_colors: List of hex color strings for altair. - - Returns: - Altair chart showing class distribution. - - """ - # Create a copy and apply the same transformations as the map - preds_plot = preds.copy() - preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(class_str.split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - df = pd.DataFrame({"predicted_class": preds_plot["predicted_class"].to_numpy()}) - counts = df["predicted_class"].value_counts().reset_index() - counts.columns = ["class", "count"] - counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2) - - # Sort counts by the same key - counts["sort_key"] = counts["class"].apply(sort_key) - counts = counts.sort_values("sort_key") - - # Create an ordered list of classes for consistent color mapping - class_order = counts["class"].tolist() - - chart = ( - alt.Chart(counts) - .mark_bar() - .encode( - x=alt.X("class:N", title="Predicted Class", sort=class_order, axis=alt.Axis(labelAngle=0)), - y=alt.Y("count:Q", title="Number of Cells"), - color=alt.Color( - "class:N", - title="Class", - scale=alt.Scale(domain=class_order, range=altair_colors), - legend=None, - ), - tooltip=[ - alt.Tooltip("class:N", title="Class"), - alt.Tooltip("count:Q", title="Count"), - alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"), - ], - ) - .properties( - width=400, - height=300, - title="Predicted Class Distribution", - ) - ) - - return chart - - -def _plot_k_binned( - results: pd.DataFrame, - target: str, - *, - vmin_percentile: float | None = None, - vmax_percentile: float | None = None, -): - """Plot K-binned results with epsilon parameters.""" - assert vmin_percentile is None or vmax_percentile is None, ( - "Only one of vmin_percentile or vmax_percentile can be set." - ) - assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results." - assert target in results.columns, f"{target} column not found in results." - assert "eps_e" in results.columns, "eps_e column not found in results." - assert "eps_cl" in results.columns, "eps_cl column not found in results." - - # Prepare data - plot_data = results[["eps_e", "eps_cl", "initial_K_binned", target]].copy() - - # Sort bins by their left value and convert to string with sorted categories - plot_data = plot_data.sort_values("initial_K_binned") - plot_data["initial_K_binned"] = plot_data["initial_K_binned"].astype(str) - bin_order = plot_data["initial_K_binned"].unique().tolist() - - # Determine color scale domain - if vmin_percentile is not None: - vmin = np.percentile(results[target], vmin_percentile) - color_scale = alt.Scale(scheme="viridis", domain=[vmin, plot_data[target].max()]) - elif vmax_percentile is not None: - vmax = np.percentile(results[target], vmax_percentile) - color_scale = alt.Scale(scheme="viridis", domain=[plot_data[target].min(), vmax]) - else: - color_scale = alt.Scale(scheme="viridis") - - # Create the chart - chart = ( - alt.Chart(plot_data) - .mark_circle(size=60, opacity=0.7) - .encode( - x=alt.X( - "eps_e:Q", - scale=alt.Scale(type="log"), - axis=alt.Axis(title="eps_e", grid=True, gridOpacity=0.5), - ), - y=alt.Y( - "eps_cl:Q", - scale=alt.Scale(type="log"), - axis=alt.Axis(title="eps_cl", grid=True, gridOpacity=0.5), - ), - color=alt.Color(f"{target}:Q", scale=color_scale, title=target), - tooltip=["eps_e:Q", "eps_cl:Q", alt.Tooltip(f"{target}:Q", format=".4f"), "initial_K_binned:N"], - ) - .properties(width=200, height=200) - .facet(facet=alt.Facet("initial_K_binned:N", title="Initial K", sort=bin_order), columns=5) - ) - - return chart - - -def _plot_params_binned(results: pd.DataFrame, x: str, hue: str, col: str, metric: str): - """Plot epsilon-binned results with K parameter.""" - assert metric in results.columns, f"{metric} not found in results." - assert x in ["eps_cl", "eps_e", "initial_K"] - assert hue in ["eps_cl", "eps_e", "initial_K"] - assert col in ["eps_cl_binned", "eps_e_binned", "initial_K_binned"] - - assert x in results.columns, f"{x} column not found in results." - assert hue in results.columns, f"{hue} column not found in results." - assert col in results.columns, f"{col} column not found in results." - - # Prepare data - plot_data = results[[x, metric, hue, col]].copy() - - # Sort bins by their left value and convert to string with sorted categories - plot_data = plot_data.sort_values(col) - plot_data[col] = plot_data[col].astype(str) - bin_order = plot_data[col].unique().tolist() - xscale = alt.Scale(type="log") if x in ["eps_cl", "eps_e"] else alt.Scale() - cscheme = "bluepurple" if hue == "eps_e" else "purplered" if hue == "eps_cl" else "greenblue" - # Create the chart - chart = ( - alt.Chart(plot_data) - .mark_circle(size=60, opacity=0.7) - .encode( - x=alt.X(f"{x}:Q", title=x, scale=xscale), - y=alt.Y(f"{metric}:Q", title=metric), - color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme=cscheme), title=hue), - tooltip=[ - f"{x}:Q", - alt.Tooltip(f"{metric}:Q", format=".4f"), - alt.Tooltip(f"{hue}:Q", format=".2e"), - f"{col}:N", - ], - ) - .properties(width=200, height=200) - .facet(facet=alt.Facet(f"{col}:N", title=col.replace("_binned", ""), sort=bin_order), columns=5) - ) - - return chart - - -def _parse_results_dir_name(results_dir: Path) -> str: - gridname, date = results_dir.name.replace("_binary", "").replace("_multi", "").split("_random_search_cv") - gridname = gridname.lstrip("permafrost_") - date = datetime.strptime(date, "%Y%m%d-%H%M%S") - date = date.strftime("%Y-%m-%d %H:%M:%S") - - settings = toml.load(results_dir / "search_settings.toml")["settings"] - task = settings.get("task", "binary") - return f"[{task.capitalize()}] {gridname.capitalize()} ({date})" - - -def _plot_top_features(model_state: xr.Dataset, top_n: int = 10): - """Plot the top N most important features based on feature weights. - - Args: - model_state: The xarray Dataset containing the model state. - top_n: Number of top features to display. - - Returns: - Altair chart showing the top features by importance. - - """ - # Extract feature weights - feature_weights = model_state["feature_weights"].to_pandas() - - # Sort by absolute weight and take top N - top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True) - - # Create DataFrame for plotting with original (signed) weights - plot_data = pd.DataFrame( - { - "feature": top_features.index, - "weight": feature_weights.loc[top_features.index].to_numpy(), - "abs_weight": top_features.to_numpy(), - } - ) - - # Create horizontal bar chart - chart = ( - alt.Chart(plot_data) - .mark_bar() - .encode( - y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), - x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), - color=alt.condition( - alt.datum.weight > 0, - alt.value("steelblue"), # Positive weights - alt.value("coral"), # Negative weights - ), - tooltip=[ - alt.Tooltip("feature:N", title="Feature"), - alt.Tooltip("weight:Q", format=".4f", title="Weight"), - alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), - ], - ) - .properties( - width=600, - height=400, - title=f"Top {top_n} Most Important Features", - ) - ) - - return chart - - -def _plot_embedding_heatmap(embedding_array: xr.DataArray): - """Create a heatmap showing embedding feature weights across bands and years. - - Args: - embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. - - Returns: - Altair chart showing the heatmap. - - """ - # Convert to DataFrame for plotting - df = embedding_array.to_dataframe(name="weight").reset_index() - - # Create faceted heatmap - chart = ( - alt.Chart(df) - .mark_rect() - .encode( - x=alt.X("year:O", title="Year"), - y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), - color=alt.Color( - "weight:Q", - scale=alt.Scale(scheme="redblue", domainMid=0), - title="Weight", - ), - tooltip=[ - alt.Tooltip("agg:N", title="Aggregation"), - alt.Tooltip("band:N", title="Band"), - alt.Tooltip("year:O", title="Year"), - alt.Tooltip("weight:Q", format=".4f", title="Weight"), - ], - ) - .properties(width=200, height=200) - .facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11) - ) - - return chart - - -def _plot_embedding_aggregation_summary(embedding_array: xr.DataArray): - """Create bar charts summarizing embedding weights by aggregation, band, and year. - - Args: - embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. - - Returns: - Tuple of three Altair charts (by_agg, by_band, by_year). - - """ - # Aggregate by different dimensions - by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs() - by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs() - by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs() - - # Create DataFrames - df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()}) - df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()}) - df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()}) - - # Sort by weight - df_agg = df_agg.sort_values("mean_abs_weight", ascending=True) - df_band = df_band.sort_values("mean_abs_weight", ascending=True) - df_year = df_year.sort_values("mean_abs_weight", ascending=True) - - # Create charts with different colors - chart_agg = ( - alt.Chart(df_agg) - .mark_bar() - .encode( - y=alt.Y("dimension:N", title="Aggregation", sort="-x"), - x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), - color=alt.Color( - "mean_abs_weight:Q", - scale=alt.Scale(scheme="blues"), - legend=None, - ), - tooltip=[ - alt.Tooltip("dimension:N", title="Aggregation"), - alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), - ], - ) - .properties(width=250, height=200, title="By Aggregation") - ) - - chart_band = ( - alt.Chart(df_band) - .mark_bar() - .encode( - y=alt.Y("dimension:N", title="Band", sort="-x"), - x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), - color=alt.Color( - "mean_abs_weight:Q", - scale=alt.Scale(scheme="greens"), - legend=None, - ), - tooltip=[ - alt.Tooltip("dimension:N", title="Band"), - alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), - ], - ) - .properties(width=250, height=200, title="By Band") - ) - - chart_year = ( - alt.Chart(df_year) - .mark_bar() - .encode( - y=alt.Y("dimension:O", title="Year", sort="-x"), - x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), - color=alt.Color( - "mean_abs_weight:Q", - scale=alt.Scale(scheme="oranges"), - legend=None, - ), - tooltip=[ - alt.Tooltip("dimension:O", title="Year"), - alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), - ], - ) - .properties(width=250, height=200, title="By Year") - ) - - return chart_agg, chart_band, chart_year - - -def _plot_era5_heatmap(era5_array: xr.DataArray): - """Create a heatmap showing ERA5 feature weights across variables and time. - - Args: - era5_array: DataArray with dimensions (variable, time) containing feature weights. - - Returns: - Altair chart showing the heatmap. - - """ - # Convert to DataFrame for plotting - df = era5_array.to_dataframe(name="weight").reset_index() - - # Create heatmap - chart = ( - alt.Chart(df) - .mark_rect() - .encode( - x=alt.X("time:N", title="Time", sort=None), - y=alt.Y("variable:N", title="Variable", sort="-color"), - color=alt.Color( - "weight:Q", - scale=alt.Scale(scheme="redblue", domainMid=0), - title="Weight", - ), - tooltip=[ - alt.Tooltip("variable:N", title="Variable"), - alt.Tooltip("time:N", title="Time"), - alt.Tooltip("weight:Q", format=".4f", title="Weight"), - ], - ) - .properties( - height=400, - title="ERA5 Feature Weights Heatmap", - ) - ) - - return chart - - -def _plot_era5_summary(era5_array: xr.DataArray): - """Create bar charts summarizing ERA5 weights by variable and time. - - Args: - era5_array: DataArray with dimensions (variable, time) containing feature weights. - - Returns: - Tuple of two Altair charts (by_variable, by_time). - - """ - # Aggregate by different dimensions - by_variable = era5_array.mean(dim="time").to_pandas().abs() - by_time = era5_array.mean(dim="variable").to_pandas().abs() - - # Create DataFrames - df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()}) - df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()}) - - # Sort by weight - df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) - df_time = df_time.sort_values("mean_abs_weight", ascending=True) - - # Create charts with different colors - chart_variable = ( - alt.Chart(df_variable) - .mark_bar() - .encode( - y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), - x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), - color=alt.Color( - "mean_abs_weight:Q", - scale=alt.Scale(scheme="purples"), - legend=None, - ), - tooltip=[ - alt.Tooltip("dimension:N", title="Variable"), - alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), - ], - ) - .properties(width=400, height=300, title="By Variable") - ) - - chart_time = ( - alt.Chart(df_time) - .mark_bar() - .encode( - y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), - x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), - color=alt.Color( - "mean_abs_weight:Q", - scale=alt.Scale(scheme="teals"), - legend=None, - ), - tooltip=[ - alt.Tooltip("dimension:N", title="Time"), - alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), - ], - ) - .properties(width=400, height=300, title="By Time") - ) - - return chart_variable, chart_time - - -def _plot_box_assignments(model_state: xr.Dataset): - """Create a heatmap showing which boxes are assigned to which labels/classes. - - Args: - model_state: The xarray Dataset containing the model state with box_assignments. - - Returns: - Altair chart showing the box-to-label assignment heatmap. - - """ - # Extract box assignments - box_assignments = model_state["box_assignments"] - - # Convert to DataFrame for plotting - df = box_assignments.to_dataframe(name="assignment").reset_index() - - # Create heatmap - chart = ( - alt.Chart(df) - .mark_rect() - .encode( - x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)), - y=alt.Y("class:N", title="Class Label"), - color=alt.Color( - "assignment:Q", - scale=alt.Scale(scheme="viridis"), - title="Assignment Strength", - ), - tooltip=[ - alt.Tooltip("class:N", title="Class"), - alt.Tooltip("box:O", title="Box"), - alt.Tooltip("assignment:Q", format=".4f", title="Assignment"), - ], - ) - .properties( - height=150, - title="Box-to-Label Assignments (Lambda Matrix)", - ) - ) - - return chart - - -def _plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]): - """Create a bar chart showing how many boxes are assigned to each class. - - Args: - model_state: The xarray Dataset containing the model state with box_assignments. - altair_colors: List of hex color strings for altair. - - Returns: - Altair chart showing count of boxes per class. - - """ - # Extract box assignments - box_assignments = model_state["box_assignments"] - - # Convert to DataFrame - df = box_assignments.to_dataframe(name="assignment").reset_index() - - # For each box, find which class it's most strongly assigned to - box_to_class = df.groupby("box")["assignment"].idxmax() - primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True) - - # Count boxes per class - counts = primary_classes.groupby("class").size().reset_index(name="count") - - # Replace the special (-1, 0] interval with "No RTS" if present - counts["class"] = counts["class"].replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(str(class_str).split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - # Sort counts by the same key - counts["sort_key"] = counts["class"].apply(sort_key) - counts = counts.sort_values("sort_key") - - # Create an ordered list of classes for consistent color mapping - class_order = counts["class"].tolist() - - # Create bar chart - chart = ( - alt.Chart(counts) - .mark_bar() - .encode( - x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)), - y=alt.Y("count:Q", title="Number of Boxes"), - color=alt.Color( - "class:N", - title="Class", - scale=alt.Scale(domain=class_order, range=altair_colors), - legend=None, - ), - tooltip=[ - alt.Tooltip("class:N", title="Class"), - alt.Tooltip("count:Q", title="Number of Boxes"), - ], - ) - .properties( - width=600, - height=300, - title="Number of Boxes Assigned to Each Class (by Primary Assignment)", - ) - ) - - return chart - - -def _plot_common_features(common_array: xr.DataArray): - """Create a bar chart showing the weights of common features. - - Args: - common_array: DataArray with dimension (feature) containing feature weights. - - Returns: - Altair chart showing the common feature weights. - - """ - # Convert to DataFrame for plotting - df = common_array.to_dataframe(name="weight").reset_index() - - # Sort by absolute weight - df["abs_weight"] = df["weight"].abs() - df = df.sort_values("abs_weight", ascending=True) - - # Create bar chart - chart = ( - alt.Chart(df) - .mark_bar() - .encode( - y=alt.Y("feature:N", title="Feature", sort="-x"), - x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), - color=alt.condition( - alt.datum.weight > 0, - alt.value("steelblue"), # Positive weights - alt.value("coral"), # Negative weights - ), - tooltip=[ - alt.Tooltip("feature:N", title="Feature"), - alt.Tooltip("weight:Q", format=".4f", title="Weight"), - alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), - ], - ) - .properties( - width=600, - height=300, - title="Common Feature Weights", - ) - ) - - return chart - - -def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str, refit_metric: str): - """Create a scatter plot comparing two metrics with parameter-based coloring. - - Args: - results: DataFrame containing the results with metric columns. - x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy'). - y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard'). - color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e'). - refit_metric: The metric used for refitting (e.g., 'f1' or 'f1_macro'). - - Returns: - Altair chart showing the metric comparison. - - """ - # Determine color scale based on parameter - if color_param in ["eps_cl", "eps_e"]: - color_scale = alt.Scale(type="log", scheme="viridis") - else: - color_scale = alt.Scale(scheme="viridis") - - chart = ( - alt.Chart(results) - .mark_circle(size=60, opacity=0.7) - .encode( - x=alt.X(f"mean_test_{x_metric}:Q", title=format_metric_name(x_metric), scale=alt.Scale(zero=False)), - y=alt.Y(f"mean_test_{y_metric}:Q", title=format_metric_name(y_metric), scale=alt.Scale(zero=False)), - color=alt.Color( - f"{color_param}:Q", - scale=color_scale, - title=color_param, - ), - tooltip=[ - alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=format_metric_name(x_metric)), - alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=format_metric_name(y_metric)), - alt.Tooltip(f"mean_test_{refit_metric}:Q", format=".4f", title=format_metric_name(refit_metric)), - alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"), - alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"), - alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"), - ], - ) - .properties( - width=400, - height=400, - ) - .interactive() - ) - - return chart - - -def _load_training_data(grid: str, level: int, task: str): - """Load training data for analysis. - - Args: - grid: Grid type (hex or healpix). - level: Grid level. - task: Classification task type (binary or multi). - - Returns: - Tuple of (GeoDataFrame with training data, feature DataFrame, labels Series, label names list). - - """ - # Use create_xy_data to get X and y data - data, x_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task) - - # Convert y_data to string labels for visualization - if task == "binary": - y_data_str = y_data.map({False: "No RTS", True: "RTS"}) - else: - # For multi-class, reconstruct the interval labels from the codes - y_data_counts = data.loc[x_data.index, "darts_rts_count"] - n_categories = 5 - bins = pd.qcut(y_data_counts, q=n_categories, duplicates="drop").unique().categories - bins = pd.IntervalIndex.from_tuples( - [(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins] - ) - y_data_str = pd.cut(y_data_counts, bins=bins).astype(str) - - # Create a GeoDataFrame with geometry and labels - training_gdf = data.loc[x_data.index, ["geometry"]].copy() - training_gdf["target_class"] = y_data_str - - return training_gdf, x_data, y_data_str, labels - - -def _plot_training_data_map(training_gdf: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap): - """Create an interactive map of training data with properly sorted classes. - - Args: - training_gdf: GeoDataFrame with target classes (intervals as strings or category names). - folium_cmap: Matplotlib ListedColormap object (for geopandas.explore). - - Returns: - Folium map with training data colored by class. - - """ - # Create a copy to avoid modifying the original data - training_plot = training_gdf.copy() - - # Replace the special (-1, 0] interval with "No RTS" if present - training_plot["target_class"] = training_plot["target_class"].replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(class_str.split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - # Get unique classes and sort them - unique_classes = sorted(training_plot["target_class"].unique(), key=sort_key) - - # Create categorical with ordered categories - training_plot["target_class"] = pd.Categorical( - training_plot["target_class"], categories=unique_classes, ordered=True - ) - - # Select tiles based on theme - tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron" - return training_plot.explore(column="target_class", cmap=folium_cmap, legend=True, tiles=tiles) - - -def _plot_training_class_distribution(y_data: pd.Series, altair_colors: list[str]): - """Create a bar chart showing the count of each class in training data. - - Args: - y_data: Series with target classes. - altair_colors: List of hex color strings for altair. - - Returns: - Altair chart showing class distribution. - - """ - # Create a copy and apply the same transformations as the map - y_data_plot = y_data.copy() - y_data_plot = y_data_plot.replace("(-1, 0]", "No RTS") - - # Sort the classes: "No RTS" first, then by the lower bound of intervals - def sort_key(class_str): - if class_str == "No RTS": - return -1 # Put "No RTS" first - # Parse interval string like "(0, 4]" or "(4, 36]" - try: - lower = float(class_str.split(",")[0].strip("([ ")) - return lower - except (ValueError, IndexError): - return float("inf") # Put unparseable values at the end - - df = pd.DataFrame({"target_class": y_data_plot.to_numpy()}) - counts = df["target_class"].value_counts().reset_index() - counts.columns = ["class", "count"] - counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2) - - # Sort counts by the same key - counts["sort_key"] = counts["class"].apply(sort_key) - counts = counts.sort_values("sort_key") - - # Create an ordered list of classes for consistent color mapping - class_order = counts["class"].tolist() - - chart = ( - alt.Chart(counts) - .mark_bar() - .encode( - x=alt.X("class:N", title="Target Class", sort=class_order, axis=alt.Axis(labelAngle=0)), - y=alt.Y("count:Q", title="Number of Samples"), - color=alt.Color( - "class:N", - title="Class", - scale=alt.Scale(domain=class_order, range=altair_colors), - legend=None, - ), - tooltip=[ - alt.Tooltip("class:N", title="Class"), - alt.Tooltip("count:Q", title="Count"), - alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"), - ], - ) - .properties( - width=400, - height=300, - title="Training Data Class Distribution", - ) - ) - - return chart - - -def main(): - """Run Streamlit dashboard application.""" - st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") - - st.title("Training Analysis Dashboard") - st.markdown("Interactive visualization of RandomizedSearchCV results") - - # Sidebar for file and parameter selection - st.sidebar.header("Configuration") - - # Get available result files - result_dirs = get_available_result_files() - - if not result_dirs: - st.error(f"No result files found in {RESULTS_DIR}") - st.info("Please run a random CV search first to generate results.") - return - - # Directory selection - dir_options = {_parse_results_dir_name(f): f for f in result_dirs} - selected_dir_name = st.sidebar.selectbox( - "Select Result Directory", - options=list(dir_options.keys()), - help="Choose a search result directory to visualize", - ) - results_dir = dir_options[selected_dir_name] - - # Load and prepare data with default bin width (will be reloaded with custom width later) - with st.spinner("Loading results..."): - settings = toml.load(results_dir / "search_settings.toml")["settings"] - results = load_and_prepare_results(results_dir / "search_results.parquet", settings, k_bin_width=40) - model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc") - n_features = model_state.sizes["feature"] - model_state["feature_weights"] *= n_features - embedding_feature_array = extract_embedding_features(model_state) - era5_feature_array = extract_era5_features(model_state) - common_feature_array = extract_common_features(model_state) - predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") - - # Determine task type and available metrics - task = settings.get("task", "binary") - available_metrics = settings.get("metrics", ["accuracy", "recall", "precision", "f1", "jaccard"]) - refit_metric = "f1" if task == "binary" else "f1_macro" - - # Generate unified colormaps once for all visualizations - matplotlib_cmap, folium_cmap, altair_colors = generate_unified_colormap(settings) - - st.sidebar.success(f"Loaded {len(results)} results") - st.sidebar.info(f"Task: {task.capitalize()} Classification") - # Dump the settings into the sidebar - with st.sidebar.expander("Search Settings", expanded=True): - st.json(settings) - - # Display some basic statistics first (lightweight) - st.header("Parameter-Search Overview") - - # Show total runs and best model info - col1, col2 = st.columns(2) - with col1: - st.metric("Total Runs", len(results)) - - with col2: - # Best model based on refit metric - best_idx = results[f"mean_test_{refit_metric}"].idxmax() - st.metric(f"Best Model Index (by {format_metric_name(refit_metric)})", f"#{best_idx}") - - # Show best parameters for the best model - best_params = results.loc[ - best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{refit_metric}", f"std_test_{refit_metric}"] - ] - - with st.container(border=True): - st.subheader(":abacus: Best Model Parameters") - st.caption(f"Parameters of retrained best model (selected by {format_metric_name(refit_metric)} score)") - col1, col2, col3 = st.columns(3) - with col1: - st.metric("initial_K", f"{best_params['initial_K']:.0f}") - with col2: - st.metric("eps_cl", f"{best_params['eps_cl']:.2e}") - with col3: - st.metric("eps_e", f"{best_params['eps_e']:.2e}") - - # Show all metrics for the best model in a container - st.subheader(":bar_chart: Performance Across All Metrics") - st.caption( - f"Complete performance profile of the best model (selected by {format_metric_name(refit_metric)} score)" - ) - - cols = st.columns(len(available_metrics)) - - for idx, metric in enumerate(available_metrics): - with cols[idx]: - best_score = results.loc[best_idx, f"mean_test_{metric}"] - best_std = results.loc[best_idx, f"std_test_{metric}"] - # Highlight refit metric since that's what we optimized for - st.metric( - format_metric_name(metric), - f"{best_score:.4f}", - delta=f"±{best_std:.4f}", - help="Mean ± std across cross-validation folds", - ) - - # Create tabs for different visualizations - tab0, tab1, tab2, tab3 = st.tabs(["Training Data", "Search Results", "Model State", "Inference Analysis"]) - - with tab0: - # Training Data Analysis - st.header("Training Data Analysis") - st.markdown("Comprehensive analysis of the training dataset used for model development") - - # Load training data - with st.spinner("Loading training data..."): - training_gdf, X_data, y_data, _ = _load_training_data( - grid=settings["grid"], level=settings["level"], task=task - ) - - # Summary statistics - st.subheader("Dataset Statistics") - - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric("Total Samples", f"{len(training_gdf):,}") - - with col2: - st.metric("Number of Features", f"{X_data.shape[1]:,}") - - with col3: - n_classes = y_data.nunique() - st.metric("Number of Classes", n_classes) - - with col4: - missing_pct = X_data.isnull().sum().sum() / (X_data.shape[0] * X_data.shape[1]) * 100 - st.metric("Missing Values", f"{missing_pct:.2f}%") - - # Class distribution visualization - st.subheader("Class Distribution") - - with st.spinner("Generating class distribution..."): - class_dist_chart = _plot_training_class_distribution(y_data, altair_colors) - st.altair_chart(class_dist_chart, use_container_width=True) - - st.markdown( - """ - **Interpretation:** - - Shows the balance between different classes in the training dataset - - Class imbalance affects model learning and may require special handling - - Each bar represents the count of training samples for that class - """ - ) - - # Interactive map - st.subheader("Interactive Training Data Map") - st.markdown("Explore the spatial distribution of training samples by class") - - with st.spinner("Generating interactive map..."): - training_map = _plot_training_data_map(training_gdf, folium_cmap) - st_folium.st_folium(training_map, width="100%", height=600, returned_objects=[]) - - # Additional statistics in expander - with st.expander("Detailed Training Data Statistics"): - st.write("**Class Distribution:**") - class_counts = y_data.value_counts().sort_index() - - # Create columns for better layout - n_cols = min(5, len(class_counts)) - cols = st.columns(n_cols) - - for idx, (class_label, count) in enumerate(class_counts.items()): - percentage = count / len(y_data) * 100 - with cols[idx % n_cols]: - st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)") - - # Show detailed table - st.write("**Detailed Class Breakdown:**") - class_df = pd.DataFrame( - { - "Class": class_counts.index, - "Count": class_counts.to_numpy(), - "Percentage": (class_counts.to_numpy() / len(y_data) * 100).round(2), - } - ) - st.dataframe(class_df, width="stretch", hide_index=True) - - # Feature statistics - st.write("**Feature Statistics:**") - st.markdown(f"- Total number of features: **{X_data.shape[1]}**") - st.markdown(f"- Features with missing values: **{X_data.isnull().any().sum()}**") - - # Show feature types breakdown - feature_types = X_data.columns.to_series().apply(lambda x: x.split("_")[0]).value_counts() - st.write("**Feature Type Distribution:**") - feature_type_df = pd.DataFrame( - { - "Feature Type": feature_types.index, - "Count": feature_types.to_numpy(), - } - ) - st.dataframe(feature_type_df, width="stretch", hide_index=True) - - with tab1: - # Metric selection - only used in this tab - metric_display_names = {metric: format_metric_name(metric) for metric in available_metrics} - selected_metric_display = st.selectbox( - "Select Metric", - options=[format_metric_name(m) for m in available_metrics], - help="Choose which metric to visualize", - ) - # Convert back to raw metric name - selected_metric = next(k for k, v in metric_display_names.items() if v == selected_metric_display) - - # Show best parameters - with st.expander("Best Parameters"): - best_idx = results[f"mean_test_{selected_metric}"].idxmax() - best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] - st.dataframe(best_params.to_frame().T, width="content") - - # Main plots - st.header(f"Visualization for {format_metric_name(selected_metric)}") - - # K-binned plot configuration - @st.fragment - def render_k_binned_plots(): - col_toggle, col_slider = st.columns([1, 1]) - - with col_toggle: - # Percentile normalization toggle for K-binned plots - use_percentile = st.toggle( - "Use Percentile Normalization", - value=True, - help="Apply percentile-based color normalization to K-binned parameter space plots", - ) - - with col_slider: - # Bin width slider for K-binned plots - k_min = int(results["initial_K"].min()) - k_max = int(results["initial_K"].max()) - k_range = k_max - k_min - - k_bin_width = st.slider( - "Initial K Bin Width", - min_value=10, - max_value=max(100, k_range // 2), - value=40, - step=10, - help=f"Width of bins for initial_K facets (range: {k_min}-{k_max})", - ) - - # Show estimated number of bins - estimated_bins = int(np.ceil(k_range / k_bin_width)) - st.caption(f"Creating approximately {estimated_bins} bins for initial_K") - - # Reload data if bin width changed from default - results_binned = results - if k_bin_width != 40: - with st.spinner("Re-binning data..."): - results_binned = load_and_prepare_results( - results_dir / "search_results.parquet", settings, k_bin_width=k_bin_width - ) - - # K-binned plots - col1, col2 = st.columns(2) - - with col1: - st.subheader("K-Binned Parameter Space (Mean)") - with st.spinner("Generating mean plot..."): - if use_percentile: - chart1 = _plot_k_binned(results_binned, f"mean_test_{selected_metric}", vmin_percentile=50) - else: - chart1 = _plot_k_binned(results_binned, f"mean_test_{selected_metric}") - st.altair_chart(chart1, use_container_width=True) - - with col2: - st.subheader("K-Binned Parameter Space (Std)") - with st.spinner("Generating std plot..."): - if use_percentile: - chart2 = _plot_k_binned(results_binned, f"std_test_{selected_metric}", vmax_percentile=50) - else: - chart2 = _plot_k_binned(results_binned, f"std_test_{selected_metric}") - st.altair_chart(chart2, use_container_width=True) - - # Epsilon-binned plots - col1, col2 = st.columns(2) - - with col1: - st.subheader(f"K vs {selected_metric} (binned by eps_e)") - with st.spinner("Generating plot..."): - chart3 = _plot_params_binned( - results_binned, "initial_K", "eps_cl", "eps_e_binned", f"mean_test_{selected_metric}" - ) - st.altair_chart(chart3, use_container_width=True) - - st.subheader(f"eps_e vs {selected_metric} (binned by K)") - with st.spinner("Generating plot..."): - chart4 = _plot_params_binned( - results_binned, "eps_e", "eps_cl", "initial_K_binned", f"mean_test_{selected_metric}" - ) - st.altair_chart(chart4, use_container_width=True) - - with col2: - st.subheader(f"K vs {selected_metric} (binned by eps_cl)") - with st.spinner("Generating plot..."): - chart5 = _plot_params_binned( - results_binned, "initial_K", "eps_e", "eps_cl_binned", f"mean_test_{selected_metric}" - ) - st.altair_chart(chart5, use_container_width=True) - - st.subheader(f"eps_cl vs {selected_metric} (binned by K)") - with st.spinner("Generating plot..."): - chart6 = _plot_params_binned( - results_binned, "eps_cl", "eps_e", "initial_K_binned", f"mean_test_{selected_metric}" - ) - st.altair_chart(chart6, use_container_width=True) - - render_k_binned_plots() - - # Metric comparison plots - st.header("Metric Comparisons") - - @st.fragment - def render_metric_comparisons(): - # Color parameter selection - color_param = st.selectbox( - "Select Color Parameter", - options=["initial_K", "eps_cl", "eps_e"], - help="Choose which parameter to use for coloring the scatter plots", - ) - - # Dynamically determine which metrics to compare based on available metrics - if task == "binary": - # For binary: show recall vs precision and accuracy vs jaccard - comparisons = [ - ("precision", "recall", "Recall vs Precision"), - ("accuracy", "jaccard", "Accuracy vs Jaccard"), - ] - else: - # For multiclass: show micro vs macro variants - comparisons = [ - ("precision_macro", "recall_macro", "Recall Macro vs Precision Macro"), - ("accuracy", "jaccard_macro", "Accuracy vs Jaccard Macro"), - ] - - col1, col2 = st.columns(2) - - with col1: - st.subheader(comparisons[0][2]) - with st.spinner(f"Generating {comparisons[0][2]} plot..."): - chart1 = _plot_metric_comparison( - results, comparisons[0][0], comparisons[0][1], color_param, refit_metric - ) - st.altair_chart(chart1, use_container_width=True) - - with col2: - st.subheader(comparisons[1][2]) - with st.spinner(f"Generating {comparisons[1][2]} plot..."): - chart2 = _plot_metric_comparison( - results, comparisons[1][0], comparisons[1][1], color_param, refit_metric - ) - st.altair_chart(chart2, use_container_width=True) - - render_metric_comparisons() - - # Optional: Raw data table - with st.expander("View Raw Results Data"): - st.dataframe(results, width="stretch") - - with tab2: - # Model state visualization - st.header("Best Estimator Model State") - - # Show basic model state info - with st.expander("Model State Information"): - st.write(f"**Variables:** {list(model_state.data_vars)}") - st.write(f"**Dimensions:** {dict(model_state.sizes)}") - st.write(f"**Coordinates:** {list(model_state.coords)}") - - # Show statistics - st.write("**Feature Weight Statistics:**") - feature_weights = model_state["feature_weights"].to_pandas() - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Mean Weight", f"{feature_weights.mean():.4f}") - with col2: - st.metric("Max Weight", f"{feature_weights.max():.4f}") - with col3: - st.metric("Total Features", len(feature_weights)) - - # Feature importance plot - st.subheader("Feature Importance") - st.markdown("The most important features based on learned feature weights from the best estimator.") - - @st.fragment - def render_feature_importance(): - # Slider to control number of features to display - top_n = st.slider( - "Number of top features to display", - min_value=5, - max_value=50, - value=10, - step=5, - help="Select how many of the most important features to visualize", - ) - - with st.spinner("Generating feature importance plot..."): - feature_chart = _plot_top_features(model_state, top_n=top_n) - st.altair_chart(feature_chart, use_container_width=True) - - st.markdown( - """ - **Interpretation:** - - **Magnitude**: Larger absolute values indicate more important features - """ - ) - - render_feature_importance() - - # Box-to-Label Assignment Visualization - st.subheader("Box-to-Label Assignments") - st.markdown( - """ - This visualization shows how the learned boxes (prototypes in feature space) are - assigned to different class labels. The ESPA classifier learns K boxes and assigns - them to classes through the Lambda matrix. Higher values indicate stronger assignment - of a box to a particular class. - """ - ) - - with st.spinner("Generating box assignment visualizations..."): - col1, col2 = st.columns([0.7, 0.3]) - - with col1: - st.markdown("### Assignment Heatmap") - box_assignment_heatmap = _plot_box_assignments(model_state) - st.altair_chart(box_assignment_heatmap, use_container_width=True) - - with col2: - st.markdown("### Box Count by Class") - box_assignment_bars = _plot_box_assignment_bars(model_state, altair_colors) - st.altair_chart(box_assignment_bars, use_container_width=True) - - # Show statistics - with st.expander("Box Assignment Statistics"): - box_assignments = model_state["box_assignments"].to_pandas() - st.write("**Assignment Matrix Statistics:**") - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Boxes", len(box_assignments.columns)) - with col2: - st.metric("Number of Classes", len(box_assignments.index)) - with col3: - st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}") - with col4: - st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}") - - # Show which boxes are most strongly assigned to each class - st.write("**Top Box Assignments per Class:**") - for class_label in box_assignments.index: - top_boxes = box_assignments.loc[class_label].nlargest(5) - st.write( - f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} " - f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})" - ) - - st.markdown( - """ - **Interpretation:** - - Each box can be assigned to multiple classes with different strengths - - Boxes with higher assignment values for a class contribute more to that class's predictions - - The distribution shows how the model partitions the feature space for classification - """ - ) - - # Embedding features analysis (if present) - if embedding_feature_array is not None: - with st.container(border=True): - st.header(":artificial_satellite: Embedding Feature Analysis") - st.markdown( - """ - Analysis of embedding features showing which aggregations, bands, and years - are most important for the model predictions. - """ - ) - - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating dimension summaries..."): - chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array) - col1, col2, col3 = st.columns(3) - with col1: - st.altair_chart(chart_agg, use_container_width=True) - with col2: - st.altair_chart(chart_band, use_container_width=True) - with col3: - st.altair_chart(chart_year, use_container_width=True) - - # Detailed heatmap - st.markdown("### Detailed Heatmap by Aggregation") - st.markdown("Shows the weight of each band-year combination for each aggregation type.") - with st.spinner("Generating heatmap..."): - heatmap_chart = _plot_embedding_heatmap(embedding_feature_array) - st.altair_chart(heatmap_chart, use_container_width=True) - - # Statistics - with st.expander("Embedding Feature Statistics"): - st.write("**Overall Statistics:**") - n_emb_features = embedding_feature_array.size - mean_weight = float(embedding_feature_array.mean().values) - max_weight = float(embedding_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total Embedding Features", n_emb_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - - # Show top embedding features - st.write("**Top 10 Embedding Features:**") - emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() - top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] - st.dataframe(top_emb, width="stretch") - else: - st.info("No embedding features found in this model.") - - # ERA5 features analysis (if present) - if era5_feature_array is not None: - with st.container(border=True): - st.header(":partly_sunny: ERA5 Feature Analysis") - st.markdown( - """ - Analysis of ERA5 climate features showing which variables and time periods - are most important for the model predictions. - """ - ) - - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating ERA5 dimension summaries..."): - chart_variable, chart_time = _plot_era5_summary(era5_feature_array) - col1, col2 = st.columns(2) - with col1: - st.altair_chart(chart_variable, use_container_width=True) - with col2: - st.altair_chart(chart_time, use_container_width=True) - - # Detailed heatmap - st.markdown("### Detailed Heatmap") - st.markdown("Shows the weight of each variable-time combination.") - with st.spinner("Generating ERA5 heatmap..."): - era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array) - st.altair_chart(era5_heatmap_chart, use_container_width=True) - - # Statistics - with st.expander("ERA5 Feature Statistics"): - st.write("**Overall Statistics:**") - n_era5_features = era5_feature_array.size - mean_weight = float(era5_feature_array.mean().values) - max_weight = float(era5_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total ERA5 Features", n_era5_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - - # Show top ERA5 features - st.write("**Top 10 ERA5 Features:**") - era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() - top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] - st.dataframe(top_era5, width="stretch") - else: - st.info("No ERA5 features found in this model.") - - # Common features analysis (if present) - if common_feature_array is not None: - with st.container(border=True): - st.header(":world_map: Common Feature Analysis") - st.markdown( - """ - Analysis of common features including cell area, water area, land area, land ratio, - longitude, and latitude. These features provide spatial and geographic context. - """ - ) - - # Bar chart showing all common feature weights - with st.spinner("Generating common features chart..."): - common_chart = _plot_common_features(common_feature_array) - st.altair_chart(common_chart, use_container_width=True) - - # Statistics - with st.expander("Common Feature Statistics"): - st.write("**Overall Statistics:**") - n_common_features = common_feature_array.size - mean_weight = float(common_feature_array.mean().values) - max_weight = float(common_feature_array.max().values) - min_weight = float(common_feature_array.min().values) - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Common Features", n_common_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - with col4: - st.metric("Min Weight", f"{min_weight:.4f}") - - # Show all common features sorted by importance - st.write("**All Common Features (by absolute weight):**") - common_df = common_feature_array.to_dataframe(name="weight").reset_index() - common_df["abs_weight"] = common_df["weight"].abs() - common_df = common_df.sort_values("abs_weight", ascending=False) - st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") - - st.markdown( - """ - **Interpretation:** - - **cell_area, water_area, land_area**: Spatial extent features that may indicate - size-related patterns - - **land_ratio**: Proportion of land vs water in each cell - - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns - - Positive weights indicate features that increase the probability of the positive class - - Negative weights indicate features that decrease the probability of the positive class - """ - ) - else: - st.info("No common features found in this model.") - - with tab3: - # Inference analysis - st.header("Inference Analysis") - st.markdown("Comprehensive analysis of model predictions on the evaluation dataset") - - # Summary statistics - st.subheader("Prediction Statistics") - col1, col2 = st.columns(2) - - with col1: - total_cells = len(predictions) - st.metric("Total Cells", f"{total_cells:,}") - - with col2: - n_classes = predictions["predicted_class"].nunique() - st.metric("Number of Classes", n_classes) - - # Class distribution and static map side by side - st.subheader("Prediction Overview") - - col_map, col_dist = st.columns([0.6, 0.4]) - - with col_map: - st.markdown("#### Static Prediction Map") - st.markdown("High-quality map with North Polar Stereographic projection") - - with st.spinner("Generating static map..."): - static_fig = _plot_prediction_map_static(predictions, matplotlib_cmap) - st.pyplot(static_fig, width="stretch") - - st.markdown( - """ - **Map Features:** - - North Polar Azimuthal Equidistant projection (optimized for Arctic) - - Predictions colored by class using inferno colormap - - Includes coastlines and 50°N latitude boundary - """ - ) - - with col_dist: - st.markdown("#### Class Distribution") - - with st.spinner("Generating class distribution..."): - class_dist_chart = _plot_prediction_class_distribution(predictions, altair_colors) - st.altair_chart(class_dist_chart, use_container_width=True) - - st.markdown( - """ - **Interpretation:** - - Balance between predicted classes - - Class imbalance may indicate regional patterns - - Each bar shows cell count per class - """ - ) - - # Additional statistics in expander - with st.expander("Detailed Prediction Statistics"): - st.write("**Class Distribution:**") - class_counts = predictions["predicted_class"].value_counts().sort_index() - - # Create columns for better layout - n_cols = min(5, len(class_counts)) - cols = st.columns(n_cols) - - for idx, (class_label, count) in enumerate(class_counts.items()): - percentage = count / len(predictions) * 100 - with cols[idx % n_cols]: - st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)") - - # Show detailed table - st.write("**Detailed Class Breakdown:**") - class_df = pd.DataFrame( - { - "Class": class_counts.index, - "Count": class_counts.to_numpy(), - "Percentage": (class_counts.to_numpy() / len(predictions) * 100).round(2), - } - ) - st.dataframe(class_df, width="stretch", hide_index=True) - - # Interactive map - st.subheader("Interactive Prediction Map") - st.markdown("Explore predictions spatially with the interactive map below") - - # with st.spinner("Generating interactive map..."): - # chart_map = _plot_prediction_map(predictions, folium_colors) - # st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[]) - st.text("Interactive map functionality is currently disabled.") - - st.balloons() - - -if __name__ == "__main__": - main()