Update docs, instructions and format code

This commit is contained in:
Tobias Hölzer 2026-01-04 17:19:02 +01:00
parent fca232da91
commit 4260b492ab
29 changed files with 987 additions and 467 deletions

View file

@ -1,56 +1,54 @@
--- ---
description: 'Specialized agent for developing and enhancing the Streamlit dashboard for data and training analysis.' description: Develop and refactor Streamlit dashboard pages and visualizations
name: Dashboard-Developer name: Dashboard
argument-hint: 'Describe dashboard features, pages, visualizations, or improvements you want to add or modify' argument-hint: Describe dashboard features, pages, or visualizations to add or modify
tools: ['edit', 'runNotebooks', 'search', 'runCommands', 'usages', 'problems', 'changes', 'testFailure', 'fetch', 'githubRepo', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'ms-toolsai.jupyter/configureNotebook', 'ms-toolsai.jupyter/listNotebookPackages', 'ms-toolsai.jupyter/installNotebookPackages', 'todos', 'runSubagent', 'runTests'] tools: ['vscode', 'execute', 'read', 'edit', 'search', 'web', 'agent', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'todo']
model: Claude Sonnet 4.5
infer: true
--- ---
# Dashboard Development Agent # Dashboard Development Agent
You are a specialized agent for incrementally developing and enhancing the **Entropice Streamlit Dashboard** used to analyze geospatial machine learning data and training experiments. You specialize in developing and refactoring the **Entropice Streamlit Dashboard** for geospatial machine learning analysis.
## Your Responsibilities ## Scope
### What You Should Do **You can edit:** Files in `src/entropice/dashboard/` only
**You cannot edit:** Data pipeline scripts, training code, or configuration files
1. **Develop Dashboard Features**: Create new pages, visualizations, and UI components for the Streamlit dashboard **Primary reference:** Always consult `views/overview_page.py` for current code patterns
2. **Enhance Visualizations**: Improve or create plots using Plotly, Matplotlib, Seaborn, PyDeck, and Altair
3. **Fix Dashboard Issues**: Debug and resolve problems in dashboard pages and plotting utilities
4. **Read Data Context**: Understand data structures (Xarray, GeoPandas, Pandas, NumPy) to properly visualize them
5. **Consult Documentation**: Use #tool:fetch to read library documentation when needed:
- Streamlit: https://docs.streamlit.io/
- Plotly: https://plotly.com/python/
- PyDeck: https://deckgl.readthedocs.io/
- Deck.gl: https://deck.gl/
- Matplotlib: https://matplotlib.org/
- Seaborn: https://seaborn.pydata.org/
- Xarray: https://docs.xarray.dev/
- GeoPandas: https://geopandas.org/
- Pandas: https://pandas.pydata.org/pandas-docs/
- NumPy: https://numpy.org/doc/stable/
6. **Understand Data Sources**: Read data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`, `dataset.py`, `training.py`, `inference.py`) to understand data structures—but **NEVER edit them** ## Responsibilities
### What You Should NOT Do ### ✅ What You Do
1. **Never Edit Data Pipeline Scripts**: Do not modify files in `src/entropice/` that are NOT in the `dashboard/` subdirectory - Create/refactor dashboard pages in `views/`
2. **Never Edit Training Scripts**: Do not modify `training.py`, `dataset.py`, or any model-related code outside the dashboard - Build visualizations using Plotly, Matplotlib, Seaborn, PyDeck, Altair
3. **Never Modify Data Processing**: If changes to data creation or model training scripts are needed, **pause and inform the user** instead of making changes yourself - Fix dashboard bugs and improve UI/UX
4. **Never Edit Configuration Files**: Do not modify `pyproject.toml`, pipeline scripts in `scripts/`, or configuration files - Create utility functions in `utils/` and `plots/`
- Read (but never edit) data pipeline code to understand data structures
- Use #tool:web to fetch library documentation:
- Streamlit: https://docs.streamlit.io/
- Plotly: https://plotly.com/python/
- PyDeck: https://deckgl.readthedocs.io/
- Xarray: https://docs.xarray.dev/
- GeoPandas: https://geopandas.org/
### Boundaries ### ❌ What You Don't Do
If you identify that a dashboard improvement requires changes to: - Edit files outside `src/entropice/dashboard/`
- Data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`) - Modify data pipeline (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`)
- Dataset assembly (`dataset.py`) - Change training code (`training.py`, `dataset.py`, `inference.py`)
- Model training (`training.py`, `inference.py`) - Edit configuration (`pyproject.toml`, `scripts/*.sh`)
- Pipeline automation scripts (`scripts/*.sh`)
### When to Stop
If a dashboard feature requires changes outside `dashboard/`, stop and inform:
**Stop immediately** and inform the user:
``` ```
⚠️ This dashboard feature requires changes to the data pipeline/training code. ⚠️ This requires changes to [file/module]
Specifically: [describe the needed changes] Needed: [describe changes]
Please review and make these changes yourself, then I can proceed with the dashboard updates. Please make these changes first, then I can update the dashboard.
``` ```
## Dashboard Structure ## Dashboard Structure
@ -60,23 +58,28 @@ The dashboard is located in `src/entropice/dashboard/` with the following struct
``` ```
dashboard/ dashboard/
├── app.py # Main Streamlit app with navigation ├── app.py # Main Streamlit app with navigation
├── overview_page.py # Overview of training results ├── views/ # Dashboard pages
├── training_data_page.py # Training data visualizations │ ├── overview_page.py # Overview of training results and dataset analysis
├── training_analysis_page.py # CV results and hyperparameter analysis │ ├── training_data_page.py # Training data visualizations (needs refactoring)
├── model_state_page.py # Feature importance and model state │ ├── training_analysis_page.py # CV results and hyperparameter analysis (needs refactoring)
├── inference_page.py # Spatial prediction visualizations │ ├── model_state_page.py # Feature importance and model state (needs refactoring)
│ └── inference_page.py # Spatial prediction visualizations (needs refactoring)
├── plots/ # Reusable plotting utilities ├── plots/ # Reusable plotting utilities
│ ├── colors.py # Color schemes
│ ├── hyperparameter_analysis.py │ ├── hyperparameter_analysis.py
│ ├── inference.py │ ├── inference.py
│ ├── model_state.py │ ├── model_state.py
│ ├── source_data.py │ ├── source_data.py
│ └── training_data.py │ └── training_data.py
└── utils/ # Data loading and processing └── utils/ # Data loading and processing utilities
├── data.py ├── loaders.py # Data loaders (training results, grid data, predictions)
└── training.py ├── stats.py # Dataset statistics computation and caching
├── colors.py # Color palette management
├── formatters.py # Display formatting utilities
└── unsembler.py # Dataset ensemble utilities
``` ```
**Note:** Currently only `overview_page.py` has been refactored to follow the new patterns. Other pages need updating to match this structure.
## Key Technologies ## Key Technologies
- **Streamlit**: Web app framework - **Streamlit**: Web app framework
@ -120,6 +123,79 @@ When working with Entropice data:
3. **Training Results**: Pickled models, Parquet/NetCDF CV results 3. **Training Results**: Pickled models, Parquet/NetCDF CV results
4. **Predictions**: GeoDataFrames with predicted classes/probabilities 4. **Predictions**: GeoDataFrames with predicted classes/probabilities
### Dashboard Code Patterns
**Follow these patterns when developing or refactoring dashboard pages:**
1. **Modular Render Functions**: Break pages into focused render functions
```python
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
# Implementation
def render_feature_count_section():
"""Render the feature count section with comparison and explorer."""
# Implementation
```
2. **Use `@st.fragment` for Interactive Components**: Isolate reactive UI elements
```python
@st.fragment
def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments."""
# Interactive selectboxes and checkboxes that re-run independently
```
3. **Cached Data Loading via Utilities**: Use centralized loaders from `utils/loaders.py`
```python
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import load_all_default_dataset_statistics
training_results = load_all_training_results() # Cached via @st.cache_data
all_stats = load_all_default_dataset_statistics() # Cached via @st.cache_data
```
4. **Consistent Color Palettes**: Use `get_palette()` from `utils/colors.py`
```python
from entropice.dashboard.utils.colors import get_palette
task_colors = get_palette("task_types", n_colors=n_tasks)
source_colors = get_palette("data_sources", n_colors=n_sources)
```
5. **Type Hints and Type Casting**: Use types from `entropice.utils.types`
```python
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
selected_members: list[L2SourceDataset] = []
```
6. **Tab-Based Organization**: Use tabs to organize complex visualizations
```python
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
with tab1:
# Heatmap visualization
with tab2:
# Bar chart visualization
```
7. **Layout with Columns**: Use columns for metrics and side-by-side content
```python
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Features", f"{total_features:,}")
with col2:
st.metric("Data Sources", len(selected_members))
```
8. **Comprehensive Docstrings**: Document render functions clearly
```python
def render_training_results_summary(training_results):
"""Render summary metrics for training results."""
# Implementation
```
### Visualization Guidelines ### Visualization Guidelines
1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps 1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps
@ -127,50 +203,79 @@ When working with Entropice data:
3. **Distributions**: Use Plotly or Seaborn 3. **Distributions**: Use Plotly or Seaborn
4. **Feature Importance**: Use Plotly bar charts 4. **Feature Importance**: Use Plotly bar charts
5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates 5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates
6. **Heatmaps**: Use `px.imshow()` with color palettes from `get_palette()`
7. **Interactive Tables**: Use `st.dataframe()` with `width='stretch'` and formatting
### Key Utility Modules
**`utils/loaders.py`**: Data loading with Streamlit caching
- `load_all_training_results()`: Load all training result directories
- `load_training_result(path)`: Load specific training result
- `TrainingResult` dataclass: Structured training result data
**`utils/stats.py`**: Dataset statistics computation
- `load_all_default_dataset_statistics()`: Load/compute stats for all grid configs
- `DatasetStatistics` class: Statistics per grid configuration
- `MemberStatistics` class: Statistics per L2 source dataset
- `TargetStatistics` class: Statistics per target dataset
- Helper methods: `get_sample_count_df()`, `get_feature_count_df()`, `get_feature_breakdown_df()`
**`utils/colors.py`**: Consistent color palette management
- `get_palette(variable, n_colors)`: Get color palette by semantic variable name
- `get_cmap(variable)`: Get matplotlib colormap
- "Refactor training_data_page.py to match the patterns in overview_page.py"
- "Add a new tab to the overview page showing temporal statistics"
- "Create a reusable plotting function in plots/ for feature importance"
- Uses pypalettes material design palettes with deterministic mapping
**`utils/formatters.py`**: Display formatting utilities
- `ModelDisplayInfo`: Model name formatting
- `TaskDisplayInfo`: Task name formatting
- `TrainingResultDisplayInfo`: Training result display names
## Workflow ## Workflow
1. **Understand the Request**: Clarify what visualization or feature is needed 1. Check `views/overview_page.py` for current patterns
2. **Search for Context**: Use #tool:search to find relevant dashboard code and data structures 2. Use #tool:search to find relevant code and data structures
3. **Read Data Pipeline**: If needed, read (but don't edit) data pipeline scripts to understand data formats 3. Read data pipeline code if needed (read-only)
4. **Consult Documentation**: Use #tool:fetch for library documentation when needed 4. Leverage existing utilities from `utils/`
5. **Implement Changes**: Edit dashboard files only 5. Use #tool:web to fetch documentation when needed
6. **Test Assumptions**: Check for errors with #tool:problems after edits 6. Implement changes following overview_page.py patterns
7. **Track Progress**: Use #tool:todos for multi-step dashboard development 7. Use #tool:todo for multi-step tasks
## Example Interactions ## Refactoring Checklist
### ✅ Good Requests (Within Scope) When updating pages to match new patterns:
- "Add a new page to visualize feature correlations" 1. Move to `views/` subdirectory
- "Create a PyDeck map showing RTS predictions by grid cell" 2. Use cached loaders from `utils/loaders.py` and `utils/stats.py`
- "Improve the hyperparameter analysis plot to show confidence intervals" 3. Split into focused `render_*()` functions
- "Add a Plotly histogram showing the distribution of RTS density" 4. Wrap interactive UI with `@st.fragment`
- "Fix the deprecation warning about use_container_width" 5. Replace hardcoded colors with `get_palette()`
6. Add type hints from `entropice.utils.types`
7. Organize with tabs for complex views
8. Use `width='stretch'` for charts/tables
9. Add comprehensive docstrings
10. Reference `overview_page.py` patterns
### ⚠️ Boundary Cases (Requires User Approval) ## Example Tasks
User: "Add a new climate variable to the dashboard" **✅ In Scope:**
Agent Response: - "Add feature correlation heatmap to overview page"
``` - "Create PyDeck map for RTS predictions"
⚠️ This requires changes to the data pipeline (era5.py) to extract the new variable. - "Refactor training_data_page.py to match overview_page.py patterns"
Please add the variable to the ERA5 processing pipeline first, then I can add it to the dashboard visualizations. - "Fix use_container_width deprecation warnings"
``` - "Add temporal statistics tab"
## Progress Reporting **⚠️ Out of Scope:**
- "Add new climate variable" → Requires changes to `era5.py`
- "Change training metrics" → Requires changes to `training.py`
- "Modify grid generation" → Requires changes to `grids.py`
For complex dashboard development tasks: ## Key Reminders
1. Use #tool:todos to create a task list - Only edit files in `dashboard/`
2. Mark tasks as in-progress before starting - Use `width='stretch'` not `use_container_width=True`
3. Mark completed immediately after finishing - Always reference `overview_page.py` for patterns
4. Keep the user informed of progress - Use #tool:web for documentation
- Use #tool:todo for complex multi-step work
## Remember
- **Read-only for data pipeline**: You can read any file to understand data structures, but only edit `dashboard/` files
- **Documentation first**: When unsure about Streamlit/Plotly/PyDeck APIs, fetch documentation
- **Modern Streamlit API**: Always use `width='stretch'` instead of `use_container_width=True`
- **Pause when needed**: If data pipeline changes are required, stop and inform the user
You are here to make the dashboard better, not to change how data is created or models are trained. Stay within these boundaries and you'll be most helpful!

View file

@ -1,226 +1,70 @@
# Entropice - GitHub Copilot Instructions # Entropice - Copilot Instructions
## Project Overview ## Project Context
Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. The system processes multi-source geospatial data, aggregates it into discrete global grids (H3/HEALPix), and trains probabilistic classifiers to estimate RTS occurrence patterns. This is a geospatial machine learning system for predicting Arctic permafrost degradation (Retrogressive Thaw Slumps) using entropy-optimal Scalable Probabilistic Approximations (eSPA). The system processes multi-source geospatial data through discrete global grid systems (H3, HEALPix) and trains probabilistic classifiers.
For detailed architecture information, see [ARCHITECTURE.md](../ARCHITECTURE.md). ## Code Style
For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
For project goals and setup, see [README.md](../README.md).
## Core Technologies - Follow PEP 8 conventions with 120 character line length
- Use type hints for all function signatures
- Use google-style docstrings for public functions
- Keep functions focused and modular
- Use `ruff` for linting and formatting, `ty` for type checking
- **Python**: 3.13 (strict version requirement) ## Technology Stack
- **Package Manager**: [Pixi](https://pixi.sh/) (not pip/conda directly)
- **GPU**: CUDA 12 with RAPIDS (CuPy, cuML)
- **Geospatial**: xarray, xdggs, GeoPandas, H3, Rasterio
- **ML**: scikit-learn, XGBoost, entropy (eSPA), cuML
- **Storage**: Zarr, Icechunk, Parquet, NetCDF
- **Visualization**: Streamlit, Bokeh, Matplotlib, Cartopy
## Code Style Guidelines - **Core**: Python 3.13, NumPy, Pandas, Xarray, GeoPandas
- **Spatial**: H3, xdggs, xvec for discrete global grid systems
- **ML**: scikit-learn, XGBoost, cuML, entropy (eSPA)
- **GPU**: CuPy, PyTorch, CUDA 12 - prefer GPU-accelerated operations
- **Storage**: Zarr, Icechunk, Parquet for intermediate data
- **CLI**: Cyclopts with dataclass-based configurations
### Python Standards ## Execution Guidelines
- Follow **PEP 8** conventions - Always use `pixi run` to execute Python commands and scripts
- Use **type hints** for all function signatures - Environment variables: `SCIPY_ARRAY_API=1`, `FAST_DATA_DIR=./data`
- Write **numpy-style docstrings** for public functions
- Keep functions **focused and modular**
- Prefer descriptive variable names over abbreviations
### Geospatial Best Practices ## Geospatial Best Practices
- Use **EPSG:3413** (Arctic Stereographic) for computations - Use EPSG:3413 (Arctic Stereographic) for computations
- Use **EPSG:4326** (WGS84) for visualization and library compatibility - Use EPSG:4326 (WGS84) for visualization and compatibility
- Store gridded data using **xarray with XDGGS** indexing - Store gridded data as XDGGS Xarray datasets (Zarr format)
- Store tabular data as **Parquet**, array data as **Zarr** - Store tabular data as GeoParquet
- Leverage **Dask** for lazy evaluation of large datasets - Handle antimeridian issues in polar regions
- Use **GeoPandas** for vector operations - Leverage Xarray/Dask for lazy evaluation and chunked processing
- Handle **antimeridian** correctly for polar regions
### Data Pipeline Conventions ## Architecture Patterns
- Follow the numbered script sequence: `00grids.sh``01darts.sh``02alphaearth.sh``03era5.sh``04arcticdem.sh``05train.sh` - Modular CLI design: each module exposes standalone Cyclopts CLI
- Each pipeline stage should produce **reproducible intermediate outputs** - Configuration as code: use dataclasses for typed configs, TOML for hyperparameters
- Use `src/entropice/utils/paths.py` for consistent path management - GPU acceleration: use CuPy for arrays, cuML for ML, batch processing for memory management
- Environment variable `FAST_DATA_DIR` controls data directory location (default: `./data`) - Data flow: Raw sources → Grid aggregation → L2 datasets → Training → Inference → Visualization
### Storage Hierarchy ## Data Storage Hierarchy
All data follows this structure:
``` ```
DATA_DIR/ DATA_DIR/
├── grids/ # H3/HEALPix tessellations (GeoParquet) ├── grids/ # H3/HEALPix tessellations (GeoParquet)
├── darts/ # RTS labels (GeoParquet) ├── darts/ # RTS labels (GeoParquet)
├── era5/ # Climate data (Zarr) ├── era5/ # Climate data (Zarr)
├── arcticdem/ # Terrain data (Icechunk Zarr) ├── arcticdem/ # Terrain data (Icechunk Zarr)
├── alphaearth/ # Satellite embeddings (Zarr) ├── alphaearth/ # Satellite embeddings (Zarr)
├── datasets/ # L2 XDGGS datasets (Zarr) ├── datasets/ # L2 XDGGS datasets (Zarr)
├── training-results/ # Models, CV results, predictions └── training-results/ # Models, CV results, predictions
└── watermask/ # Ocean mask (GeoParquet)
``` ```
## Module Organization ## Key Modules
### Core Modules (`src/entropice/`) - `entropice.spatial`: Grid generation and raster-to-vector aggregation
- `entropice.ingest`: Data extractors (DARTS, ERA5, ArcticDEM, AlphaEarth)
- `entropice.ml`: Dataset assembly, training, inference
- `entropice.dashboard`: Streamlit visualization app
- `entropice.utils`: Paths, codecs, types
The codebase is organized into four main packages: ## Testing & Notebooks
- **`entropice.ingest`**: Data ingestion from external sources - Production code belongs in `src/entropice/`, not notebooks
- **`entropice.spatial`**: Spatial operations and grid management - Notebooks in `notebooks/` are for exploration only (not version-controlled)
- **`entropice.ml`**: Machine learning workflows - Use `pytest` for testing geospatial correctness and data integrity
- **`entropice.utils`**: Common utilities
#### Data Ingestion (`src/entropice/ingest/`)
- **`darts.py`**: RTS label extraction from DARTS v2 dataset
- **`era5.py`**: Climate data processing from ERA5 (Arctic-aligned years: Oct 1 - Sep 30)
- **`arcticdem.py`**: Terrain analysis from 32m Arctic elevation data
- **`alphaearth.py`**: Satellite image embeddings via Google Earth Engine
#### Spatial Operations (`src/entropice/spatial/`)
- **`grids.py`**: H3/HEALPix spatial grid generation with watermask
- **`aggregators.py`**: Raster-to-vector spatial aggregation engine
- **`watermask.py`**: Ocean masking utilities
- **`xvec.py`**: Extended vector operations for xarray
#### Machine Learning (`src/entropice/ml/`)
- **`dataset.py`**: Multi-source data integration and feature engineering
- **`training.py`**: Model training with eSPA, XGBoost, Random Forest, KNN
- **`inference.py`**: Batch prediction pipeline for trained models
#### Utilities (`src/entropice/utils/`)
- **`paths.py`**: Centralized path management
- **`codecs.py`**: Custom codecs for data serialization
### Dashboard (`src/entropice/dashboard/`)
- Streamlit-based interactive visualization
- Modular pages: overview, training data, analysis, model state, inference
- Bokeh-based geospatial plotting utilities
- Run with: `pixi run dashboard`
### Scripts (`scripts/`)
- Numbered pipeline scripts (`00grids.sh` through `05train.sh`)
- Run entire pipeline for multiple grid configurations
- Each script uses CLIs from core modules
### Notebooks (`notebooks/`)
- Exploratory analysis and validation
- **NOT committed to git**
- Keep production code in `src/entropice/`
## Development Workflow
### Setup
```bash
pixi install # NOT pip install or conda install
```
### Running Tests
```bash
pixi run pytest
```
### Running Python Commands
Always use `pixi run` to execute Python commands to use the correct environment:
```bash
pixi run python script.py
pixi run python -c "import entropice"
```
### Common Tasks
**Important**: Always use `pixi run` prefix for Python commands to ensure correct environment.
- **Generate grids**: Use `pixi run create-grid` or `spatial/grids.py` CLI
- **Process labels**: Use `pixi run darts` or `ingest/darts.py` CLI
- **Train models**: Use `pixi run train` with TOML config or `ml/training.py` CLI
- **Run inference**: Use `ml/inference.py` CLI
- **View results**: `pixi run dashboard`
## Key Design Patterns
### 1. XDGGS Indexing
All geospatial data uses discrete global grid systems (H3 or HEALPix) via `xdggs` library for consistent spatial indexing across sources.
### 2. Lazy Evaluation
Use Xarray/Dask for out-of-core computation with Zarr/Icechunk chunked storage to manage large datasets.
### 3. GPU Acceleration
Prefer GPU-accelerated operations:
- CuPy for array operations
- cuML for Random Forest and KNN
- XGBoost GPU training
- PyTorch tensors when applicable
### 4. Configuration as Code
- Use dataclasses for typed configuration
- TOML files for training hyperparameters
- Cyclopts for CLI argument parsing
## Data Sources
- **DARTS v2**: RTS labels (year, area, count, density)
- **ERA5**: Climate data (40-year history, Arctic-aligned years)
- **ArcticDEM**: 32m resolution terrain (slope, aspect, indices)
- **AlphaEarth**: 64-dimensional satellite embeddings
- **Watermask**: Ocean exclusion layer
## Model Support
Primary: **eSPA** (entropy-optimal Scalable Probabilistic Approximations)
Alternatives: XGBoost, Random Forest, K-Nearest Neighbors
Training features:
- Randomized hyperparameter search
- K-Fold cross-validation
- Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall)
## Extension Points
To extend Entropice:
- **New data source**: Follow patterns in `ingest/era5.py` or `ingest/arcticdem.py`
- **Custom aggregations**: Add to `_Aggregations` dataclass in `spatial/aggregators.py`
- **Alternative labels**: Implement extractor following `ingest/darts.py` pattern
- **New models**: Add scikit-learn compatible estimators to `ml/training.py`
- **Dashboard pages**: Add Streamlit pages to `dashboard/` module
## Important Notes
- **Always use `pixi run` prefix** for Python commands (not plain `python`)
- Grid resolutions: **H3** (3-6), **HEALPix** (6-10)
- Arctic years run **October 1 to September 30** (not calendar years)
- Handle **antimeridian crossing** in polar regions
- Use **batch processing** for GPU memory management
- Notebooks are for exploration only - **keep production code in `src/`**
- Always use **absolute paths** or paths from `utils/paths.py`
## Common Issues
- **Memory**: Use batch processing and Dask chunking for large datasets
- **GPU OOM**: Reduce batch size in inference or training
- **Antimeridian**: Use proper handling in `spatial/aggregators.py` for polar grids
- **Temporal alignment**: ERA5 uses Arctic-aligned years (Oct-Sep)
- **CRS**: Compute in EPSG:3413, visualize in EPSG:4326
## References
For more details, consult:
- [ARCHITECTURE.md](../ARCHITECTURE.md) - System architecture and design patterns
- [CONTRIBUTING.md](../CONTRIBUTING.md) - Development workflow and standards
- [README.md](../README.md) - Project goals and setup instructions

View file

@ -10,7 +10,7 @@ applyTo: '**/*.py,**/*.ipynb'
- Write clear and concise comments for each function. - Write clear and concise comments for each function.
- Ensure functions have descriptive names and include type hints. - Ensure functions have descriptive names and include type hints.
- Provide docstrings following PEP 257 conventions. - Provide docstrings following PEP 257 conventions.
- Use the `typing` module for type annotations (e.g., `List[str]`, `Dict[str, int]`). - Use the `typing` module for advanced type annotations (e.g., `TypedDict`, `Literal["a", "b", ...]`).
- Break down complex functions into smaller, more manageable functions. - Break down complex functions into smaller, more manageable functions.
## General Instructions ## General Instructions
@ -27,7 +27,7 @@ applyTo: '**/*.py,**/*.ipynb'
- Follow the **PEP 8** style guide for Python. - Follow the **PEP 8** style guide for Python.
- Maintain proper indentation (use 4 spaces for each level of indentation). - Maintain proper indentation (use 4 spaces for each level of indentation).
- Ensure lines do not exceed 79 characters. - Ensure lines do not exceed 120 characters.
- Place function and class docstrings immediately after the `def` or `class` keyword. - Place function and class docstrings immediately after the `def` or `class` keyword.
- Use blank lines to separate functions, classes, and code blocks where appropriate. - Use blank lines to separate functions, classes, and code blocks where appropriate.
@ -41,6 +41,8 @@ applyTo: '**/*.py,**/*.ipynb'
## Example of Proper Documentation ## Example of Proper Documentation
```python ```python
import math
def calculate_area(radius: float) -> float: def calculate_area(radius: float) -> float:
""" """
Calculate the area of a circle given the radius. Calculate the area of a circle given the radius.
@ -51,6 +53,5 @@ def calculate_area(radius: float) -> float:
Returns: Returns:
float: The area of the circle, calculated as π * radius^2. float: The area of the circle, calculated as π * radius^2.
""" """
import math
return math.pi * radius ** 2 return math.pi * radius ** 2
``` ```

View file

@ -27,12 +27,20 @@ The pipeline follows a sequential processing approach where each stage produces
### System Components ### System Components
The codebase is organized into four main packages: The project in general is organized into four directories:
- **`src/entropice/`**: The main processing codebase
- **`scripts/`**: Bash scripts and wrapper for data processing pipeline scripts
- **`notebooks/`**: Exploratory analysis and validation notebooks (not commited to git)
- **`tests/`**: Unit tests and manual validation scripts
The processing codebase is organized into five main packages:
- **`entropice.ingest`**: Data ingestion from external sources (DARTS, ERA5, ArcticDEM, AlphaEarth) - **`entropice.ingest`**: Data ingestion from external sources (DARTS, ERA5, ArcticDEM, AlphaEarth)
- **`entropice.spatial`**: Spatial operations and grid management - **`entropice.spatial`**: Spatial operations and grid management
- **`entropice.ml`**: Machine learning workflows (dataset, training, inference) - **`entropice.ml`**: Machine learning workflows (dataset, training, inference)
- **`entropice.utils`**: Common utilities (paths, codecs) - **`entropice.utils`**: Common utilities (paths, codecs)
- **`entropice.dashboard`**: Streamlit Dashboard for interactive visualization
#### 1. Spatial Grid System (`spatial/grids.py`) #### 1. Spatial Grid System (`spatial/grids.py`)
@ -179,6 +187,14 @@ scripts/05train.sh # Model training
- TOML files for training hyperparameters - TOML files for training hyperparameters
- Environment-based path management (`utils/paths.py`) - Environment-based path management (`utils/paths.py`)
### 6. Geospatial Best Practices
- Use **xarray** with XDGGS for gridded data storage
- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays)
- Leverage **Dask** for lazy evaluation of large datasets
- Use **GeoPandas** for vector operations
- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries
## Data Storage Hierarchy ## Data Storage Hierarchy
```sh ```sh

View file

@ -6,7 +6,6 @@ Thank you for your interest in contributing to Entropice! This document provides
### Prerequisites ### Prerequisites
- Python 3.13
- CUDA 12 compatible GPU (for full functionality) - CUDA 12 compatible GPU (for full functionality)
- [Pixi package manager](https://pixi.sh/) - [Pixi package manager](https://pixi.sh/)
@ -16,55 +15,36 @@ Thank you for your interest in contributing to Entropice! This document provides
pixi install pixi install
``` ```
This will set up the complete environment including RAPIDS, PyTorch, and all geospatial dependencies. This will set up the complete environment including Python, RAPIDS, PyTorch, and all geospatial dependencies.
## Development Workflow ## Development Workflow
**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies. > Read in the [Architecture Guide](ARCHITECTURE.md) about the code organisatoin and key modules
### Code Organization **Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies:
- **`src/entropice/ingest/`**: Data ingestion modules (darts, era5, arcticdem, alphaearth) ```bash
- **`src/entropice/spatial/`**: Spatial operations (grids, aggregators, watermask, xvec) pixi run python script.py
- **`src/entropice/ml/`**: Machine learning components (dataset, training, inference) pixi run python -c "import entropice"
- **`src/entropice/utils/`**: Utilities (paths, codecs) ```
- **`src/entropice/dashboard/`**: Streamlit visualization dashboard
- **`scripts/`**: Data processing pipeline scripts (numbered 00-05)
- **`notebooks/`**: Exploratory analysis and validation notebooks
- **`tests/`**: Unit tests
### Key Modules ### Python Style and Formatting
- `spatial/grids.py`: H3/HEALPix spatial grid systems
- `ingest/darts.py`, `ingest/era5.py`, `ingest/arcticdem.py`, `ingest/alphaearth.py`: Data source processors
- `ml/dataset.py`: Dataset assembly and feature engineering
- `ml/training.py`: Model training with eSPA, XGBoost, Random Forest, KNN
- `ml/inference.py`: Prediction generation
- `utils/paths.py`: Centralized path management
## Coding Standards
### Python Style
- Follow PEP 8 conventions - Follow PEP 8 conventions
- Use type hints for function signatures - Use type hints for function signatures
- Prefer numpy-style docstrings for public functions - Prefer google-style docstrings for public functions
- Keep functions focused and modular - Keep functions focused and modular
### Geospatial Best Practices `ty` and `ruff` are used for typing, linting and formatting.
Ensure to check for any warnings from both of these:
- Use **xarray** with XDGGS for gridded data storage ```sh
- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays) pixi run ty check # For type checks
- Leverage **Dask** for lazy evaluation of large datasets pixi run ruff check # For linting
- Use **GeoPandas** for vector operations pixi run ruff format # For formatting
- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries ```
### Data Pipeline Single files can be specified by just adding them to the command, e.g. `pixi run ty check src/entropice/dashboard/app.py`
- Follow the numbered script sequence: `00grids.sh``01darts.sh` → ... → `05train.sh`
- Each stage should produce reproducible intermediate outputs
- Document data dependencies in module docstrings
- Use `utils/paths.py` for consistent path management
## Testing ## Testing
@ -74,13 +54,6 @@ Run tests for specific modules:
pixi run pytest pixi run pytest
``` ```
When running Python scripts or commands, always use `pixi run`:
```bash
pixi run python script.py
pixi run python -c "import entropice"
```
When adding features, include tests that verify: When adding features, include tests that verify:
- Correct handling of geospatial coordinates and projections - Correct handling of geospatial coordinates and projections
@ -100,15 +73,8 @@ When adding features, include tests that verify:
### Commit Messages ### Commit Messages
- Use present tense: "Add feature" not "Added feature" - Use present tense: "Add feature" not "Added feature"
- Reference issues when applicable: "Fix #123: Correct grid aggregation"
- Keep first line under 72 characters - Keep first line under 72 characters
## Working with Data
### Local Development
- Set `FAST_DATA_DIR` environment variable for data directory (default: `./data`)
### Notebooks ### Notebooks
- Notebooks in `notebooks/` are for exploration and validation, they are not commited to git - Notebooks in `notebooks/` are for exploration and validation, they are not commited to git

View file

@ -100,6 +100,9 @@ cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" } cuml-cu12 = { index = "nvidia" }
cuspatial-cu12 = { index = "nvidia" } cuspatial-cu12 = { index = "nvidia" }
[tool.ruff]
line-length = 120
[tool.ruff.lint.pyflakes] [tool.ruff.lint.pyflakes]
# Ignore libraries when checking for unused imports # Ignore libraries when checking for unused imports
allowed-unused-imports = [ allowed-unused-imports = [

View file

@ -224,7 +224,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(value_counts) alt.Chart(value_counts)
.mark_bar(color=bar_color) .mark_bar(color=bar_color)
.encode( .encode(
alt.X(f"{formatted_col}:N", title=param_name, sort=None), alt.X(
f"{formatted_col}:N",
title=param_name,
sort=None,
),
alt.Y("count:Q", title="Count"), alt.Y("count:Q", title="Count"),
tooltip=[ tooltip=[
alt.Tooltip(param_name, format=".2e"), alt.Tooltip(param_name, format=".2e"),
@ -238,7 +242,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(value_counts) alt.Chart(value_counts)
.mark_bar(color=bar_color) .mark_bar(color=bar_color)
.encode( .encode(
alt.X(f"{param_name}:Q", title=param_name, scale=x_scale), alt.X(
f"{param_name}:Q",
title=param_name,
scale=x_scale,
),
alt.Y("count:Q", title="Count"), alt.Y("count:Q", title="Count"),
tooltip=[ tooltip=[
alt.Tooltip(param_name, format=".3f"), alt.Tooltip(param_name, format=".3f"),
@ -301,10 +309,18 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
alt.Chart(df_plot) alt.Chart(df_plot)
.mark_bar(color=bar_color) .mark_bar(color=bar_color)
.encode( .encode(
alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name), alt.X(
f"{param_name}:Q",
bin=alt.Bin(maxbins=n_bins),
title=param_name,
),
alt.Y("count()", title="Count"), alt.Y("count()", title="Count"),
tooltip=[ tooltip=[
alt.Tooltip(f"{param_name}:Q", format=format_str, bin=True), alt.Tooltip(
f"{param_name}:Q",
format=format_str,
bin=True,
),
"count()", "count()",
], ],
) )
@ -396,9 +412,15 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
scale=alt.Scale(range=get_palette(metric, n_colors=256)), scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None, legend=None,
), ),
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")], tooltip=[
alt.Tooltip(param_name, format=".2e"),
alt.Tooltip(metric, format=".4f"),
],
)
.properties(
height=300,
title=f"{metric} vs {param_name} (log scale)",
) )
.properties(height=300, title=f"{metric} vs {param_name} (log scale)")
) )
else: else:
chart = ( chart = (
@ -412,7 +434,10 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
scale=alt.Scale(range=get_palette(metric, n_colors=256)), scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None, legend=None,
), ),
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")], tooltip=[
alt.Tooltip(param_name, format=".3f"),
alt.Tooltip(metric, format=".4f"),
],
) )
.properties(height=300, title=f"{metric} vs {param_name}") .properties(height=300, title=f"{metric} vs {param_name}")
) )
@ -568,7 +593,16 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
if len(param_names) == 2: if len(param_names) == 2:
# Simple case: just one pair # Simple case: just one pair
x_param, y_param = param_names_sorted x_param, y_param = param_names_sorted
_render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500) _render_2d_param_plot(
results_binned,
x_param,
y_param,
score_col,
bin_info,
hex_colors,
metric,
height=500,
)
else: else:
# Multiple parameters: create structured plots # Multiple parameters: create structured plots
st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}") st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
@ -599,7 +633,14 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
with cols[col_idx]: with cols[col_idx]:
_render_2d_param_plot( _render_2d_param_plot(
results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350 results_binned,
x_param,
y_param,
score_col,
bin_info,
hex_colors,
metric,
height=350,
) )
@ -708,7 +749,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
# Get colormap for score evolution # Get colormap for score evolution
evolution_cmap = get_cmap("score_evolution") evolution_cmap = get_cmap("score_evolution")
evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))] evolution_colors = [
mcolors.rgb2hex(evolution_cmap(0.3)),
mcolors.rgb2hex(evolution_cmap(0.7)),
]
# Create line chart # Create line chart
chart = ( chart = (
@ -717,13 +761,21 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
.encode( .encode(
alt.X("Iteration", title="Iteration"), alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()), alt.Y("value", title=metric.replace("_", " ").title()),
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)), alt.Color(
"Type",
legend=alt.Legend(title=""),
scale=alt.Scale(range=evolution_colors),
),
strokeDash=alt.StrokeDash( strokeDash=alt.StrokeDash(
"Type", "Type",
legend=None, legend=None,
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]), scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
), ),
tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")], tooltip=[
"Iteration",
"Type",
alt.Tooltip("value", format=".4f", title="Score"),
],
) )
.properties(height=400) .properties(height=400)
) )
@ -1195,7 +1247,13 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
with col1: with col1:
# Filter by confusion category # Filter by confusion category
if task == "binary": if task == "binary":
categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"] categories = [
"All",
"True Positive",
"False Positive",
"True Negative",
"False Negative",
]
else: else:
categories = ["All", "Correct", "Incorrect"] categories = ["All", "Correct", "Incorrect"]
@ -1206,7 +1264,14 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
) )
with col2: with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity") opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="confusion_map_opacity",
)
# Filter data if needed # Filter data if needed
if selected_category != "All": if selected_category != "All":

View file

@ -85,7 +85,11 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
.mark_rect() .mark_rect()
.encode( .encode(
x=alt.X("year:O", title="Year"), x=alt.X("year:O", title="Year"),
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), y=alt.Y(
"band:O",
title="Band",
sort=alt.SortField(field="band", order="ascending"),
),
color=alt.Color( color=alt.Color(
"weight:Q", "weight:Q",
scale=alt.Scale(scheme="blues"), scale=alt.Scale(scheme="blues"),
@ -105,7 +109,9 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
return chart return chart
def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]: def plot_embedding_aggregation_summary(
embedding_array: xr.DataArray,
) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Create bar charts summarizing embedding weights by aggregation, band, and year. """Create bar charts summarizing embedding weights by aggregation, band, and year.
Args: Args:
@ -345,8 +351,18 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs() by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs()
# Create DataFrames # Create DataFrames
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) df_variable = pd.DataFrame(
df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values}) {
"dimension": by_variable.index.astype(str),
"mean_abs_weight": by_variable.values,
}
)
df_season = pd.DataFrame(
{
"dimension": by_season.index.astype(str),
"mean_abs_weight": by_season.values,
}
)
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values}) df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
# Sort by weight # Sort by weight
@ -358,7 +374,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_variable) alt.Chart(df_variable)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), y=alt.Y(
"dimension:N",
title="Variable",
sort="-x",
axis=alt.Axis(labelLimit=300),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -377,7 +398,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_season) alt.Chart(df_season)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)), y=alt.Y(
"dimension:N",
title="Season",
sort="-x",
axis=alt.Axis(labelLimit=200),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -396,7 +422,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_year) alt.Chart(df_year)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)), y=alt.Y(
"dimension:O",
title="Year",
sort="-x",
axis=alt.Axis(labelLimit=200),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -418,7 +449,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs() by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs()
# Create DataFrames, handling potential MultiIndex # Create DataFrames, handling potential MultiIndex
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) df_variable = pd.DataFrame(
{
"dimension": by_variable.index.astype(str),
"mean_abs_weight": by_variable.values,
}
)
df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values}) df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values})
# Sort by weight # Sort by weight
@ -430,7 +466,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_variable) alt.Chart(df_variable)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), y=alt.Y(
"dimension:N",
title="Variable",
sort="-x",
axis=alt.Axis(labelLimit=300),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -449,7 +490,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
alt.Chart(df_time) alt.Chart(df_time)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), y=alt.Y(
"dimension:N",
title="Time",
sort="-x",
axis=alt.Axis(labelLimit=200),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -508,7 +554,9 @@ def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart:
return chart return chart
def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: def plot_arcticdem_summary(
arcticdem_array: xr.DataArray,
) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ArcticDEM weights by variable and aggregation. """Create bar charts summarizing ArcticDEM weights by variable and aggregation.
Args: Args:
@ -523,7 +571,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs() by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames, handling potential MultiIndex # Create DataFrames, handling potential MultiIndex
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) df_variable = pd.DataFrame(
{
"dimension": by_variable.index.astype(str),
"mean_abs_weight": by_variable.values,
}
)
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values}) df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
# Sort by weight # Sort by weight
@ -535,7 +588,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
alt.Chart(df_variable) alt.Chart(df_variable)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), y=alt.Y(
"dimension:N",
title="Variable",
sort="-x",
axis=alt.Axis(labelLimit=300),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -554,7 +612,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
alt.Chart(df_agg) alt.Chart(df_agg)
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)), y=alt.Y(
"dimension:N",
title="Aggregation",
sort="-x",
axis=alt.Axis(labelLimit=200),
),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color( color=alt.Color(
"mean_abs_weight:Q", "mean_abs_weight:Q",
@ -665,7 +728,12 @@ def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str])
alt.Chart(counts) alt.Chart(counts)
.mark_bar() .mark_bar()
.encode( .encode(
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)), x=alt.X(
"class:N",
title="Class Label",
sort=class_order,
axis=alt.Axis(labelAngle=-45),
),
y=alt.Y("count:Q", title="Number of Boxes"), y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color( color=alt.Color(
"class:N", "class:N",
@ -767,7 +835,10 @@ def plot_xgboost_feature_importance(
.mark_bar() .mark_bar()
.encode( .encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"), x=alt.X(
"importance:Q",
title=f"{importance_type.replace('_', ' ').title()} Importance",
),
color=alt.value("steelblue"), color=alt.value("steelblue"),
tooltip=[ tooltip=[
alt.Tooltip("feature:N", title="Feature"), alt.Tooltip("feature:N", title="Feature"),
@ -880,7 +951,9 @@ def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.
return chart return chart
def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]: def plot_rf_tree_statistics(
model_state: xr.Dataset,
) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Plot Random Forest tree statistics. """Plot Random Forest tree statistics.
Args: Args:
@ -908,7 +981,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"), x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
y=alt.Y("count()", title="Number of Trees"), y=alt.Y("count()", title="Number of Trees"),
color=alt.value("steelblue"), color=alt.value("steelblue"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")], tooltip=[
alt.Tooltip("count()", title="Count"),
alt.Tooltip("value:Q", bin=True, title="Depth Range"),
],
) )
.properties(width=300, height=200, title="Distribution of Tree Depths") .properties(width=300, height=200, title="Distribution of Tree Depths")
) )
@ -921,7 +997,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"), x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
y=alt.Y("count()", title="Number of Trees"), y=alt.Y("count()", title="Number of Trees"),
color=alt.value("forestgreen"), color=alt.value("forestgreen"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")], tooltip=[
alt.Tooltip("count()", title="Count"),
alt.Tooltip("value:Q", bin=True, title="Leaves Range"),
],
) )
.properties(width=300, height=200, title="Distribution of Leaf Counts") .properties(width=300, height=200, title="Distribution of Leaf Counts")
) )
@ -934,7 +1013,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"), x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
y=alt.Y("count()", title="Number of Trees"), y=alt.Y("count()", title="Number of Trees"),
color=alt.value("darkorange"), color=alt.value("darkorange"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")], tooltip=[
alt.Tooltip("count()", title="Count"),
alt.Tooltip("value:Q", bin=True, title="Nodes Range"),
],
) )
.properties(width=300, height=200, title="Distribution of Node Counts") .properties(width=300, height=200, title="Distribution of Node Counts")
) )

View file

@ -330,7 +330,10 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
with col3: with col3:
time_values = pd.to_datetime(ds["time"].values) time_values = pd.to_datetime(ds["time"].values)
st.metric("Time Steps", f"{time_values.min().strftime('%Y')}{time_values.max().strftime('%Y')}") st.metric(
"Time Steps",
f"{time_values.min().strftime('%Y')}{time_values.max().strftime('%Y')}",
)
with col4: with col4:
if has_agg: if has_agg:
@ -398,11 +401,15 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
col1, col2, col3 = st.columns([2, 2, 1]) col1, col2, col3 = st.columns([2, 2, 1])
with col1: with col1:
selected_var = st.selectbox( selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" "Select variable to visualize",
options=variables,
key=f"era5_{temporal_type}_var_select",
) )
with col2: with col2:
selected_agg = st.selectbox( selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select" "Aggregation",
options=ds["aggregations"].values,
key=f"era5_{temporal_type}_agg_select",
) )
with col3: with col3:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
@ -411,7 +418,9 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
col1, col2 = st.columns([3, 1]) col1, col2 = st.columns([3, 1])
with col1: with col1:
selected_var = st.selectbox( selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" "Select variable to visualize",
options=variables,
key=f"era5_{temporal_type}_var_select",
) )
with col2: with col2:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
@ -614,7 +623,10 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
deck = pdk.Deck( deck = pdk.Deck(
layers=[layer], layers=[layer],
initial_view_state=view_state, initial_view_state=view_state,
tooltip={"html": "<b>Value:</b> {value}", "style": {"backgroundColor": "steelblue", "color": "white"}}, tooltip={
"html": "<b>Value:</b> {value}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
) )
@ -746,7 +758,10 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
deck = pdk.Deck( deck = pdk.Deck(
layers=[layer], layers=[layer],
initial_view_state=view_state, initial_view_state=view_state,
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
) )
@ -784,7 +799,14 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
) )
with col2: with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity") opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="areas_map_opacity",
)
# Create GeoDataFrame # Create GeoDataFrame
gdf = grid_gdf.copy() gdf = grid_gdf.copy()
@ -896,7 +918,9 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2: with col2:
selected_agg = st.selectbox( selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg" "Aggregation",
options=ds["aggregations"].values,
key=f"era5_{temporal_type}_agg",
) )
with col3: with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
@ -1022,7 +1046,10 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
deck = pdk.Deck( deck = pdk.Deck(
layers=[layer], layers=[layer],
initial_view_state=view_state, initial_view_state=view_state,
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
) )

View file

@ -11,7 +11,9 @@ from entropice.dashboard.utils.colors import get_palette
from entropice.ml.dataset import CategoricalTrainingDataset from entropice.ml.dataset import CategoricalTrainingDataset
def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]): def render_all_distribution_histograms(
train_data_dict: dict[str, CategoricalTrainingDataset],
):
"""Render histograms for all three tasks side by side. """Render histograms for all three tasks side by side.
Args: Args:
@ -81,7 +83,13 @@ def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTra
height=400, height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20}, margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True, showlegend=True,
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
xaxis_title=None, xaxis_title=None,
yaxis_title="Count", yaxis_title="Count",
xaxis={"tickangle": -45}, xaxis={"tickangle": -45},
@ -137,7 +145,10 @@ def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb) gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
elif color_mode == "split": elif color_mode == "split":
split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange split_colors = {
"train": [66, 135, 245],
"test": [245, 135, 66],
} # Blue # Orange
gdf["fill_color"] = gdf["split"].map(split_colors) gdf["fill_color"] = gdf["split"].map(split_colors)
return gdf return gdf
@ -168,7 +179,14 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
) )
with col2: with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity") opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="spatial_map_opacity",
)
# Determine which task dataset to use and color mode # Determine which task dataset to use and color mode
if vis_mode == "split": if vis_mode == "split":
@ -258,7 +276,10 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
# Set initial view state (centered on the Arctic) # Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using elevation # Adjust pitch and zoom based on whether we're using elevation
view_state = pdk.ViewState( view_state = pdk.ViewState(
latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45 latitude=70,
longitude=0,
zoom=2 if not use_elevation else 1.5,
pitch=0 if not use_elevation else 45,
) )
# Create deck # Create deck

View file

@ -95,7 +95,9 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
return colors return colors
def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]: def generate_unified_colormap(
settings: dict,
) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
"""Generate unified colormaps for all plotting libraries. """Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot, This function creates consistent color schemes across Matplotlib/Ultraplot,
@ -125,7 +127,10 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
if is_dark_theme: if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else: else:
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme base_colors = [
"#3498db",
"#e74c3c",
] # Brighter blue and red for light theme
else: else:
# For multi-class: use a sequential colormap # For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap # Use matplotlib's viridis colormap

View file

@ -19,7 +19,9 @@ class ModelDisplayInfo:
model_display_infos: dict[Model, ModelDisplayInfo] = { model_display_infos: dict[Model, ModelDisplayInfo] = {
"espa": ModelDisplayInfo( "espa": ModelDisplayInfo(
keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm" keyword="espa",
short="eSPA",
long="entropy-optimal Scalable Probabilistic Approximations algorithm",
), ),
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"), "xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"), "rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),

View file

@ -135,7 +135,9 @@ def load_all_training_results() -> list[TrainingResult]:
return training_results return training_results
def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]: def load_all_training_data(
e: DatasetEnsemble,
) -> dict[Task, CategoricalTrainingDataset]:
"""Load training data for all three tasks. """Load training data for all three tasks.
Args: Args:

View file

@ -142,7 +142,9 @@ class DatasetStatistics:
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
@staticmethod @staticmethod
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: def get_sample_count_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
) -> pd.DataFrame:
"""Convert sample count data to DataFrame.""" """Convert sample count data to DataFrame."""
rows = [] rows = []
for grid_config in grid_configs: for grid_config in grid_configs:
@ -164,7 +166,9 @@ class DatasetStatistics:
return pd.DataFrame(rows) return pd.DataFrame(rows)
@staticmethod @staticmethod
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: def get_feature_count_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
) -> pd.DataFrame:
"""Convert feature count data to DataFrame.""" """Convert feature count data to DataFrame."""
rows = [] rows = []
for grid_config in grid_configs: for grid_config in grid_configs:
@ -191,7 +195,9 @@ class DatasetStatistics:
return pd.DataFrame(rows) return pd.DataFrame(rows)
@staticmethod @staticmethod
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: def get_feature_breakdown_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts.""" """Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = [] rows = []
for grid_config in grid_configs: for grid_config in grid_configs:
@ -219,7 +225,10 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
target_statistics: dict[TargetDataset, TargetStatistics] = {} target_statistics: dict[TargetDataset, TargetStatistics] = {}
for target in all_target_datasets: for target in all_target_datasets:
target_statistics[target] = TargetStatistics.compute( target_statistics[target] = TargetStatistics.compute(
grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells grid=grid_config.grid,
level=grid_config.level,
target=target,
total_cells=total_cells,
) )
member_statistics: dict[L2SourceDataset, MemberStatistics] = {} member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
for member in all_l2_source_datasets: for member in all_l2_source_datasets:
@ -357,7 +366,10 @@ class TrainingDatasetStatistics:
@classmethod @classmethod
def compute( def compute(
cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None cls,
ensemble: DatasetEnsemble,
task: Task,
dataset: gpd.GeoDataFrame | None = None,
) -> "TrainingDatasetStatistics": ) -> "TrainingDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu") categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")

View file

@ -65,7 +65,9 @@ def extract_embedding_features(
def extract_era5_features( def extract_era5_features(
model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None model_state: xr.Dataset,
importance_type: str = "feature_weights",
temporal_group: str | None = None,
) -> xr.DataArray | None: ) -> xr.DataArray | None:
"""Extract ERA5 features from the model state. """Extract ERA5 features from the model state.
@ -100,7 +102,17 @@ def extract_era5_features(
- Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021") - Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021")
""" """
parts = feature.split("_") parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} common_aggs = {
"mean",
"std",
"min",
"max",
"median",
"sum",
"count",
"q25",
"q75",
}
# Find where the time part starts (after "era5" and variable name) # Find where the time part starts (after "era5" and variable name)
# Pattern: era5_variable_time or era5_variable_time_agg # Pattern: era5_variable_time or era5_variable_time_agg
@ -166,7 +178,17 @@ def extract_era5_features(
def _extract_time_name(feature: str) -> str: def _extract_time_name(feature: str) -> str:
parts = feature.split("_") parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} common_aggs = {
"mean",
"std",
"min",
"max",
"median",
"sum",
"count",
"q25",
"q75",
}
if parts[-1] in common_aggs: if parts[-1] in common_aggs:
# Has aggregation: era5_var_time_agg -> time is second to last # Has aggregation: era5_var_time_agg -> time is second to last
return parts[-2] return parts[-2]
@ -179,11 +201,8 @@ def extract_era5_features(
Pattern: era5_variable_season_year_agg or era5_variable_season_year Pattern: era5_variable_season_year_agg or era5_variable_season_year
""" """
parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
# Look through parts to find season/shoulder indicators # Look through parts to find season/shoulder indicators
for part in parts: for part in feature.split("_"):
if part.lower() in ["summer", "winter"]: if part.lower() in ["summer", "winter"]:
return part.lower() return part.lower()
elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]: elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
@ -197,11 +216,26 @@ def extract_era5_features(
For seasonal/shoulder features, find the year that comes after the season. For seasonal/shoulder features, find the year that comes after the season.
""" """
parts = feature.split("_") parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} common_aggs = {
"mean",
"std",
"min",
"max",
"median",
"sum",
"count",
"q25",
"q75",
}
# Find the season/shoulder part, then the next part should be the year # Find the season/shoulder part, then the next part should be the year
for i, part in enumerate(parts): for i, part in enumerate(parts):
if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]: if part.lower() in ["summer", "winter"] or part.upper() in [
"JFM",
"AMJ",
"JAS",
"OND",
]:
# Next part should be the year # Next part should be the year
if i + 1 < len(parts): if i + 1 < len(parts):
next_part = parts[i + 1] next_part = parts[i + 1]
@ -218,7 +252,17 @@ def extract_era5_features(
def _extract_agg_name(feature: str) -> str | None: def _extract_agg_name(feature: str) -> str | None:
parts = feature.split("_") parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} common_aggs = {
"mean",
"std",
"min",
"max",
"median",
"sum",
"count",
"q25",
"q75",
}
if parts[-1] in common_aggs: if parts[-1] in common_aggs:
return parts[-1] return parts[-1]
return None return None
@ -255,7 +299,10 @@ def extract_era5_features(
if has_agg: if has_agg:
era5_features_array = era5_features_array.assign_coords( era5_features_array = era5_features_array.assign_coords(
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), agg=(
"feature",
[_extract_agg_name(f) or "none" for f in era5_features],
),
) )
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack( era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack(
"feature" "feature"
@ -274,7 +321,10 @@ def extract_era5_features(
if has_agg: if has_agg:
# Add aggregation dimension # Add aggregation dimension
era5_features_array = era5_features_array.assign_coords( era5_features_array = era5_features_array.assign_coords(
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), agg=(
"feature",
[_extract_agg_name(f) or "none" for f in era5_features],
),
) )
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
else: else:
@ -345,7 +395,14 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
Returns None if no common features are found. Returns None if no common features are found.
""" """
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"] common_feature_names = [
"cell_area",
"water_area",
"land_area",
"land_ratio",
"lon",
"lat",
]
def _is_common_feature(feature: str) -> bool: def _is_common_feature(feature: str) -> bool:
return feature in common_feature_names return feature in common_feature_names

View file

@ -66,7 +66,10 @@ def render_inference_page():
with col4: with col4:
st.metric("Level", selected_result.settings.get("level", "Unknown")) st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col5: with col5:
st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", "")) st.metric(
"Target",
selected_result.settings.get("target", "Unknown").replace("darts_", ""),
)
st.divider() st.divider()

View file

@ -10,8 +10,16 @@ from stopuhr import stopwatch
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics from entropice.dashboard.utils.stats import (
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs DatasetStatistics,
load_all_default_dataset_statistics,
)
from entropice.utils.types import (
GridConfig,
L2SourceDataset,
TargetDataset,
grid_configs,
)
def render_sample_count_overview(): def render_sample_count_overview():
@ -45,7 +53,12 @@ def render_sample_count_overview():
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task # Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean") pivot_df = target_df.pivot_table(
index="Grid",
columns="Task",
values="Samples (Coverage)",
aggfunc="mean",
)
# Sort index by grid type and level # Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
@ -56,7 +69,11 @@ def render_sample_count_overview():
fig = px.imshow( fig = px.imshow(
pivot_df, pivot_df,
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"}, labels={
"x": "Task",
"y": "Grid Configuration",
"color": "Sample Count",
},
x=pivot_df.columns, x=pivot_df.columns,
y=pivot_df.index, y=pivot_df.index,
color_continuous_scale=sample_colors, color_continuous_scale=sample_colors,
@ -87,7 +104,10 @@ def render_sample_count_overview():
facet_col="Target", facet_col="Target",
barmode="group", barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset", title="Sample Counts by Grid Configuration and Target Dataset",
labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"}, labels={
"Grid": "Grid Configuration",
"Samples (Coverage)": "Number of Samples",
},
color_discrete_sequence=task_colors, color_discrete_sequence=task_colors,
height=500, height=500,
) )
@ -141,7 +161,10 @@ def render_feature_count_comparison():
color="Data Source", color="Data Source",
barmode="stack", barmode="stack",
title="Total Features by Data Source Across Grid Configurations", title="Total Features by Data Source Across Grid Configurations",
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"}, labels={
"Grid": "Grid Configuration",
"Number of Features": "Number of Features",
},
color_discrete_sequence=source_colors, color_discrete_sequence=source_colors,
text_auto=False, text_auto=False,
) )
@ -162,7 +185,10 @@ def render_feature_count_comparison():
y="Inference Cells", y="Inference Cells",
color="Grid", color="Grid",
title="Inference Cells by Grid Configuration", title="Inference Cells by Grid Configuration",
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"}, labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
},
color_discrete_sequence=grid_colors, color_discrete_sequence=grid_colors,
text="Inference Cells", text="Inference Cells",
) )
@ -177,7 +203,10 @@ def render_feature_count_comparison():
y="Total Samples", y="Total Samples",
color="Grid", color="Grid",
title="Total Samples by Grid Configuration", title="Total Samples by Grid Configuration",
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"}, labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
},
color_discrete_sequence=grid_colors, color_discrete_sequence=grid_colors,
text="Total Samples", text="Total Samples",
) )
@ -226,7 +255,13 @@ def render_feature_count_comparison():
# Display full comparison table with formatting # Display full comparison table with formatting
display_df = comparison_df[ display_df = comparison_df[
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"] [
"Grid",
"Total Features",
"Data Sources",
"Inference Cells",
"Total Samples",
]
].copy() ].copy()
# Format numbers with commas # Format numbers with commas
@ -314,7 +349,11 @@ def render_feature_count_explorer():
with col2: with col2:
# Calculate minimum cells across all data sources (for inference capability) # Calculate minimum cells across all data sources (for inference capability)
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values()) min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources") st.metric(
"Inference Cells",
f"{min_cells:,}",
help="Number of union of cells across all data sources",
)
with col3: with col3:
st.metric("Data Sources", len(selected_members)) st.metric("Data Sources", len(selected_members))
with col4: with col4:

View file

@ -14,7 +14,10 @@ from entropice.dashboard.plots.source_data import (
render_era5_overview, render_era5_overview,
render_era5_plots, render_era5_plots,
) )
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map from entropice.dashboard.plots.training_data import (
render_all_distribution_histograms,
render_spatial_map,
)
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.spatial import grids from entropice.spatial import grids
@ -42,7 +45,10 @@ def render_training_data_page():
] ]
grid_level_combined = st.selectbox( grid_level_combined = st.selectbox(
"Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level" "Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
) )
# Parse grid type and level # Parse grid type and level
@ -60,7 +66,13 @@ def render_training_data_page():
# Members selection # Members selection
st.subheader("Dataset Members") st.subheader("Dataset Members")
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] all_members = [
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
]
selected_members = [] selected_members = []
for member in all_members: for member in all_members:
@ -69,7 +81,10 @@ def render_training_data_page():
# Form submit button # Form submit button
load_button = st.form_submit_button( load_button = st.form_submit_button(
"Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0 "Load Dataset",
type="primary",
width="stretch",
disabled=len(selected_members) == 0,
) )
# Create DatasetEnsemble only when form is submitted # Create DatasetEnsemble only when form is submitted

View file

@ -77,7 +77,20 @@ def download(grid: Grid, level: int):
for year in track(range(2024, 2025), total=1, description="Processing years..."): for year in track(range(2024, 2025), total=1, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] aggs = [
"mean",
"stdDev",
"min",
"max",
"count",
"median",
"p1",
"p5",
"p25",
"p75",
"p95",
"p99",
]
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs] bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
def extract_embedding(feature): def extract_embedding(feature):
@ -136,7 +149,20 @@ def combine_to_zarr(grid: Grid, level: int):
""" """
cell_ids = grids.get_cell_ids(grid, level) cell_ids = grids.get_cell_ids(grid, level)
years = list(range(2018, 2025)) years = list(range(2018, 2025))
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] aggs = [
"mean",
"stdDev",
"min",
"max",
"count",
"median",
"p1",
"p5",
"p25",
"p75",
"p95",
"p99",
]
bands = [f"A{str(i).zfill(2)}" for i in range(64)] bands = [f"A{str(i).zfill(2)}" for i in range(64)]
a = xr.DataArray( a = xr.DataArray(

View file

@ -125,8 +125,14 @@ def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
cs = 3600 cs = 3600
# Calculate safe slice bounds for edge chunks # Calculate safe slice bounds for edge chunks
y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d) y_start, y_end = (
x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d) max(0, cs * chunk_loc[0] - d),
min(len(y), cs * chunk_loc[0] + cs + d),
)
x_start, x_end = (
max(0, cs * chunk_loc[1] - d),
min(len(x), cs * chunk_loc[1] + cs + d),
)
# Extract coordinate arrays with safe bounds # Extract coordinate arrays with safe bounds
y_chunk = cp.asarray(y[y_start:y_end]) y_chunk = cp.asarray(y[y_start:y_end])
@ -162,7 +168,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
if np.all(np.isnan(chunk)): if np.all(np.isnan(chunk)):
# Return an array of NaNs with the expected shape # Return an array of NaNs with the expected shape
return np.full( return np.full(
(12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px), (
12,
chunk.shape[0] - 2 * large_kernels.size_px,
chunk.shape[1] - 2 * large_kernels.size_px,
),
np.nan, np.nan,
dtype=np.float32, dtype=np.float32,
) )
@ -171,7 +181,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels) # Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
mask = cp.isnan(chunk) mask = cp.isnan(chunk)
mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True) mask &= ~binary_dilation(
binary_erosion(mask, iterations=3, brute_force=True),
iterations=3,
brute_force=True,
)
if cp.any(mask): if cp.any(mask):
# Find indices of valid values # Find indices of valid values
indices = distance_transform_edt(mask, return_distances=False, return_indices=True) indices = distance_transform_edt(mask, return_distances=False, return_indices=True)

View file

@ -14,7 +14,12 @@ from rich.progress import track
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file from entropice.utils.paths import (
darts_ml_training_labels_repo,
dartsl2_cov_file,
dartsl2_file,
get_darts_rts_file,
)
from entropice.utils.types import Grid from entropice.utils.types import Grid
traceback.install() traceback.install()

View file

@ -204,15 +204,36 @@ def download_daily_aggregated():
) )
# Assign attributes # Assign attributes
daily_raw["t2m_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"} daily_raw["t2m_max"].attrs = {
daily_raw["t2m_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"} "long_name": "Daily maximum 2 metre temperature",
daily_raw["t2m_mean"].attrs = {"long_name": "Daily mean 2 metre temperature", "units": "K"} "units": "K",
}
daily_raw["t2m_min"].attrs = {
"long_name": "Daily minimum 2 metre temperature",
"units": "K",
}
daily_raw["t2m_mean"].attrs = {
"long_name": "Daily mean 2 metre temperature",
"units": "K",
}
daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"} daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"}
daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"} daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"}
daily_raw["lblt_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"} daily_raw["lblt_max"].attrs = {
daily_raw["tp"].attrs = {"long_name": "Daily total precipitation", "units": "m"} # Units are rather m^3 / m^2 "long_name": "Daily maximum lake ice bottom temperature",
daily_raw["sf"].attrs = {"long_name": "Daily total snow fall", "units": "m"} # Units are rather m^3 / m^2 "units": "K",
daily_raw["sshf"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"} }
daily_raw["tp"].attrs = {
"long_name": "Daily total precipitation",
"units": "m",
} # Units are rather m^3 / m^2
daily_raw["sf"].attrs = {
"long_name": "Daily total snow fall",
"units": "m",
} # Units are rather m^3 / m^2
daily_raw["sshf"].attrs = {
"long_name": "Daily total surface sensible heat flux",
"units": "J/m²",
}
daily_raw = daily_raw.odc.assign_crs("epsg:4326") daily_raw = daily_raw.odc.assign_crs("epsg:4326")
daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"]) daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"])
@ -281,11 +302,17 @@ def daily_enrich():
# Formulas based on Groeke et. al. (2025) Stochastic Weather generation... # Formulas based on Groeke et. al. (2025) Stochastic Weather generation...
daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2 daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2
daily.t2m_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"} daily.t2m_avg.attrs = {
"long_name": "Daily average 2 metre temperature",
"units": "K",
}
_store("t2m_avg") _store("t2m_avg")
daily["t2m_range"] = daily.t2m_max - daily.t2m_min daily["t2m_range"] = daily.t2m_max - daily.t2m_min
daily.t2m_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"} daily.t2m_range.attrs = {
"long_name": "Daily range of 2 metre temperature",
"units": "K",
}
_store("t2m_range") _store("t2m_range")
with np.errstate(invalid="ignore"): with np.errstate(invalid="ignore"):
@ -298,7 +325,10 @@ def daily_enrich():
_store("thawing_degree_days") _store("thawing_degree_days")
daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0) daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0)
daily.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"} daily.freezing_degree_days.attrs = {
"long_name": "Freezing degree days",
"units": "K",
}
_store("freezing_degree_days") _store("freezing_degree_days")
daily["thawing_days"] = (daily.t2m_avg > 273.15).where(~daily.t2m_avg.isnull()) daily["thawing_days"] = (daily.t2m_avg > 273.15).where(~daily.t2m_avg.isnull())
@ -425,7 +455,12 @@ def multi_monthly_aggregate(agg: Literal["yearly", "seasonal", "shoulder"] = "ye
multimonthly_store = get_era5_stores(agg) multimonthly_store = get_era5_stores(agg)
print(f"Saving empty multi-monthly ERA5 data to {multimonthly_store}.") print(f"Saving empty multi-monthly ERA5 data to {multimonthly_store}.")
multimonthly.to_zarr(multimonthly_store, mode="w", encoding=codecs.from_ds(multimonthly), consolidated=False) multimonthly.to_zarr(
multimonthly_store,
mode="w",
encoding=codecs.from_ds(multimonthly),
consolidated=False,
)
def _store(var): def _store(var):
nonlocal multimonthly nonlocal multimonthly
@ -512,14 +547,20 @@ def yearly_thaw_periods():
never_thaws = (daily.thawing_days.resample(time="12MS").sum(dim="time") == 0).persist() never_thaws = (daily.thawing_days.resample(time="12MS").sum(dim="time") == 0).persist()
first_thaw_day = daily.thawing_days.resample(time="12MS").map(_get_first_day).where(~never_thaws) first_thaw_day = daily.thawing_days.resample(time="12MS").map(_get_first_day).where(~never_thaws)
first_thaw_day.attrs = {"long_name": "Day of first thaw in year", "units": "day of year"} first_thaw_day.attrs = {
"long_name": "Day of first thaw in year",
"units": "day of year",
}
_store_in_yearly(first_thaw_day, "day_of_first_thaw") _store_in_yearly(first_thaw_day, "day_of_first_thaw")
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").first() n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").first()
last_thaw_day = n_days_in_year - daily.thawing_days[::-1].resample(time="12MS").map(_get_first_day).where( last_thaw_day = n_days_in_year - daily.thawing_days[::-1].resample(time="12MS").map(_get_first_day).where(
~never_thaws ~never_thaws
) )
last_thaw_day.attrs = {"long_name": "Day of last thaw in year", "units": "day of year"} last_thaw_day.attrs = {
"long_name": "Day of last thaw in year",
"units": "day of year",
}
_store_in_yearly(last_thaw_day, "day_of_last_thaw") _store_in_yearly(last_thaw_day, "day_of_last_thaw")
@ -532,14 +573,20 @@ def yearly_freeze_periods():
never_freezes = (daily.freezing_days.resample(time="12MS").sum(dim="time") == 0).persist() never_freezes = (daily.freezing_days.resample(time="12MS").sum(dim="time") == 0).persist()
first_freeze_day = daily.freezing_days.resample(time="12MS").map(_get_first_day).where(~never_freezes) first_freeze_day = daily.freezing_days.resample(time="12MS").map(_get_first_day).where(~never_freezes)
first_freeze_day.attrs = {"long_name": "Day of first freeze in year", "units": "day of year"} first_freeze_day.attrs = {
"long_name": "Day of first freeze in year",
"units": "day of year",
}
_store_in_yearly(first_freeze_day, "day_of_first_freeze") _store_in_yearly(first_freeze_day, "day_of_first_freeze")
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").last() n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").last()
last_freeze_day = n_days_in_year - daily.freezing_days[::-1].resample(time="12MS").map(_get_first_day).where( last_freeze_day = n_days_in_year - daily.freezing_days[::-1].resample(time="12MS").map(_get_first_day).where(
~never_freezes ~never_freezes
) )
last_freeze_day.attrs = {"long_name": "Day of last freeze in year", "units": "day of year"} last_freeze_day.attrs = {
"long_name": "Day of last freeze in year",
"units": "day of year",
}
_store_in_yearly(last_freeze_day, "day_of_last_freeze") _store_in_yearly(last_freeze_day, "day_of_last_freeze")
@ -728,7 +775,12 @@ def spatial_agg(
pxbuffer=10, pxbuffer=10,
) )
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)}) aggregated = aggregated.chunk(
{
"cell_ids": min(len(aggregated.cell_ids), 10000),
"time": len(aggregated.time),
}
)
store = get_era5_stores(agg, grid, level) store = get_era5_stores(agg, grid, level)
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"): with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))

View file

@ -99,7 +99,14 @@ def bin_values(
""" """
labels_dict = { labels_dict = {
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"], "count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
"density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"], "density": [
"Empty",
"Very Sparse",
"Sparse",
"Moderate",
"Dense",
"Very Dense",
],
} }
labels = labels_dict[task] labels = labels_dict[task]
@ -186,7 +193,13 @@ class DatasetEnsemble:
level: int level: int
target: Literal["darts_rts", "darts_mllabels"] target: Literal["darts_rts", "darts_mllabels"]
members: list[L2SourceDataset] = field( members: list[L2SourceDataset] = field(
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] default_factory=lambda: [
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
]
) )
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
variable_filters: dict[str, list[str]] = field(default_factory=dict) variable_filters: dict[str, list[str]] = field(default_factory=dict)
@ -277,7 +290,9 @@ class DatasetEnsemble:
return targets return targets
def _prep_era5( def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] self,
targets: gpd.GeoDataFrame,
temporal: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame: ) -> pd.DataFrame:
era5 = self._read_member("ERA5-" + temporal, targets) era5 = self._read_member("ERA5-" + temporal, targets)
era5_df = era5.to_dataframe() era5_df = era5.to_dataframe()
@ -352,7 +367,9 @@ class DatasetEnsemble:
print(f"=== Total number of features in dataset: {stats['total_features']}") print(f"=== Total number of features in dataset: {stats['total_features']}")
def create( def create(
self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r" self,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> gpd.GeoDataFrame: ) -> gpd.GeoDataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists # n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
@ -438,7 +455,10 @@ class DatasetEnsemble:
yield dataset yield dataset
def _cat_and_split( def _cat_and_split(
self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"] self,
dataset: gpd.GeoDataFrame,
task: Task,
device: Literal["cpu", "cuda", "torch"],
) -> CategoricalTrainingDataset: ) -> CategoricalTrainingDataset:
taskcol = self.taskcol(task) taskcol = self.taskcol(task)

View file

@ -20,7 +20,9 @@ set_config(array_api_dispatch=True)
def predict_proba( def predict_proba(
e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list e: DatasetEnsemble,
clf: RandomForestClassifier | ESPAClassifier | XGBClassifier,
classes: list,
) -> gpd.GeoDataFrame: ) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell. """Get predicted probabilities for each cell.

View file

@ -79,7 +79,19 @@ class _Aggregations:
def aggnames(self) -> list[str]: def aggnames(self) -> list[str]:
if self._common: if self._common:
return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"] return [
"mean",
"std",
"min",
"max",
"median",
"p1",
"p5",
"p25",
"p75",
"p95",
"p99",
]
names = [] names = []
if self.mean: if self.mean:
names.append("mean") names.append("mean")
@ -105,7 +117,15 @@ class _Aggregations:
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy() cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy() cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy() cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here! quantiles_to_compute = [
0.5,
0.01,
0.05,
0.25,
0.75,
0.95,
0.99,
] # ? Ordering is important here!
cell_data[4:, ...] = flattened_var.quantile( cell_data[4:, ...] = flattened_var.quantile(
q=quantiles_to_compute, q=quantiles_to_compute,
dim="z", dim="z",
@ -156,7 +176,11 @@ class _Aggregations:
return self._agg_cell_data_single(flattened) return self._agg_cell_data_single(flattened)
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"]) others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32) cell_data = np.full(
(len(flattened.data_vars), len(self), *others_shape),
np.nan,
dtype=np.float32,
)
for i, var in enumerate(flattened.data_vars): for i, var in enumerate(flattened.data_vars):
cell_data[i, ...] = self._agg_cell_data_single(flattened[var]) cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
# Transform to numpy arrays # Transform to numpy arrays
@ -194,7 +218,9 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
@stopwatch("Correcting geometries", log=False) @stopwatch("Correcting geometries", log=False)
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]: def _get_corrected_geoms(
inp: tuple[Polygon, odc.geo.geobox.GeoBox, str],
) -> list[odc.geo.Geometry]:
geom, gbox, crs = inp geom, gbox, crs = inp
# cell.geometry is a shapely Polygon # cell.geometry is a shapely Polygon
if crs != "EPSG:4326" or not _crosses_antimeridian(geom): if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
@ -440,7 +466,12 @@ def _align_partition(
others_shape = tuple( others_shape = tuple(
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
) )
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) ongrid_shape = (
len(grid_partition_gdf),
len(raster.data_vars),
len(aggregations),
*others_shape,
)
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32) ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
@ -455,7 +486,11 @@ def _align_partition(
cell_ids = grids.convert_cell_ids(grid_partition_gdf) cell_ids = grids.convert_cell_ids(grid_partition_gdf)
dims = ["cell_ids", "variables", "aggregations"] dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} coords = {
"cell_ids": cell_ids,
"variables": list(raster.data_vars),
"aggregations": aggregations.aggnames(),
}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim) dims.append(dim)
coords[dim] = raster.coords[dim] coords[dim] = raster.coords[dim]

View file

@ -131,7 +131,9 @@ def create_global_hex_grid(resolution):
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
} }
for future in track( for future in track(
as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells) as_completed(future_to_hex),
description="Creating hex polygons...",
total=len(hex0_cells),
): ):
hex_batch, hex_id_batch, hex_area_batch = future.result() hex_batch, hex_id_batch, hex_area_batch = future.result()
hex_list.extend(hex_batch) hex_list.extend(hex_batch)
@ -157,7 +159,10 @@ def create_global_hex_grid(resolution):
hex_area_list.append(hex_area) hex_area_list.append(hex_area)
# Create GeoDataFrame # Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326") grid = gpd.GeoDataFrame(
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
crs="EPSG:4326",
)
return grid return grid

View file

@ -5,7 +5,10 @@ from zarr.codecs import BloscCodec
def from_ds( def from_ds(
ds: xr.Dataset, store_floats_as_float32: bool = True, include_coords: bool = True, filter_existing: bool = True ds: xr.Dataset,
store_floats_as_float32: bool = True,
include_coords: bool = True,
filter_existing: bool = True,
) -> dict: ) -> dict:
"""Create compression encoding for zarr dataset storage. """Create compression encoding for zarr dataset storage.

View file

@ -4,7 +4,17 @@ from dataclasses import dataclass
from typing import Literal from typing import Literal
type Grid = Literal["hex", "healpix"] type Grid = Literal["hex", "healpix"]
type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"] type GridLevel = Literal[
"hex3",
"hex4",
"hex5",
"hex6",
"healpix6",
"healpix7",
"healpix8",
"healpix9",
"healpix10",
]
type TargetDataset = Literal["darts_rts", "darts_mllabels"] type TargetDataset = Literal["darts_rts", "darts_mllabels"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]

View file

@ -0,0 +1,83 @@
from itertools import product
import xarray as xr
from rich.progress import track
from entropice.utils.paths import (
get_arcticdem_stores,
get_embeddings_store,
get_era5_stores,
)
from entropice.utils.types import Grid, L2SourceDataset
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
"""Validate if the L2 dataset exists for the given grid and level.
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
l2ds (L2SourceDataset): The L2 source dataset to validate.
Returns:
bool: True if the dataset exists and does not contain NaNs, False otherwise.
"""
if l2ds == "ArcticDEM":
store = get_arcticdem_stores(grid, level)
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
agg = l2ds.split("-")[1]
store = get_era5_stores(agg, grid, level) # type: ignore
elif l2ds == "AlphaEarth":
store = get_embeddings_store(grid, level)
else:
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
if not store.exists():
print("\t Dataset store does not exist")
return False
ds = xr.open_zarr(store, consolidated=False)
has_nan = False
for var in ds.data_vars:
n_nans = ds[var].isnull().sum().compute().item()
if n_nans > 0:
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
has_nan = True
if has_nan:
return False
return True
def main():
grid_levels: set[tuple[Grid, int]] = {
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
"ERA5-seasonal",
"ERA5-yearly",
"AlphaEarth",
]
for (grid, level), l2ds in track(
product(grid_levels, l2_source_datasets),
total=len(grid_levels) * len(l2_source_datasets),
description="Validating L2 datasets...",
):
is_valid = validate_l2_dataset(grid, level, l2ds)
status = "VALID" if is_valid else "INVALID"
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
if __name__ == "__main__":
main()