diff --git a/.github/agents/Dashboard.agent.md b/.github/agents/Dashboard.agent.md
index b9c1d71..5915dc2 100644
--- a/.github/agents/Dashboard.agent.md
+++ b/.github/agents/Dashboard.agent.md
@@ -1,56 +1,54 @@
---
-description: 'Specialized agent for developing and enhancing the Streamlit dashboard for data and training analysis.'
-name: Dashboard-Developer
-argument-hint: 'Describe dashboard features, pages, visualizations, or improvements you want to add or modify'
-tools: ['edit', 'runNotebooks', 'search', 'runCommands', 'usages', 'problems', 'changes', 'testFailure', 'fetch', 'githubRepo', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'ms-toolsai.jupyter/configureNotebook', 'ms-toolsai.jupyter/listNotebookPackages', 'ms-toolsai.jupyter/installNotebookPackages', 'todos', 'runSubagent', 'runTests']
+description: Develop and refactor Streamlit dashboard pages and visualizations
+name: Dashboard
+argument-hint: Describe dashboard features, pages, or visualizations to add or modify
+tools: ['vscode', 'execute', 'read', 'edit', 'search', 'web', 'agent', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'todo']
+model: Claude Sonnet 4.5
+infer: true
---
# Dashboard Development Agent
-You are a specialized agent for incrementally developing and enhancing the **Entropice Streamlit Dashboard** used to analyze geospatial machine learning data and training experiments.
+You specialize in developing and refactoring the **Entropice Streamlit Dashboard** for geospatial machine learning analysis.
-## Your Responsibilities
+## Scope
-### What You Should Do
+**You can edit:** Files in `src/entropice/dashboard/` only
+**You cannot edit:** Data pipeline scripts, training code, or configuration files
-1. **Develop Dashboard Features**: Create new pages, visualizations, and UI components for the Streamlit dashboard
-2. **Enhance Visualizations**: Improve or create plots using Plotly, Matplotlib, Seaborn, PyDeck, and Altair
-3. **Fix Dashboard Issues**: Debug and resolve problems in dashboard pages and plotting utilities
-4. **Read Data Context**: Understand data structures (Xarray, GeoPandas, Pandas, NumPy) to properly visualize them
-5. **Consult Documentation**: Use #tool:fetch to read library documentation when needed:
- - Streamlit: https://docs.streamlit.io/
- - Plotly: https://plotly.com/python/
- - PyDeck: https://deckgl.readthedocs.io/
- - Deck.gl: https://deck.gl/
- - Matplotlib: https://matplotlib.org/
- - Seaborn: https://seaborn.pydata.org/
- - Xarray: https://docs.xarray.dev/
- - GeoPandas: https://geopandas.org/
- - Pandas: https://pandas.pydata.org/pandas-docs/
- - NumPy: https://numpy.org/doc/stable/
+**Primary reference:** Always consult `views/overview_page.py` for current code patterns
-6. **Understand Data Sources**: Read data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`, `dataset.py`, `training.py`, `inference.py`) to understand data structures—but **NEVER edit them**
+## Responsibilities
-### What You Should NOT Do
+### ✅ What You Do
-1. **Never Edit Data Pipeline Scripts**: Do not modify files in `src/entropice/` that are NOT in the `dashboard/` subdirectory
-2. **Never Edit Training Scripts**: Do not modify `training.py`, `dataset.py`, or any model-related code outside the dashboard
-3. **Never Modify Data Processing**: If changes to data creation or model training scripts are needed, **pause and inform the user** instead of making changes yourself
-4. **Never Edit Configuration Files**: Do not modify `pyproject.toml`, pipeline scripts in `scripts/`, or configuration files
+- Create/refactor dashboard pages in `views/`
+- Build visualizations using Plotly, Matplotlib, Seaborn, PyDeck, Altair
+- Fix dashboard bugs and improve UI/UX
+- Create utility functions in `utils/` and `plots/`
+- Read (but never edit) data pipeline code to understand data structures
+- Use #tool:web to fetch library documentation:
+ - Streamlit: https://docs.streamlit.io/
+ - Plotly: https://plotly.com/python/
+ - PyDeck: https://deckgl.readthedocs.io/
+ - Xarray: https://docs.xarray.dev/
+ - GeoPandas: https://geopandas.org/
-### Boundaries
+### ❌ What You Don't Do
-If you identify that a dashboard improvement requires changes to:
-- Data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`)
-- Dataset assembly (`dataset.py`)
-- Model training (`training.py`, `inference.py`)
-- Pipeline automation scripts (`scripts/*.sh`)
+- Edit files outside `src/entropice/dashboard/`
+- Modify data pipeline (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`)
+- Change training code (`training.py`, `dataset.py`, `inference.py`)
+- Edit configuration (`pyproject.toml`, `scripts/*.sh`)
+
+### When to Stop
+
+If a dashboard feature requires changes outside `dashboard/`, stop and inform:
-**Stop immediately** and inform the user:
```
-⚠️ This dashboard feature requires changes to the data pipeline/training code.
-Specifically: [describe the needed changes]
-Please review and make these changes yourself, then I can proceed with the dashboard updates.
+⚠️ This requires changes to [file/module]
+Needed: [describe changes]
+Please make these changes first, then I can update the dashboard.
```
## Dashboard Structure
@@ -60,23 +58,28 @@ The dashboard is located in `src/entropice/dashboard/` with the following struct
```
dashboard/
├── app.py # Main Streamlit app with navigation
-├── overview_page.py # Overview of training results
-├── training_data_page.py # Training data visualizations
-├── training_analysis_page.py # CV results and hyperparameter analysis
-├── model_state_page.py # Feature importance and model state
-├── inference_page.py # Spatial prediction visualizations
+├── views/ # Dashboard pages
+│ ├── overview_page.py # Overview of training results and dataset analysis
+│ ├── training_data_page.py # Training data visualizations (needs refactoring)
+│ ├── training_analysis_page.py # CV results and hyperparameter analysis (needs refactoring)
+│ ├── model_state_page.py # Feature importance and model state (needs refactoring)
+│ └── inference_page.py # Spatial prediction visualizations (needs refactoring)
├── plots/ # Reusable plotting utilities
-│ ├── colors.py # Color schemes
│ ├── hyperparameter_analysis.py
│ ├── inference.py
│ ├── model_state.py
│ ├── source_data.py
│ └── training_data.py
-└── utils/ # Data loading and processing
- ├── data.py
- └── training.py
+└── utils/ # Data loading and processing utilities
+ ├── loaders.py # Data loaders (training results, grid data, predictions)
+ ├── stats.py # Dataset statistics computation and caching
+ ├── colors.py # Color palette management
+ ├── formatters.py # Display formatting utilities
+ └── unsembler.py # Dataset ensemble utilities
```
+**Note:** Currently only `overview_page.py` has been refactored to follow the new patterns. Other pages need updating to match this structure.
+
## Key Technologies
- **Streamlit**: Web app framework
@@ -120,6 +123,79 @@ When working with Entropice data:
3. **Training Results**: Pickled models, Parquet/NetCDF CV results
4. **Predictions**: GeoDataFrames with predicted classes/probabilities
+### Dashboard Code Patterns
+
+**Follow these patterns when developing or refactoring dashboard pages:**
+
+1. **Modular Render Functions**: Break pages into focused render functions
+ ```python
+ def render_sample_count_overview():
+ """Render overview of sample counts per task+target+grid+level combination."""
+ # Implementation
+
+ def render_feature_count_section():
+ """Render the feature count section with comparison and explorer."""
+ # Implementation
+ ```
+
+2. **Use `@st.fragment` for Interactive Components**: Isolate reactive UI elements
+ ```python
+ @st.fragment
+ def render_feature_count_explorer():
+ """Render interactive detailed configuration explorer using fragments."""
+ # Interactive selectboxes and checkboxes that re-run independently
+ ```
+
+3. **Cached Data Loading via Utilities**: Use centralized loaders from `utils/loaders.py`
+ ```python
+ from entropice.dashboard.utils.loaders import load_all_training_results
+ from entropice.dashboard.utils.stats import load_all_default_dataset_statistics
+
+ training_results = load_all_training_results() # Cached via @st.cache_data
+ all_stats = load_all_default_dataset_statistics() # Cached via @st.cache_data
+ ```
+
+4. **Consistent Color Palettes**: Use `get_palette()` from `utils/colors.py`
+ ```python
+ from entropice.dashboard.utils.colors import get_palette
+
+ task_colors = get_palette("task_types", n_colors=n_tasks)
+ source_colors = get_palette("data_sources", n_colors=n_sources)
+ ```
+
+5. **Type Hints and Type Casting**: Use types from `entropice.utils.types`
+ ```python
+ from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
+
+ selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
+ selected_members: list[L2SourceDataset] = []
+ ```
+
+6. **Tab-Based Organization**: Use tabs to organize complex visualizations
+ ```python
+ tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
+ with tab1:
+ # Heatmap visualization
+ with tab2:
+ # Bar chart visualization
+ ```
+
+7. **Layout with Columns**: Use columns for metrics and side-by-side content
+ ```python
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Total Features", f"{total_features:,}")
+ with col2:
+ st.metric("Data Sources", len(selected_members))
+ ```
+
+8. **Comprehensive Docstrings**: Document render functions clearly
+ ```python
+ def render_training_results_summary(training_results):
+ """Render summary metrics for training results."""
+ # Implementation
+ ```
+
### Visualization Guidelines
1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps
@@ -127,50 +203,79 @@ When working with Entropice data:
3. **Distributions**: Use Plotly or Seaborn
4. **Feature Importance**: Use Plotly bar charts
5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates
+6. **Heatmaps**: Use `px.imshow()` with color palettes from `get_palette()`
+7. **Interactive Tables**: Use `st.dataframe()` with `width='stretch'` and formatting
+
+### Key Utility Modules
+
+**`utils/loaders.py`**: Data loading with Streamlit caching
+- `load_all_training_results()`: Load all training result directories
+- `load_training_result(path)`: Load specific training result
+- `TrainingResult` dataclass: Structured training result data
+
+**`utils/stats.py`**: Dataset statistics computation
+- `load_all_default_dataset_statistics()`: Load/compute stats for all grid configs
+- `DatasetStatistics` class: Statistics per grid configuration
+- `MemberStatistics` class: Statistics per L2 source dataset
+- `TargetStatistics` class: Statistics per target dataset
+- Helper methods: `get_sample_count_df()`, `get_feature_count_df()`, `get_feature_breakdown_df()`
+
+**`utils/colors.py`**: Consistent color palette management
+- `get_palette(variable, n_colors)`: Get color palette by semantic variable name
+- `get_cmap(variable)`: Get matplotlib colormap
+- "Refactor training_data_page.py to match the patterns in overview_page.py"
+- "Add a new tab to the overview page showing temporal statistics"
+- "Create a reusable plotting function in plots/ for feature importance"
+- Uses pypalettes material design palettes with deterministic mapping
+
+**`utils/formatters.py`**: Display formatting utilities
+- `ModelDisplayInfo`: Model name formatting
+- `TaskDisplayInfo`: Task name formatting
+- `TrainingResultDisplayInfo`: Training result display names
## Workflow
-1. **Understand the Request**: Clarify what visualization or feature is needed
-2. **Search for Context**: Use #tool:search to find relevant dashboard code and data structures
-3. **Read Data Pipeline**: If needed, read (but don't edit) data pipeline scripts to understand data formats
-4. **Consult Documentation**: Use #tool:fetch for library documentation when needed
-5. **Implement Changes**: Edit dashboard files only
-6. **Test Assumptions**: Check for errors with #tool:problems after edits
-7. **Track Progress**: Use #tool:todos for multi-step dashboard development
+1. Check `views/overview_page.py` for current patterns
+2. Use #tool:search to find relevant code and data structures
+3. Read data pipeline code if needed (read-only)
+4. Leverage existing utilities from `utils/`
+5. Use #tool:web to fetch documentation when needed
+6. Implement changes following overview_page.py patterns
+7. Use #tool:todo for multi-step tasks
-## Example Interactions
+## Refactoring Checklist
-### ✅ Good Requests (Within Scope)
+When updating pages to match new patterns:
-- "Add a new page to visualize feature correlations"
-- "Create a PyDeck map showing RTS predictions by grid cell"
-- "Improve the hyperparameter analysis plot to show confidence intervals"
-- "Add a Plotly histogram showing the distribution of RTS density"
-- "Fix the deprecation warning about use_container_width"
+1. Move to `views/` subdirectory
+2. Use cached loaders from `utils/loaders.py` and `utils/stats.py`
+3. Split into focused `render_*()` functions
+4. Wrap interactive UI with `@st.fragment`
+5. Replace hardcoded colors with `get_palette()`
+6. Add type hints from `entropice.utils.types`
+7. Organize with tabs for complex views
+8. Use `width='stretch'` for charts/tables
+9. Add comprehensive docstrings
+10. Reference `overview_page.py` patterns
-### ⚠️ Boundary Cases (Requires User Approval)
+## Example Tasks
-User: "Add a new climate variable to the dashboard"
-Agent Response:
-```
-⚠️ This requires changes to the data pipeline (era5.py) to extract the new variable.
-Please add the variable to the ERA5 processing pipeline first, then I can add it to the dashboard visualizations.
-```
+**✅ In Scope:**
+- "Add feature correlation heatmap to overview page"
+- "Create PyDeck map for RTS predictions"
+- "Refactor training_data_page.py to match overview_page.py patterns"
+- "Fix use_container_width deprecation warnings"
+- "Add temporal statistics tab"
-## Progress Reporting
+**⚠️ Out of Scope:**
+- "Add new climate variable" → Requires changes to `era5.py`
+- "Change training metrics" → Requires changes to `training.py`
+- "Modify grid generation" → Requires changes to `grids.py`
-For complex dashboard development tasks:
+## Key Reminders
-1. Use #tool:todos to create a task list
-2. Mark tasks as in-progress before starting
-3. Mark completed immediately after finishing
-4. Keep the user informed of progress
-
-## Remember
-
-- **Read-only for data pipeline**: You can read any file to understand data structures, but only edit `dashboard/` files
-- **Documentation first**: When unsure about Streamlit/Plotly/PyDeck APIs, fetch documentation
-- **Modern Streamlit API**: Always use `width='stretch'` instead of `use_container_width=True`
-- **Pause when needed**: If data pipeline changes are required, stop and inform the user
-
-You are here to make the dashboard better, not to change how data is created or models are trained. Stay within these boundaries and you'll be most helpful!
+- Only edit files in `dashboard/`
+- Use `width='stretch'` not `use_container_width=True`
+- Always reference `overview_page.py` for patterns
+- Use #tool:web for documentation
+- Use #tool:todo for complex multi-step work
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
index f856fe3..7bd5a5d 100644
--- a/.github/copilot-instructions.md
+++ b/.github/copilot-instructions.md
@@ -1,226 +1,70 @@
-# Entropice - GitHub Copilot Instructions
+# Entropice - Copilot Instructions
-## Project Overview
+## Project Context
-Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. The system processes multi-source geospatial data, aggregates it into discrete global grids (H3/HEALPix), and trains probabilistic classifiers to estimate RTS occurrence patterns.
+This is a geospatial machine learning system for predicting Arctic permafrost degradation (Retrogressive Thaw Slumps) using entropy-optimal Scalable Probabilistic Approximations (eSPA). The system processes multi-source geospatial data through discrete global grid systems (H3, HEALPix) and trains probabilistic classifiers.
-For detailed architecture information, see [ARCHITECTURE.md](../ARCHITECTURE.md).
-For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
-For project goals and setup, see [README.md](../README.md).
+## Code Style
-## Core Technologies
+- Follow PEP 8 conventions with 120 character line length
+- Use type hints for all function signatures
+- Use google-style docstrings for public functions
+- Keep functions focused and modular
+- Use `ruff` for linting and formatting, `ty` for type checking
-- **Python**: 3.13 (strict version requirement)
-- **Package Manager**: [Pixi](https://pixi.sh/) (not pip/conda directly)
-- **GPU**: CUDA 12 with RAPIDS (CuPy, cuML)
-- **Geospatial**: xarray, xdggs, GeoPandas, H3, Rasterio
-- **ML**: scikit-learn, XGBoost, entropy (eSPA), cuML
-- **Storage**: Zarr, Icechunk, Parquet, NetCDF
-- **Visualization**: Streamlit, Bokeh, Matplotlib, Cartopy
+## Technology Stack
-## Code Style Guidelines
+- **Core**: Python 3.13, NumPy, Pandas, Xarray, GeoPandas
+- **Spatial**: H3, xdggs, xvec for discrete global grid systems
+- **ML**: scikit-learn, XGBoost, cuML, entropy (eSPA)
+- **GPU**: CuPy, PyTorch, CUDA 12 - prefer GPU-accelerated operations
+- **Storage**: Zarr, Icechunk, Parquet for intermediate data
+- **CLI**: Cyclopts with dataclass-based configurations
-### Python Standards
+## Execution Guidelines
-- Follow **PEP 8** conventions
-- Use **type hints** for all function signatures
-- Write **numpy-style docstrings** for public functions
-- Keep functions **focused and modular**
-- Prefer descriptive variable names over abbreviations
+- Always use `pixi run` to execute Python commands and scripts
+- Environment variables: `SCIPY_ARRAY_API=1`, `FAST_DATA_DIR=./data`
-### Geospatial Best Practices
+## Geospatial Best Practices
-- Use **EPSG:3413** (Arctic Stereographic) for computations
-- Use **EPSG:4326** (WGS84) for visualization and library compatibility
-- Store gridded data using **xarray with XDGGS** indexing
-- Store tabular data as **Parquet**, array data as **Zarr**
-- Leverage **Dask** for lazy evaluation of large datasets
-- Use **GeoPandas** for vector operations
-- Handle **antimeridian** correctly for polar regions
+- Use EPSG:3413 (Arctic Stereographic) for computations
+- Use EPSG:4326 (WGS84) for visualization and compatibility
+- Store gridded data as XDGGS Xarray datasets (Zarr format)
+- Store tabular data as GeoParquet
+- Handle antimeridian issues in polar regions
+- Leverage Xarray/Dask for lazy evaluation and chunked processing
-### Data Pipeline Conventions
+## Architecture Patterns
-- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → `02alphaearth.sh` → `03era5.sh` → `04arcticdem.sh` → `05train.sh`
-- Each pipeline stage should produce **reproducible intermediate outputs**
-- Use `src/entropice/utils/paths.py` for consistent path management
-- Environment variable `FAST_DATA_DIR` controls data directory location (default: `./data`)
+- Modular CLI design: each module exposes standalone Cyclopts CLI
+- Configuration as code: use dataclasses for typed configs, TOML for hyperparameters
+- GPU acceleration: use CuPy for arrays, cuML for ML, batch processing for memory management
+- Data flow: Raw sources → Grid aggregation → L2 datasets → Training → Inference → Visualization
-### Storage Hierarchy
+## Data Storage Hierarchy
-All data follows this structure:
```
DATA_DIR/
-├── grids/ # H3/HEALPix tessellations (GeoParquet)
-├── darts/ # RTS labels (GeoParquet)
-├── era5/ # Climate data (Zarr)
-├── arcticdem/ # Terrain data (Icechunk Zarr)
-├── alphaearth/ # Satellite embeddings (Zarr)
-├── datasets/ # L2 XDGGS datasets (Zarr)
-├── training-results/ # Models, CV results, predictions
-└── watermask/ # Ocean mask (GeoParquet)
+├── grids/ # H3/HEALPix tessellations (GeoParquet)
+├── darts/ # RTS labels (GeoParquet)
+├── era5/ # Climate data (Zarr)
+├── arcticdem/ # Terrain data (Icechunk Zarr)
+├── alphaearth/ # Satellite embeddings (Zarr)
+├── datasets/ # L2 XDGGS datasets (Zarr)
+└── training-results/ # Models, CV results, predictions
```
-## Module Organization
+## Key Modules
-### Core Modules (`src/entropice/`)
+- `entropice.spatial`: Grid generation and raster-to-vector aggregation
+- `entropice.ingest`: Data extractors (DARTS, ERA5, ArcticDEM, AlphaEarth)
+- `entropice.ml`: Dataset assembly, training, inference
+- `entropice.dashboard`: Streamlit visualization app
+- `entropice.utils`: Paths, codecs, types
-The codebase is organized into four main packages:
+## Testing & Notebooks
-- **`entropice.ingest`**: Data ingestion from external sources
-- **`entropice.spatial`**: Spatial operations and grid management
-- **`entropice.ml`**: Machine learning workflows
-- **`entropice.utils`**: Common utilities
-
-#### Data Ingestion (`src/entropice/ingest/`)
-
-- **`darts.py`**: RTS label extraction from DARTS v2 dataset
-- **`era5.py`**: Climate data processing from ERA5 (Arctic-aligned years: Oct 1 - Sep 30)
-- **`arcticdem.py`**: Terrain analysis from 32m Arctic elevation data
-- **`alphaearth.py`**: Satellite image embeddings via Google Earth Engine
-
-#### Spatial Operations (`src/entropice/spatial/`)
-
-- **`grids.py`**: H3/HEALPix spatial grid generation with watermask
-- **`aggregators.py`**: Raster-to-vector spatial aggregation engine
-- **`watermask.py`**: Ocean masking utilities
-- **`xvec.py`**: Extended vector operations for xarray
-
-#### Machine Learning (`src/entropice/ml/`)
-
-- **`dataset.py`**: Multi-source data integration and feature engineering
-- **`training.py`**: Model training with eSPA, XGBoost, Random Forest, KNN
-- **`inference.py`**: Batch prediction pipeline for trained models
-
-#### Utilities (`src/entropice/utils/`)
-
-- **`paths.py`**: Centralized path management
-- **`codecs.py`**: Custom codecs for data serialization
-
-### Dashboard (`src/entropice/dashboard/`)
-
-- Streamlit-based interactive visualization
-- Modular pages: overview, training data, analysis, model state, inference
-- Bokeh-based geospatial plotting utilities
-- Run with: `pixi run dashboard`
-
-### Scripts (`scripts/`)
-
-- Numbered pipeline scripts (`00grids.sh` through `05train.sh`)
-- Run entire pipeline for multiple grid configurations
-- Each script uses CLIs from core modules
-
-### Notebooks (`notebooks/`)
-
-- Exploratory analysis and validation
-- **NOT committed to git**
-- Keep production code in `src/entropice/`
-
-## Development Workflow
-
-### Setup
-
-```bash
-pixi install # NOT pip install or conda install
-```
-
-### Running Tests
-
-```bash
-pixi run pytest
-```
-
-### Running Python Commands
-
-Always use `pixi run` to execute Python commands to use the correct environment:
-
-```bash
-pixi run python script.py
-pixi run python -c "import entropice"
-```
-
-### Common Tasks
-
-**Important**: Always use `pixi run` prefix for Python commands to ensure correct environment.
-
-- **Generate grids**: Use `pixi run create-grid` or `spatial/grids.py` CLI
-- **Process labels**: Use `pixi run darts` or `ingest/darts.py` CLI
-- **Train models**: Use `pixi run train` with TOML config or `ml/training.py` CLI
-- **Run inference**: Use `ml/inference.py` CLI
-- **View results**: `pixi run dashboard`
-
-## Key Design Patterns
-
-### 1. XDGGS Indexing
-
-All geospatial data uses discrete global grid systems (H3 or HEALPix) via `xdggs` library for consistent spatial indexing across sources.
-
-### 2. Lazy Evaluation
-
-Use Xarray/Dask for out-of-core computation with Zarr/Icechunk chunked storage to manage large datasets.
-
-### 3. GPU Acceleration
-
-Prefer GPU-accelerated operations:
-- CuPy for array operations
-- cuML for Random Forest and KNN
-- XGBoost GPU training
-- PyTorch tensors when applicable
-
-### 4. Configuration as Code
-
-- Use dataclasses for typed configuration
-- TOML files for training hyperparameters
-- Cyclopts for CLI argument parsing
-
-## Data Sources
-
-- **DARTS v2**: RTS labels (year, area, count, density)
-- **ERA5**: Climate data (40-year history, Arctic-aligned years)
-- **ArcticDEM**: 32m resolution terrain (slope, aspect, indices)
-- **AlphaEarth**: 64-dimensional satellite embeddings
-- **Watermask**: Ocean exclusion layer
-
-## Model Support
-
-Primary: **eSPA** (entropy-optimal Scalable Probabilistic Approximations)
-Alternatives: XGBoost, Random Forest, K-Nearest Neighbors
-
-Training features:
-- Randomized hyperparameter search
-- K-Fold cross-validation
-- Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall)
-
-## Extension Points
-
-To extend Entropice:
-
-- **New data source**: Follow patterns in `ingest/era5.py` or `ingest/arcticdem.py`
-- **Custom aggregations**: Add to `_Aggregations` dataclass in `spatial/aggregators.py`
-- **Alternative labels**: Implement extractor following `ingest/darts.py` pattern
-- **New models**: Add scikit-learn compatible estimators to `ml/training.py`
-- **Dashboard pages**: Add Streamlit pages to `dashboard/` module
-
-## Important Notes
-
-- **Always use `pixi run` prefix** for Python commands (not plain `python`)
-- Grid resolutions: **H3** (3-6), **HEALPix** (6-10)
-- Arctic years run **October 1 to September 30** (not calendar years)
-- Handle **antimeridian crossing** in polar regions
-- Use **batch processing** for GPU memory management
-- Notebooks are for exploration only - **keep production code in `src/`**
-- Always use **absolute paths** or paths from `utils/paths.py`
-
-## Common Issues
-
-- **Memory**: Use batch processing and Dask chunking for large datasets
-- **GPU OOM**: Reduce batch size in inference or training
-- **Antimeridian**: Use proper handling in `spatial/aggregators.py` for polar grids
-- **Temporal alignment**: ERA5 uses Arctic-aligned years (Oct-Sep)
-- **CRS**: Compute in EPSG:3413, visualize in EPSG:4326
-
-## References
-
-For more details, consult:
-- [ARCHITECTURE.md](../ARCHITECTURE.md) - System architecture and design patterns
-- [CONTRIBUTING.md](../CONTRIBUTING.md) - Development workflow and standards
-- [README.md](../README.md) - Project goals and setup instructions
+- Production code belongs in `src/entropice/`, not notebooks
+- Notebooks in `notebooks/` are for exploration only (not version-controlled)
+- Use `pytest` for testing geospatial correctness and data integrity
diff --git a/.github/python.instructions.md b/.github/python.instructions.md
index 3ae0290..8b7a5d2 100644
--- a/.github/python.instructions.md
+++ b/.github/python.instructions.md
@@ -10,7 +10,7 @@ applyTo: '**/*.py,**/*.ipynb'
- Write clear and concise comments for each function.
- Ensure functions have descriptive names and include type hints.
- Provide docstrings following PEP 257 conventions.
-- Use the `typing` module for type annotations (e.g., `List[str]`, `Dict[str, int]`).
+- Use the `typing` module for advanced type annotations (e.g., `TypedDict`, `Literal["a", "b", ...]`).
- Break down complex functions into smaller, more manageable functions.
## General Instructions
@@ -27,7 +27,7 @@ applyTo: '**/*.py,**/*.ipynb'
- Follow the **PEP 8** style guide for Python.
- Maintain proper indentation (use 4 spaces for each level of indentation).
-- Ensure lines do not exceed 79 characters.
+- Ensure lines do not exceed 120 characters.
- Place function and class docstrings immediately after the `def` or `class` keyword.
- Use blank lines to separate functions, classes, and code blocks where appropriate.
@@ -41,6 +41,8 @@ applyTo: '**/*.py,**/*.ipynb'
## Example of Proper Documentation
```python
+import math
+
def calculate_area(radius: float) -> float:
"""
Calculate the area of a circle given the radius.
@@ -51,6 +53,5 @@ def calculate_area(radius: float) -> float:
Returns:
float: The area of the circle, calculated as π * radius^2.
"""
- import math
return math.pi * radius ** 2
```
diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md
index db19c12..04ffe73 100644
--- a/ARCHITECTURE.md
+++ b/ARCHITECTURE.md
@@ -27,12 +27,20 @@ The pipeline follows a sequential processing approach where each stage produces
### System Components
-The codebase is organized into four main packages:
+The project in general is organized into four directories:
+
+- **`src/entropice/`**: The main processing codebase
+- **`scripts/`**: Bash scripts and wrapper for data processing pipeline scripts
+- **`notebooks/`**: Exploratory analysis and validation notebooks (not commited to git)
+- **`tests/`**: Unit tests and manual validation scripts
+
+The processing codebase is organized into five main packages:
- **`entropice.ingest`**: Data ingestion from external sources (DARTS, ERA5, ArcticDEM, AlphaEarth)
- **`entropice.spatial`**: Spatial operations and grid management
- **`entropice.ml`**: Machine learning workflows (dataset, training, inference)
- **`entropice.utils`**: Common utilities (paths, codecs)
+- **`entropice.dashboard`**: Streamlit Dashboard for interactive visualization
#### 1. Spatial Grid System (`spatial/grids.py`)
@@ -179,6 +187,14 @@ scripts/05train.sh # Model training
- TOML files for training hyperparameters
- Environment-based path management (`utils/paths.py`)
+### 6. Geospatial Best Practices
+
+- Use **xarray** with XDGGS for gridded data storage
+- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays)
+- Leverage **Dask** for lazy evaluation of large datasets
+- Use **GeoPandas** for vector operations
+- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries
+
## Data Storage Hierarchy
```sh
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ec4fe38..5cdaa31 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -6,7 +6,6 @@ Thank you for your interest in contributing to Entropice! This document provides
### Prerequisites
-- Python 3.13
- CUDA 12 compatible GPU (for full functionality)
- [Pixi package manager](https://pixi.sh/)
@@ -16,55 +15,36 @@ Thank you for your interest in contributing to Entropice! This document provides
pixi install
```
-This will set up the complete environment including RAPIDS, PyTorch, and all geospatial dependencies.
+This will set up the complete environment including Python, RAPIDS, PyTorch, and all geospatial dependencies.
## Development Workflow
-**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies.
+> Read in the [Architecture Guide](ARCHITECTURE.md) about the code organisatoin and key modules
-### Code Organization
+**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies:
-- **`src/entropice/ingest/`**: Data ingestion modules (darts, era5, arcticdem, alphaearth)
-- **`src/entropice/spatial/`**: Spatial operations (grids, aggregators, watermask, xvec)
-- **`src/entropice/ml/`**: Machine learning components (dataset, training, inference)
-- **`src/entropice/utils/`**: Utilities (paths, codecs)
-- **`src/entropice/dashboard/`**: Streamlit visualization dashboard
-- **`scripts/`**: Data processing pipeline scripts (numbered 00-05)
-- **`notebooks/`**: Exploratory analysis and validation notebooks
-- **`tests/`**: Unit tests
+```bash
+pixi run python script.py
+pixi run python -c "import entropice"
+```
-### Key Modules
-
-- `spatial/grids.py`: H3/HEALPix spatial grid systems
-- `ingest/darts.py`, `ingest/era5.py`, `ingest/arcticdem.py`, `ingest/alphaearth.py`: Data source processors
-- `ml/dataset.py`: Dataset assembly and feature engineering
-- `ml/training.py`: Model training with eSPA, XGBoost, Random Forest, KNN
-- `ml/inference.py`: Prediction generation
-- `utils/paths.py`: Centralized path management
-
-## Coding Standards
-
-### Python Style
+### Python Style and Formatting
- Follow PEP 8 conventions
- Use type hints for function signatures
-- Prefer numpy-style docstrings for public functions
+- Prefer google-style docstrings for public functions
- Keep functions focused and modular
-### Geospatial Best Practices
+`ty` and `ruff` are used for typing, linting and formatting.
+Ensure to check for any warnings from both of these:
-- Use **xarray** with XDGGS for gridded data storage
-- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays)
-- Leverage **Dask** for lazy evaluation of large datasets
-- Use **GeoPandas** for vector operations
-- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries
+```sh
+pixi run ty check # For type checks
+pixi run ruff check # For linting
+pixi run ruff format # For formatting
+```
-### Data Pipeline
-
-- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → ... → `05train.sh`
-- Each stage should produce reproducible intermediate outputs
-- Document data dependencies in module docstrings
-- Use `utils/paths.py` for consistent path management
+Single files can be specified by just adding them to the command, e.g. `pixi run ty check src/entropice/dashboard/app.py`
## Testing
@@ -74,13 +54,6 @@ Run tests for specific modules:
pixi run pytest
```
-When running Python scripts or commands, always use `pixi run`:
-
-```bash
-pixi run python script.py
-pixi run python -c "import entropice"
-```
-
When adding features, include tests that verify:
- Correct handling of geospatial coordinates and projections
@@ -100,15 +73,8 @@ When adding features, include tests that verify:
### Commit Messages
- Use present tense: "Add feature" not "Added feature"
-- Reference issues when applicable: "Fix #123: Correct grid aggregation"
- Keep first line under 72 characters
-## Working with Data
-
-### Local Development
-
-- Set `FAST_DATA_DIR` environment variable for data directory (default: `./data`)
-
### Notebooks
- Notebooks in `notebooks/` are for exploration and validation, they are not commited to git
diff --git a/pyproject.toml b/pyproject.toml
index 1643f13..6f51207 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,6 +100,9 @@ cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" }
cuspatial-cu12 = { index = "nvidia" }
+[tool.ruff]
+line-length = 120
+
[tool.ruff.lint.pyflakes]
# Ignore libraries when checking for unused imports
allowed-unused-imports = [
diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py
index 9141f95..e63620f 100644
--- a/src/entropice/dashboard/plots/hyperparameter_analysis.py
+++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py
@@ -224,7 +224,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(value_counts)
.mark_bar(color=bar_color)
.encode(
- alt.X(f"{formatted_col}:N", title=param_name, sort=None),
+ alt.X(
+ f"{formatted_col}:N",
+ title=param_name,
+ sort=None,
+ ),
alt.Y("count:Q", title="Count"),
tooltip=[
alt.Tooltip(param_name, format=".2e"),
@@ -238,7 +242,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(value_counts)
.mark_bar(color=bar_color)
.encode(
- alt.X(f"{param_name}:Q", title=param_name, scale=x_scale),
+ alt.X(
+ f"{param_name}:Q",
+ title=param_name,
+ scale=x_scale,
+ ),
alt.Y("count:Q", title="Count"),
tooltip=[
alt.Tooltip(param_name, format=".3f"),
@@ -301,10 +309,18 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(df_plot)
.mark_bar(color=bar_color)
.encode(
- alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name),
+ alt.X(
+ f"{param_name}:Q",
+ bin=alt.Bin(maxbins=n_bins),
+ title=param_name,
+ ),
alt.Y("count()", title="Count"),
tooltip=[
- alt.Tooltip(f"{param_name}:Q", format=format_str, bin=True),
+ alt.Tooltip(
+ f"{param_name}:Q",
+ format=format_str,
+ bin=True,
+ ),
"count()",
],
)
@@ -396,9 +412,15 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
- tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
+ tooltip=[
+ alt.Tooltip(param_name, format=".2e"),
+ alt.Tooltip(metric, format=".4f"),
+ ],
+ )
+ .properties(
+ height=300,
+ title=f"{metric} vs {param_name} (log scale)",
)
- .properties(height=300, title=f"{metric} vs {param_name} (log scale)")
)
else:
chart = (
@@ -412,7 +434,10 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
- tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
+ tooltip=[
+ alt.Tooltip(param_name, format=".3f"),
+ alt.Tooltip(metric, format=".4f"),
+ ],
)
.properties(height=300, title=f"{metric} vs {param_name}")
)
@@ -568,7 +593,16 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
if len(param_names) == 2:
# Simple case: just one pair
x_param, y_param = param_names_sorted
- _render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500)
+ _render_2d_param_plot(
+ results_binned,
+ x_param,
+ y_param,
+ score_col,
+ bin_info,
+ hex_colors,
+ metric,
+ height=500,
+ )
else:
# Multiple parameters: create structured plots
st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
@@ -599,7 +633,14 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
with cols[col_idx]:
_render_2d_param_plot(
- results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350
+ results_binned,
+ x_param,
+ y_param,
+ score_col,
+ bin_info,
+ hex_colors,
+ metric,
+ height=350,
)
@@ -708,7 +749,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
# Get colormap for score evolution
evolution_cmap = get_cmap("score_evolution")
- evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))]
+ evolution_colors = [
+ mcolors.rgb2hex(evolution_cmap(0.3)),
+ mcolors.rgb2hex(evolution_cmap(0.7)),
+ ]
# Create line chart
chart = (
@@ -717,13 +761,21 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
.encode(
alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()),
- alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)),
+ alt.Color(
+ "Type",
+ legend=alt.Legend(title=""),
+ scale=alt.Scale(range=evolution_colors),
+ ),
strokeDash=alt.StrokeDash(
"Type",
legend=None,
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
),
- tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")],
+ tooltip=[
+ "Iteration",
+ "Type",
+ alt.Tooltip("value", format=".4f", title="Score"),
+ ],
)
.properties(height=400)
)
@@ -1195,7 +1247,13 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
with col1:
# Filter by confusion category
if task == "binary":
- categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"]
+ categories = [
+ "All",
+ "True Positive",
+ "False Positive",
+ "True Negative",
+ "False Negative",
+ ]
else:
categories = ["All", "Correct", "Incorrect"]
@@ -1206,7 +1264,14 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
)
with col2:
- opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity")
+ opacity = st.slider(
+ "Opacity",
+ min_value=0.1,
+ max_value=1.0,
+ value=0.7,
+ step=0.1,
+ key="confusion_map_opacity",
+ )
# Filter data if needed
if selected_category != "All":
diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py
index 2305fda..c0b6981 100644
--- a/src/entropice/dashboard/plots/model_state.py
+++ b/src/entropice/dashboard/plots/model_state.py
@@ -85,7 +85,11 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
.mark_rect()
.encode(
x=alt.X("year:O", title="Year"),
- y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
+ y=alt.Y(
+ "band:O",
+ title="Band",
+ sort=alt.SortField(field="band", order="ascending"),
+ ),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="blues"),
@@ -105,7 +109,9 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
return chart
-def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
+def plot_embedding_aggregation_summary(
+ embedding_array: xr.DataArray,
+) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
Args:
@@ -345,8 +351,18 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs()
# Create DataFrames
- df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
- df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values})
+ df_variable = pd.DataFrame(
+ {
+ "dimension": by_variable.index.astype(str),
+ "mean_abs_weight": by_variable.values,
+ }
+ )
+ df_season = pd.DataFrame(
+ {
+ "dimension": by_season.index.astype(str),
+ "mean_abs_weight": by_season.values,
+ }
+ )
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
# Sort by weight
@@ -358,7 +374,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_variable)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
+ y=alt.Y(
+ "dimension:N",
+ title="Variable",
+ sort="-x",
+ axis=alt.Axis(labelLimit=300),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -377,7 +398,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_season)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)),
+ y=alt.Y(
+ "dimension:N",
+ title="Season",
+ sort="-x",
+ axis=alt.Axis(labelLimit=200),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -396,7 +422,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_year)
.mark_bar()
.encode(
- y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)),
+ y=alt.Y(
+ "dimension:O",
+ title="Year",
+ sort="-x",
+ axis=alt.Axis(labelLimit=200),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -418,7 +449,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs()
# Create DataFrames, handling potential MultiIndex
- df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
+ df_variable = pd.DataFrame(
+ {
+ "dimension": by_variable.index.astype(str),
+ "mean_abs_weight": by_variable.values,
+ }
+ )
df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values})
# Sort by weight
@@ -430,7 +466,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_variable)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
+ y=alt.Y(
+ "dimension:N",
+ title="Variable",
+ sort="-x",
+ axis=alt.Axis(labelLimit=300),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -449,7 +490,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_time)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
+ y=alt.Y(
+ "dimension:N",
+ title="Time",
+ sort="-x",
+ axis=alt.Axis(labelLimit=200),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -508,7 +554,9 @@ def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart:
return chart
-def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
+def plot_arcticdem_summary(
+ arcticdem_array: xr.DataArray,
+) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ArcticDEM weights by variable and aggregation.
Args:
@@ -523,7 +571,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames, handling potential MultiIndex
- df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
+ df_variable = pd.DataFrame(
+ {
+ "dimension": by_variable.index.astype(str),
+ "mean_abs_weight": by_variable.values,
+ }
+ )
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
# Sort by weight
@@ -535,7 +588,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
alt.Chart(df_variable)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
+ y=alt.Y(
+ "dimension:N",
+ title="Variable",
+ sort="-x",
+ axis=alt.Axis(labelLimit=300),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -554,7 +612,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
alt.Chart(df_agg)
.mark_bar()
.encode(
- y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)),
+ y=alt.Y(
+ "dimension:N",
+ title="Aggregation",
+ sort="-x",
+ axis=alt.Axis(labelLimit=200),
+ ),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
@@ -665,7 +728,12 @@ def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str])
alt.Chart(counts)
.mark_bar()
.encode(
- x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
+ x=alt.X(
+ "class:N",
+ title="Class Label",
+ sort=class_order,
+ axis=alt.Axis(labelAngle=-45),
+ ),
y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color(
"class:N",
@@ -767,7 +835,10 @@ def plot_xgboost_feature_importance(
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
- x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"),
+ x=alt.X(
+ "importance:Q",
+ title=f"{importance_type.replace('_', ' ').title()} Importance",
+ ),
color=alt.value("steelblue"),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
@@ -880,7 +951,9 @@ def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.
return chart
-def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
+def plot_rf_tree_statistics(
+ model_state: xr.Dataset,
+) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Plot Random Forest tree statistics.
Args:
@@ -908,7 +981,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("steelblue"),
- tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")],
+ tooltip=[
+ alt.Tooltip("count()", title="Count"),
+ alt.Tooltip("value:Q", bin=True, title="Depth Range"),
+ ],
)
.properties(width=300, height=200, title="Distribution of Tree Depths")
)
@@ -921,7 +997,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("forestgreen"),
- tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")],
+ tooltip=[
+ alt.Tooltip("count()", title="Count"),
+ alt.Tooltip("value:Q", bin=True, title="Leaves Range"),
+ ],
)
.properties(width=300, height=200, title="Distribution of Leaf Counts")
)
@@ -934,7 +1013,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("darkorange"),
- tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")],
+ tooltip=[
+ alt.Tooltip("count()", title="Count"),
+ alt.Tooltip("value:Q", bin=True, title="Nodes Range"),
+ ],
)
.properties(width=300, height=200, title="Distribution of Node Counts")
)
diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py
index ca11da9..12a74c9 100644
--- a/src/entropice/dashboard/plots/source_data.py
+++ b/src/entropice/dashboard/plots/source_data.py
@@ -330,7 +330,10 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
with col3:
time_values = pd.to_datetime(ds["time"].values)
- st.metric("Time Steps", f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}")
+ st.metric(
+ "Time Steps",
+ f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}",
+ )
with col4:
if has_agg:
@@ -398,11 +401,15 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_var = st.selectbox(
- "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
+ "Select variable to visualize",
+ options=variables,
+ key=f"era5_{temporal_type}_var_select",
)
with col2:
selected_agg = st.selectbox(
- "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select"
+ "Aggregation",
+ options=ds["aggregations"].values,
+ key=f"era5_{temporal_type}_agg_select",
)
with col3:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
@@ -411,7 +418,9 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
col1, col2 = st.columns([3, 1])
with col1:
selected_var = st.selectbox(
- "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
+ "Select variable to visualize",
+ options=variables,
+ key=f"era5_{temporal_type}_var_select",
)
with col2:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
@@ -614,7 +623,10 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
- tooltip={"html": "Value: {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
+ tooltip={
+ "html": "Value: {value}",
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
@@ -746,7 +758,10 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
- tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
+ tooltip={
+ "html": tooltip_html,
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
@@ -784,7 +799,14 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
)
with col2:
- opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity")
+ opacity = st.slider(
+ "Opacity",
+ min_value=0.1,
+ max_value=1.0,
+ value=0.7,
+ step=0.1,
+ key="areas_map_opacity",
+ )
# Create GeoDataFrame
gdf = grid_gdf.copy()
@@ -896,7 +918,9 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
selected_agg = st.selectbox(
- "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
+ "Aggregation",
+ options=ds["aggregations"].values,
+ key=f"era5_{temporal_type}_agg",
)
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
@@ -1022,7 +1046,10 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
- tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
+ tooltip={
+ "html": tooltip_html,
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py
index 1148b58..73ee446 100644
--- a/src/entropice/dashboard/plots/training_data.py
+++ b/src/entropice/dashboard/plots/training_data.py
@@ -11,7 +11,9 @@ from entropice.dashboard.utils.colors import get_palette
from entropice.ml.dataset import CategoricalTrainingDataset
-def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]):
+def render_all_distribution_histograms(
+ train_data_dict: dict[str, CategoricalTrainingDataset],
+):
"""Render histograms for all three tasks side by side.
Args:
@@ -81,7 +83,13 @@ def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTra
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True,
- legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
+ legend={
+ "orientation": "h",
+ "yanchor": "bottom",
+ "y": 1.02,
+ "xanchor": "right",
+ "x": 1,
+ },
xaxis_title=None,
yaxis_title="Count",
xaxis={"tickangle": -45},
@@ -137,7 +145,10 @@ def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
elif color_mode == "split":
- split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange
+ split_colors = {
+ "train": [66, 135, 245],
+ "test": [245, 135, 66],
+ } # Blue # Orange
gdf["fill_color"] = gdf["split"].map(split_colors)
return gdf
@@ -168,7 +179,14 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
)
with col2:
- opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity")
+ opacity = st.slider(
+ "Opacity",
+ min_value=0.1,
+ max_value=1.0,
+ value=0.7,
+ step=0.1,
+ key="spatial_map_opacity",
+ )
# Determine which task dataset to use and color mode
if vis_mode == "split":
@@ -258,7 +276,10 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
# Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using elevation
view_state = pdk.ViewState(
- latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45
+ latitude=70,
+ longitude=0,
+ zoom=2 if not use_elevation else 1.5,
+ pitch=0 if not use_elevation else 45,
)
# Create deck
diff --git a/src/entropice/dashboard/utils/colors.py b/src/entropice/dashboard/utils/colors.py
index cd77d3e..97b3838 100644
--- a/src/entropice/dashboard/utils/colors.py
+++ b/src/entropice/dashboard/utils/colors.py
@@ -95,7 +95,9 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
return colors
-def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
+def generate_unified_colormap(
+ settings: dict,
+) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
"""Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot,
@@ -125,7 +127,10 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else:
- base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
+ base_colors = [
+ "#3498db",
+ "#e74c3c",
+ ] # Brighter blue and red for light theme
else:
# For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap
diff --git a/src/entropice/dashboard/utils/formatters.py b/src/entropice/dashboard/utils/formatters.py
index 042a490..5db49a4 100644
--- a/src/entropice/dashboard/utils/formatters.py
+++ b/src/entropice/dashboard/utils/formatters.py
@@ -19,7 +19,9 @@ class ModelDisplayInfo:
model_display_infos: dict[Model, ModelDisplayInfo] = {
"espa": ModelDisplayInfo(
- keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm"
+ keyword="espa",
+ short="eSPA",
+ long="entropy-optimal Scalable Probabilistic Approximations algorithm",
),
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index c7f0645..34233a1 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -135,7 +135,9 @@ def load_all_training_results() -> list[TrainingResult]:
return training_results
-def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]:
+def load_all_training_data(
+ e: DatasetEnsemble,
+) -> dict[Task, CategoricalTrainingDataset]:
"""Load training data for all three tasks.
Args:
diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py
index 2d35e2f..004c7f5 100644
--- a/src/entropice/dashboard/utils/stats.py
+++ b/src/entropice/dashboard/utils/stats.py
@@ -142,7 +142,9 @@ class DatasetStatistics:
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
@staticmethod
- def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ def get_sample_count_df(
+ all_stats: dict[GridLevel, "DatasetStatistics"],
+ ) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for grid_config in grid_configs:
@@ -164,7 +166,9 @@ class DatasetStatistics:
return pd.DataFrame(rows)
@staticmethod
- def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ def get_feature_count_df(
+ all_stats: dict[GridLevel, "DatasetStatistics"],
+ ) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for grid_config in grid_configs:
@@ -191,7 +195,9 @@ class DatasetStatistics:
return pd.DataFrame(rows)
@staticmethod
- def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
+ def get_feature_breakdown_df(
+ all_stats: dict[GridLevel, "DatasetStatistics"],
+ ) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for grid_config in grid_configs:
@@ -219,7 +225,10 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
target_statistics: dict[TargetDataset, TargetStatistics] = {}
for target in all_target_datasets:
target_statistics[target] = TargetStatistics.compute(
- grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
+ grid=grid_config.grid,
+ level=grid_config.level,
+ target=target,
+ total_cells=total_cells,
)
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
for member in all_l2_source_datasets:
@@ -357,7 +366,10 @@ class TrainingDatasetStatistics:
@classmethod
def compute(
- cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None
+ cls,
+ ensemble: DatasetEnsemble,
+ task: Task,
+ dataset: gpd.GeoDataFrame | None = None,
) -> "TrainingDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
diff --git a/src/entropice/dashboard/utils/unsembler.py b/src/entropice/dashboard/utils/unsembler.py
index 6f9e154..3cec383 100644
--- a/src/entropice/dashboard/utils/unsembler.py
+++ b/src/entropice/dashboard/utils/unsembler.py
@@ -65,7 +65,9 @@ def extract_embedding_features(
def extract_era5_features(
- model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None
+ model_state: xr.Dataset,
+ importance_type: str = "feature_weights",
+ temporal_group: str | None = None,
) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
@@ -100,7 +102,17 @@ def extract_era5_features(
- Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021")
"""
parts = feature.split("_")
- common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
+ common_aggs = {
+ "mean",
+ "std",
+ "min",
+ "max",
+ "median",
+ "sum",
+ "count",
+ "q25",
+ "q75",
+ }
# Find where the time part starts (after "era5" and variable name)
# Pattern: era5_variable_time or era5_variable_time_agg
@@ -166,7 +178,17 @@ def extract_era5_features(
def _extract_time_name(feature: str) -> str:
parts = feature.split("_")
- common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
+ common_aggs = {
+ "mean",
+ "std",
+ "min",
+ "max",
+ "median",
+ "sum",
+ "count",
+ "q25",
+ "q75",
+ }
if parts[-1] in common_aggs:
# Has aggregation: era5_var_time_agg -> time is second to last
return parts[-2]
@@ -179,11 +201,8 @@ def extract_era5_features(
Pattern: era5_variable_season_year_agg or era5_variable_season_year
"""
- parts = feature.split("_")
- common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
-
# Look through parts to find season/shoulder indicators
- for part in parts:
+ for part in feature.split("_"):
if part.lower() in ["summer", "winter"]:
return part.lower()
elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
@@ -197,11 +216,26 @@ def extract_era5_features(
For seasonal/shoulder features, find the year that comes after the season.
"""
parts = feature.split("_")
- common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
+ common_aggs = {
+ "mean",
+ "std",
+ "min",
+ "max",
+ "median",
+ "sum",
+ "count",
+ "q25",
+ "q75",
+ }
# Find the season/shoulder part, then the next part should be the year
for i, part in enumerate(parts):
- if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
+ if part.lower() in ["summer", "winter"] or part.upper() in [
+ "JFM",
+ "AMJ",
+ "JAS",
+ "OND",
+ ]:
# Next part should be the year
if i + 1 < len(parts):
next_part = parts[i + 1]
@@ -218,7 +252,17 @@ def extract_era5_features(
def _extract_agg_name(feature: str) -> str | None:
parts = feature.split("_")
- common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
+ common_aggs = {
+ "mean",
+ "std",
+ "min",
+ "max",
+ "median",
+ "sum",
+ "count",
+ "q25",
+ "q75",
+ }
if parts[-1] in common_aggs:
return parts[-1]
return None
@@ -255,7 +299,10 @@ def extract_era5_features(
if has_agg:
era5_features_array = era5_features_array.assign_coords(
- agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
+ agg=(
+ "feature",
+ [_extract_agg_name(f) or "none" for f in era5_features],
+ ),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack(
"feature"
@@ -274,7 +321,10 @@ def extract_era5_features(
if has_agg:
# Add aggregation dimension
era5_features_array = era5_features_array.assign_coords(
- agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
+ agg=(
+ "feature",
+ [_extract_agg_name(f) or "none" for f in era5_features],
+ ),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
else:
@@ -345,7 +395,14 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
Returns None if no common features are found.
"""
- common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
+ common_feature_names = [
+ "cell_area",
+ "water_area",
+ "land_area",
+ "land_ratio",
+ "lon",
+ "lat",
+ ]
def _is_common_feature(feature: str) -> bool:
return feature in common_feature_names
diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py
index 86958b2..20d1759 100644
--- a/src/entropice/dashboard/views/inference_page.py
+++ b/src/entropice/dashboard/views/inference_page.py
@@ -66,7 +66,10 @@ def render_inference_page():
with col4:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col5:
- st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", ""))
+ st.metric(
+ "Target",
+ selected_result.settings.get("target", "Unknown").replace("darts_", ""),
+ )
st.divider()
diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py
index d9676a3..55c0220 100644
--- a/src/entropice/dashboard/views/overview_page.py
+++ b/src/entropice/dashboard/views/overview_page.py
@@ -10,8 +10,16 @@ from stopuhr import stopwatch
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
-from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
-from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
+from entropice.dashboard.utils.stats import (
+ DatasetStatistics,
+ load_all_default_dataset_statistics,
+)
+from entropice.utils.types import (
+ GridConfig,
+ L2SourceDataset,
+ TargetDataset,
+ grid_configs,
+)
def render_sample_count_overview():
@@ -45,7 +53,12 @@ def render_sample_count_overview():
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
- pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
+ pivot_df = target_df.pivot_table(
+ index="Grid",
+ columns="Task",
+ values="Samples (Coverage)",
+ aggfunc="mean",
+ )
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
@@ -56,7 +69,11 @@ def render_sample_count_overview():
fig = px.imshow(
pivot_df,
- labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
+ labels={
+ "x": "Task",
+ "y": "Grid Configuration",
+ "color": "Sample Count",
+ },
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=sample_colors,
@@ -87,7 +104,10 @@ def render_sample_count_overview():
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
- labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
+ labels={
+ "Grid": "Grid Configuration",
+ "Samples (Coverage)": "Number of Samples",
+ },
color_discrete_sequence=task_colors,
height=500,
)
@@ -141,7 +161,10 @@ def render_feature_count_comparison():
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
- labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
+ labels={
+ "Grid": "Grid Configuration",
+ "Number of Features": "Number of Features",
+ },
color_discrete_sequence=source_colors,
text_auto=False,
)
@@ -162,7 +185,10 @@ def render_feature_count_comparison():
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
- labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
+ labels={
+ "Grid": "Grid Configuration",
+ "Inference Cells": "Number of Cells",
+ },
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
@@ -177,7 +203,10 @@ def render_feature_count_comparison():
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
- labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
+ labels={
+ "Grid": "Grid Configuration",
+ "Total Samples": "Number of Samples",
+ },
color_discrete_sequence=grid_colors,
text="Total Samples",
)
@@ -226,7 +255,13 @@ def render_feature_count_comparison():
# Display full comparison table with formatting
display_df = comparison_df[
- ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
+ [
+ "Grid",
+ "Total Features",
+ "Data Sources",
+ "Inference Cells",
+ "Total Samples",
+ ]
].copy()
# Format numbers with commas
@@ -314,7 +349,11 @@ def render_feature_count_explorer():
with col2:
# Calculate minimum cells across all data sources (for inference capability)
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
- st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
+ st.metric(
+ "Inference Cells",
+ f"{min_cells:,}",
+ help="Number of union of cells across all data sources",
+ )
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py
index 275d3c0..4523f4c 100644
--- a/src/entropice/dashboard/views/training_data_page.py
+++ b/src/entropice/dashboard/views/training_data_page.py
@@ -14,7 +14,10 @@ from entropice.dashboard.plots.source_data import (
render_era5_overview,
render_era5_plots,
)
-from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
+from entropice.dashboard.plots.training_data import (
+ render_all_distribution_histograms,
+ render_spatial_map,
+)
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import DatasetEnsemble
from entropice.spatial import grids
@@ -42,7 +45,10 @@ def render_training_data_page():
]
grid_level_combined = st.selectbox(
- "Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level"
+ "Grid Configuration",
+ options=grid_options,
+ index=0,
+ help="Select the grid system and resolution level",
)
# Parse grid type and level
@@ -60,7 +66,13 @@ def render_training_data_page():
# Members selection
st.subheader("Dataset Members")
- all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ all_members = [
+ "AlphaEarth",
+ "ArcticDEM",
+ "ERA5-yearly",
+ "ERA5-seasonal",
+ "ERA5-shoulder",
+ ]
selected_members = []
for member in all_members:
@@ -69,7 +81,10 @@ def render_training_data_page():
# Form submit button
load_button = st.form_submit_button(
- "Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0
+ "Load Dataset",
+ type="primary",
+ width="stretch",
+ disabled=len(selected_members) == 0,
)
# Create DatasetEnsemble only when form is submitted
diff --git a/src/entropice/ingest/alphaearth.py b/src/entropice/ingest/alphaearth.py
index 55098fe..35562d0 100644
--- a/src/entropice/ingest/alphaearth.py
+++ b/src/entropice/ingest/alphaearth.py
@@ -77,7 +77,20 @@ def download(grid: Grid, level: int):
for year in track(range(2024, 2025), total=1, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
- aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
+ aggs = [
+ "mean",
+ "stdDev",
+ "min",
+ "max",
+ "count",
+ "median",
+ "p1",
+ "p5",
+ "p25",
+ "p75",
+ "p95",
+ "p99",
+ ]
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
def extract_embedding(feature):
@@ -136,7 +149,20 @@ def combine_to_zarr(grid: Grid, level: int):
"""
cell_ids = grids.get_cell_ids(grid, level)
years = list(range(2018, 2025))
- aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
+ aggs = [
+ "mean",
+ "stdDev",
+ "min",
+ "max",
+ "count",
+ "median",
+ "p1",
+ "p5",
+ "p25",
+ "p75",
+ "p95",
+ "p99",
+ ]
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
a = xr.DataArray(
diff --git a/src/entropice/ingest/arcticdem.py b/src/entropice/ingest/arcticdem.py
index 7837111..51fce8e 100644
--- a/src/entropice/ingest/arcticdem.py
+++ b/src/entropice/ingest/arcticdem.py
@@ -125,8 +125,14 @@ def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
cs = 3600
# Calculate safe slice bounds for edge chunks
- y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d)
- x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d)
+ y_start, y_end = (
+ max(0, cs * chunk_loc[0] - d),
+ min(len(y), cs * chunk_loc[0] + cs + d),
+ )
+ x_start, x_end = (
+ max(0, cs * chunk_loc[1] - d),
+ min(len(x), cs * chunk_loc[1] + cs + d),
+ )
# Extract coordinate arrays with safe bounds
y_chunk = cp.asarray(y[y_start:y_end])
@@ -162,7 +168,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
if np.all(np.isnan(chunk)):
# Return an array of NaNs with the expected shape
return np.full(
- (12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px),
+ (
+ 12,
+ chunk.shape[0] - 2 * large_kernels.size_px,
+ chunk.shape[1] - 2 * large_kernels.size_px,
+ ),
np.nan,
dtype=np.float32,
)
@@ -171,7 +181,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
mask = cp.isnan(chunk)
- mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True)
+ mask &= ~binary_dilation(
+ binary_erosion(mask, iterations=3, brute_force=True),
+ iterations=3,
+ brute_force=True,
+ )
if cp.any(mask):
# Find indices of valid values
indices = distance_transform_edt(mask, return_distances=False, return_indices=True)
diff --git a/src/entropice/ingest/darts.py b/src/entropice/ingest/darts.py
index 19c6760..4e3dfd1 100644
--- a/src/entropice/ingest/darts.py
+++ b/src/entropice/ingest/darts.py
@@ -14,7 +14,12 @@ from rich.progress import track
from stopuhr import stopwatch
from entropice.spatial import grids
-from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
+from entropice.utils.paths import (
+ darts_ml_training_labels_repo,
+ dartsl2_cov_file,
+ dartsl2_file,
+ get_darts_rts_file,
+)
from entropice.utils.types import Grid
traceback.install()
diff --git a/src/entropice/ingest/era5.py b/src/entropice/ingest/era5.py
index 9859f65..667e43e 100644
--- a/src/entropice/ingest/era5.py
+++ b/src/entropice/ingest/era5.py
@@ -204,15 +204,36 @@ def download_daily_aggregated():
)
# Assign attributes
- daily_raw["t2m_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"}
- daily_raw["t2m_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"}
- daily_raw["t2m_mean"].attrs = {"long_name": "Daily mean 2 metre temperature", "units": "K"}
+ daily_raw["t2m_max"].attrs = {
+ "long_name": "Daily maximum 2 metre temperature",
+ "units": "K",
+ }
+ daily_raw["t2m_min"].attrs = {
+ "long_name": "Daily minimum 2 metre temperature",
+ "units": "K",
+ }
+ daily_raw["t2m_mean"].attrs = {
+ "long_name": "Daily mean 2 metre temperature",
+ "units": "K",
+ }
daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"}
daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"}
- daily_raw["lblt_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"}
- daily_raw["tp"].attrs = {"long_name": "Daily total precipitation", "units": "m"} # Units are rather m^3 / m^2
- daily_raw["sf"].attrs = {"long_name": "Daily total snow fall", "units": "m"} # Units are rather m^3 / m^2
- daily_raw["sshf"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"}
+ daily_raw["lblt_max"].attrs = {
+ "long_name": "Daily maximum lake ice bottom temperature",
+ "units": "K",
+ }
+ daily_raw["tp"].attrs = {
+ "long_name": "Daily total precipitation",
+ "units": "m",
+ } # Units are rather m^3 / m^2
+ daily_raw["sf"].attrs = {
+ "long_name": "Daily total snow fall",
+ "units": "m",
+ } # Units are rather m^3 / m^2
+ daily_raw["sshf"].attrs = {
+ "long_name": "Daily total surface sensible heat flux",
+ "units": "J/m²",
+ }
daily_raw = daily_raw.odc.assign_crs("epsg:4326")
daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"])
@@ -281,11 +302,17 @@ def daily_enrich():
# Formulas based on Groeke et. al. (2025) Stochastic Weather generation...
daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2
- daily.t2m_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"}
+ daily.t2m_avg.attrs = {
+ "long_name": "Daily average 2 metre temperature",
+ "units": "K",
+ }
_store("t2m_avg")
daily["t2m_range"] = daily.t2m_max - daily.t2m_min
- daily.t2m_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"}
+ daily.t2m_range.attrs = {
+ "long_name": "Daily range of 2 metre temperature",
+ "units": "K",
+ }
_store("t2m_range")
with np.errstate(invalid="ignore"):
@@ -298,7 +325,10 @@ def daily_enrich():
_store("thawing_degree_days")
daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0)
- daily.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"}
+ daily.freezing_degree_days.attrs = {
+ "long_name": "Freezing degree days",
+ "units": "K",
+ }
_store("freezing_degree_days")
daily["thawing_days"] = (daily.t2m_avg > 273.15).where(~daily.t2m_avg.isnull())
@@ -425,7 +455,12 @@ def multi_monthly_aggregate(agg: Literal["yearly", "seasonal", "shoulder"] = "ye
multimonthly_store = get_era5_stores(agg)
print(f"Saving empty multi-monthly ERA5 data to {multimonthly_store}.")
- multimonthly.to_zarr(multimonthly_store, mode="w", encoding=codecs.from_ds(multimonthly), consolidated=False)
+ multimonthly.to_zarr(
+ multimonthly_store,
+ mode="w",
+ encoding=codecs.from_ds(multimonthly),
+ consolidated=False,
+ )
def _store(var):
nonlocal multimonthly
@@ -512,14 +547,20 @@ def yearly_thaw_periods():
never_thaws = (daily.thawing_days.resample(time="12MS").sum(dim="time") == 0).persist()
first_thaw_day = daily.thawing_days.resample(time="12MS").map(_get_first_day).where(~never_thaws)
- first_thaw_day.attrs = {"long_name": "Day of first thaw in year", "units": "day of year"}
+ first_thaw_day.attrs = {
+ "long_name": "Day of first thaw in year",
+ "units": "day of year",
+ }
_store_in_yearly(first_thaw_day, "day_of_first_thaw")
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").first()
last_thaw_day = n_days_in_year - daily.thawing_days[::-1].resample(time="12MS").map(_get_first_day).where(
~never_thaws
)
- last_thaw_day.attrs = {"long_name": "Day of last thaw in year", "units": "day of year"}
+ last_thaw_day.attrs = {
+ "long_name": "Day of last thaw in year",
+ "units": "day of year",
+ }
_store_in_yearly(last_thaw_day, "day_of_last_thaw")
@@ -532,14 +573,20 @@ def yearly_freeze_periods():
never_freezes = (daily.freezing_days.resample(time="12MS").sum(dim="time") == 0).persist()
first_freeze_day = daily.freezing_days.resample(time="12MS").map(_get_first_day).where(~never_freezes)
- first_freeze_day.attrs = {"long_name": "Day of first freeze in year", "units": "day of year"}
+ first_freeze_day.attrs = {
+ "long_name": "Day of first freeze in year",
+ "units": "day of year",
+ }
_store_in_yearly(first_freeze_day, "day_of_first_freeze")
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").last()
last_freeze_day = n_days_in_year - daily.freezing_days[::-1].resample(time="12MS").map(_get_first_day).where(
~never_freezes
)
- last_freeze_day.attrs = {"long_name": "Day of last freeze in year", "units": "day of year"}
+ last_freeze_day.attrs = {
+ "long_name": "Day of last freeze in year",
+ "units": "day of year",
+ }
_store_in_yearly(last_freeze_day, "day_of_last_freeze")
@@ -728,7 +775,12 @@ def spatial_agg(
pxbuffer=10,
)
- aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
+ aggregated = aggregated.chunk(
+ {
+ "cell_ids": min(len(aggregated.cell_ids), 10000),
+ "time": len(aggregated.time),
+ }
+ )
store = get_era5_stores(agg, grid, level)
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py
index 0344e39..8959f0a 100644
--- a/src/entropice/ml/dataset.py
+++ b/src/entropice/ml/dataset.py
@@ -99,7 +99,14 @@ def bin_values(
"""
labels_dict = {
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
- "density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
+ "density": [
+ "Empty",
+ "Very Sparse",
+ "Sparse",
+ "Moderate",
+ "Dense",
+ "Very Dense",
+ ],
}
labels = labels_dict[task]
@@ -186,7 +193,13 @@ class DatasetEnsemble:
level: int
target: Literal["darts_rts", "darts_mllabels"]
members: list[L2SourceDataset] = field(
- default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
+ default_factory=lambda: [
+ "AlphaEarth",
+ "ArcticDEM",
+ "ERA5-yearly",
+ "ERA5-seasonal",
+ "ERA5-shoulder",
+ ]
)
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
variable_filters: dict[str, list[str]] = field(default_factory=dict)
@@ -277,7 +290,9 @@ class DatasetEnsemble:
return targets
def _prep_era5(
- self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
+ self,
+ targets: gpd.GeoDataFrame,
+ temporal: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame:
era5 = self._read_member("ERA5-" + temporal, targets)
era5_df = era5.to_dataframe()
@@ -352,7 +367,9 @@ class DatasetEnsemble:
print(f"=== Total number of features in dataset: {stats['total_features']}")
def create(
- self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r"
+ self,
+ filter_target_col: str | None = None,
+ cache_mode: Literal["n", "o", "r"] = "r",
) -> gpd.GeoDataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
@@ -438,7 +455,10 @@ class DatasetEnsemble:
yield dataset
def _cat_and_split(
- self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"]
+ self,
+ dataset: gpd.GeoDataFrame,
+ task: Task,
+ device: Literal["cpu", "cuda", "torch"],
) -> CategoricalTrainingDataset:
taskcol = self.taskcol(task)
diff --git a/src/entropice/ml/inference.py b/src/entropice/ml/inference.py
index 05c9dae..d285d65 100644
--- a/src/entropice/ml/inference.py
+++ b/src/entropice/ml/inference.py
@@ -20,7 +20,9 @@ set_config(array_api_dispatch=True)
def predict_proba(
- e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list
+ e: DatasetEnsemble,
+ clf: RandomForestClassifier | ESPAClassifier | XGBClassifier,
+ classes: list,
) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell.
diff --git a/src/entropice/spatial/aggregators.py b/src/entropice/spatial/aggregators.py
index d775e87..f6e6ba0 100644
--- a/src/entropice/spatial/aggregators.py
+++ b/src/entropice/spatial/aggregators.py
@@ -79,7 +79,19 @@ class _Aggregations:
def aggnames(self) -> list[str]:
if self._common:
- return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
+ return [
+ "mean",
+ "std",
+ "min",
+ "max",
+ "median",
+ "p1",
+ "p5",
+ "p25",
+ "p75",
+ "p95",
+ "p99",
+ ]
names = []
if self.mean:
names.append("mean")
@@ -105,7 +117,15 @@ class _Aggregations:
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
- quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here!
+ quantiles_to_compute = [
+ 0.5,
+ 0.01,
+ 0.05,
+ 0.25,
+ 0.75,
+ 0.95,
+ 0.99,
+ ] # ? Ordering is important here!
cell_data[4:, ...] = flattened_var.quantile(
q=quantiles_to_compute,
dim="z",
@@ -156,7 +176,11 @@ class _Aggregations:
return self._agg_cell_data_single(flattened)
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
- cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
+ cell_data = np.full(
+ (len(flattened.data_vars), len(self), *others_shape),
+ np.nan,
+ dtype=np.float32,
+ )
for i, var in enumerate(flattened.data_vars):
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
# Transform to numpy arrays
@@ -194,7 +218,9 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
@stopwatch("Correcting geometries", log=False)
-def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
+def _get_corrected_geoms(
+ inp: tuple[Polygon, odc.geo.geobox.GeoBox, str],
+) -> list[odc.geo.Geometry]:
geom, gbox, crs = inp
# cell.geometry is a shapely Polygon
if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
@@ -440,7 +466,12 @@ def _align_partition(
others_shape = tuple(
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
)
- ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
+ ongrid_shape = (
+ len(grid_partition_gdf),
+ len(raster.data_vars),
+ len(aggregations),
+ *others_shape,
+ )
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
@@ -455,7 +486,11 @@ def _align_partition(
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
dims = ["cell_ids", "variables", "aggregations"]
- coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
+ coords = {
+ "cell_ids": cell_ids,
+ "variables": list(raster.data_vars),
+ "aggregations": aggregations.aggnames(),
+ }
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
diff --git a/src/entropice/spatial/grids.py b/src/entropice/spatial/grids.py
index 3f99867..3a60d99 100644
--- a/src/entropice/spatial/grids.py
+++ b/src/entropice/spatial/grids.py
@@ -131,7 +131,9 @@ def create_global_hex_grid(resolution):
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
}
for future in track(
- as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells)
+ as_completed(future_to_hex),
+ description="Creating hex polygons...",
+ total=len(hex0_cells),
):
hex_batch, hex_id_batch, hex_area_batch = future.result()
hex_list.extend(hex_batch)
@@ -157,7 +159,10 @@ def create_global_hex_grid(resolution):
hex_area_list.append(hex_area)
# Create GeoDataFrame
- grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326")
+ grid = gpd.GeoDataFrame(
+ {"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
+ crs="EPSG:4326",
+ )
return grid
diff --git a/src/entropice/utils/codecs.py b/src/entropice/utils/codecs.py
index c2f0282..7078b94 100644
--- a/src/entropice/utils/codecs.py
+++ b/src/entropice/utils/codecs.py
@@ -5,7 +5,10 @@ from zarr.codecs import BloscCodec
def from_ds(
- ds: xr.Dataset, store_floats_as_float32: bool = True, include_coords: bool = True, filter_existing: bool = True
+ ds: xr.Dataset,
+ store_floats_as_float32: bool = True,
+ include_coords: bool = True,
+ filter_existing: bool = True,
) -> dict:
"""Create compression encoding for zarr dataset storage.
diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py
index eb7eb91..1db121c 100644
--- a/src/entropice/utils/types.py
+++ b/src/entropice/utils/types.py
@@ -4,7 +4,17 @@ from dataclasses import dataclass
from typing import Literal
type Grid = Literal["hex", "healpix"]
-type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"]
+type GridLevel = Literal[
+ "hex3",
+ "hex4",
+ "hex5",
+ "hex6",
+ "healpix6",
+ "healpix7",
+ "healpix8",
+ "healpix9",
+ "healpix10",
+]
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
diff --git a/tests/l2dataset_validation.py b/tests/l2dataset_validation.py
new file mode 100644
index 0000000..d41ae5a
--- /dev/null
+++ b/tests/l2dataset_validation.py
@@ -0,0 +1,83 @@
+from itertools import product
+
+import xarray as xr
+from rich.progress import track
+
+from entropice.utils.paths import (
+ get_arcticdem_stores,
+ get_embeddings_store,
+ get_era5_stores,
+)
+from entropice.utils.types import Grid, L2SourceDataset
+
+
+def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
+ """Validate if the L2 dataset exists for the given grid and level.
+
+ Args:
+ grid (Grid): The grid type to use.
+ level (int): The grid level to use.
+ l2ds (L2SourceDataset): The L2 source dataset to validate.
+
+ Returns:
+ bool: True if the dataset exists and does not contain NaNs, False otherwise.
+
+ """
+ if l2ds == "ArcticDEM":
+ store = get_arcticdem_stores(grid, level)
+ elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
+ agg = l2ds.split("-")[1]
+ store = get_era5_stores(agg, grid, level) # type: ignore
+ elif l2ds == "AlphaEarth":
+ store = get_embeddings_store(grid, level)
+ else:
+ raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
+
+ if not store.exists():
+ print("\t Dataset store does not exist")
+ return False
+
+ ds = xr.open_zarr(store, consolidated=False)
+ has_nan = False
+ for var in ds.data_vars:
+ n_nans = ds[var].isnull().sum().compute().item()
+ if n_nans > 0:
+ print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
+ has_nan = True
+ if has_nan:
+ return False
+ return True
+
+
+def main():
+ grid_levels: set[tuple[Grid, int]] = {
+ ("hex", 3),
+ ("hex", 4),
+ ("hex", 5),
+ ("hex", 6),
+ ("healpix", 6),
+ ("healpix", 7),
+ ("healpix", 8),
+ ("healpix", 9),
+ ("healpix", 10),
+ }
+ l2_source_datasets: list[L2SourceDataset] = [
+ "ArcticDEM",
+ "ERA5-shoulder",
+ "ERA5-seasonal",
+ "ERA5-yearly",
+ "AlphaEarth",
+ ]
+
+ for (grid, level), l2ds in track(
+ product(grid_levels, l2_source_datasets),
+ total=len(grid_levels) * len(l2_source_datasets),
+ description="Validating L2 datasets...",
+ ):
+ is_valid = validate_l2_dataset(grid, level, l2ds)
+ status = "VALID" if is_valid else "INVALID"
+ print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
+
+
+if __name__ == "__main__":
+ main()