Update docs, instructions and format code
This commit is contained in:
parent
fca232da91
commit
4260b492ab
29 changed files with 987 additions and 467 deletions
263
.github/agents/Dashboard.agent.md
vendored
263
.github/agents/Dashboard.agent.md
vendored
|
|
@ -1,56 +1,54 @@
|
|||
---
|
||||
description: 'Specialized agent for developing and enhancing the Streamlit dashboard for data and training analysis.'
|
||||
name: Dashboard-Developer
|
||||
argument-hint: 'Describe dashboard features, pages, visualizations, or improvements you want to add or modify'
|
||||
tools: ['edit', 'runNotebooks', 'search', 'runCommands', 'usages', 'problems', 'changes', 'testFailure', 'fetch', 'githubRepo', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'ms-toolsai.jupyter/configureNotebook', 'ms-toolsai.jupyter/listNotebookPackages', 'ms-toolsai.jupyter/installNotebookPackages', 'todos', 'runSubagent', 'runTests']
|
||||
description: Develop and refactor Streamlit dashboard pages and visualizations
|
||||
name: Dashboard
|
||||
argument-hint: Describe dashboard features, pages, or visualizations to add or modify
|
||||
tools: ['vscode', 'execute', 'read', 'edit', 'search', 'web', 'agent', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'todo']
|
||||
model: Claude Sonnet 4.5
|
||||
infer: true
|
||||
---
|
||||
|
||||
# Dashboard Development Agent
|
||||
|
||||
You are a specialized agent for incrementally developing and enhancing the **Entropice Streamlit Dashboard** used to analyze geospatial machine learning data and training experiments.
|
||||
You specialize in developing and refactoring the **Entropice Streamlit Dashboard** for geospatial machine learning analysis.
|
||||
|
||||
## Your Responsibilities
|
||||
## Scope
|
||||
|
||||
### What You Should Do
|
||||
**You can edit:** Files in `src/entropice/dashboard/` only
|
||||
**You cannot edit:** Data pipeline scripts, training code, or configuration files
|
||||
|
||||
1. **Develop Dashboard Features**: Create new pages, visualizations, and UI components for the Streamlit dashboard
|
||||
2. **Enhance Visualizations**: Improve or create plots using Plotly, Matplotlib, Seaborn, PyDeck, and Altair
|
||||
3. **Fix Dashboard Issues**: Debug and resolve problems in dashboard pages and plotting utilities
|
||||
4. **Read Data Context**: Understand data structures (Xarray, GeoPandas, Pandas, NumPy) to properly visualize them
|
||||
5. **Consult Documentation**: Use #tool:fetch to read library documentation when needed:
|
||||
**Primary reference:** Always consult `views/overview_page.py` for current code patterns
|
||||
|
||||
## Responsibilities
|
||||
|
||||
### ✅ What You Do
|
||||
|
||||
- Create/refactor dashboard pages in `views/`
|
||||
- Build visualizations using Plotly, Matplotlib, Seaborn, PyDeck, Altair
|
||||
- Fix dashboard bugs and improve UI/UX
|
||||
- Create utility functions in `utils/` and `plots/`
|
||||
- Read (but never edit) data pipeline code to understand data structures
|
||||
- Use #tool:web to fetch library documentation:
|
||||
- Streamlit: https://docs.streamlit.io/
|
||||
- Plotly: https://plotly.com/python/
|
||||
- PyDeck: https://deckgl.readthedocs.io/
|
||||
- Deck.gl: https://deck.gl/
|
||||
- Matplotlib: https://matplotlib.org/
|
||||
- Seaborn: https://seaborn.pydata.org/
|
||||
- Xarray: https://docs.xarray.dev/
|
||||
- GeoPandas: https://geopandas.org/
|
||||
- Pandas: https://pandas.pydata.org/pandas-docs/
|
||||
- NumPy: https://numpy.org/doc/stable/
|
||||
|
||||
6. **Understand Data Sources**: Read data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`, `dataset.py`, `training.py`, `inference.py`) to understand data structures—but **NEVER edit them**
|
||||
### ❌ What You Don't Do
|
||||
|
||||
### What You Should NOT Do
|
||||
- Edit files outside `src/entropice/dashboard/`
|
||||
- Modify data pipeline (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`)
|
||||
- Change training code (`training.py`, `dataset.py`, `inference.py`)
|
||||
- Edit configuration (`pyproject.toml`, `scripts/*.sh`)
|
||||
|
||||
1. **Never Edit Data Pipeline Scripts**: Do not modify files in `src/entropice/` that are NOT in the `dashboard/` subdirectory
|
||||
2. **Never Edit Training Scripts**: Do not modify `training.py`, `dataset.py`, or any model-related code outside the dashboard
|
||||
3. **Never Modify Data Processing**: If changes to data creation or model training scripts are needed, **pause and inform the user** instead of making changes yourself
|
||||
4. **Never Edit Configuration Files**: Do not modify `pyproject.toml`, pipeline scripts in `scripts/`, or configuration files
|
||||
### When to Stop
|
||||
|
||||
### Boundaries
|
||||
If a dashboard feature requires changes outside `dashboard/`, stop and inform:
|
||||
|
||||
If you identify that a dashboard improvement requires changes to:
|
||||
- Data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`)
|
||||
- Dataset assembly (`dataset.py`)
|
||||
- Model training (`training.py`, `inference.py`)
|
||||
- Pipeline automation scripts (`scripts/*.sh`)
|
||||
|
||||
**Stop immediately** and inform the user:
|
||||
```
|
||||
⚠️ This dashboard feature requires changes to the data pipeline/training code.
|
||||
Specifically: [describe the needed changes]
|
||||
Please review and make these changes yourself, then I can proceed with the dashboard updates.
|
||||
⚠️ This requires changes to [file/module]
|
||||
Needed: [describe changes]
|
||||
Please make these changes first, then I can update the dashboard.
|
||||
```
|
||||
|
||||
## Dashboard Structure
|
||||
|
|
@ -60,23 +58,28 @@ The dashboard is located in `src/entropice/dashboard/` with the following struct
|
|||
```
|
||||
dashboard/
|
||||
├── app.py # Main Streamlit app with navigation
|
||||
├── overview_page.py # Overview of training results
|
||||
├── training_data_page.py # Training data visualizations
|
||||
├── training_analysis_page.py # CV results and hyperparameter analysis
|
||||
├── model_state_page.py # Feature importance and model state
|
||||
├── inference_page.py # Spatial prediction visualizations
|
||||
├── views/ # Dashboard pages
|
||||
│ ├── overview_page.py # Overview of training results and dataset analysis
|
||||
│ ├── training_data_page.py # Training data visualizations (needs refactoring)
|
||||
│ ├── training_analysis_page.py # CV results and hyperparameter analysis (needs refactoring)
|
||||
│ ├── model_state_page.py # Feature importance and model state (needs refactoring)
|
||||
│ └── inference_page.py # Spatial prediction visualizations (needs refactoring)
|
||||
├── plots/ # Reusable plotting utilities
|
||||
│ ├── colors.py # Color schemes
|
||||
│ ├── hyperparameter_analysis.py
|
||||
│ ├── inference.py
|
||||
│ ├── model_state.py
|
||||
│ ├── source_data.py
|
||||
│ └── training_data.py
|
||||
└── utils/ # Data loading and processing
|
||||
├── data.py
|
||||
└── training.py
|
||||
└── utils/ # Data loading and processing utilities
|
||||
├── loaders.py # Data loaders (training results, grid data, predictions)
|
||||
├── stats.py # Dataset statistics computation and caching
|
||||
├── colors.py # Color palette management
|
||||
├── formatters.py # Display formatting utilities
|
||||
└── unsembler.py # Dataset ensemble utilities
|
||||
```
|
||||
|
||||
**Note:** Currently only `overview_page.py` has been refactored to follow the new patterns. Other pages need updating to match this structure.
|
||||
|
||||
## Key Technologies
|
||||
|
||||
- **Streamlit**: Web app framework
|
||||
|
|
@ -120,6 +123,79 @@ When working with Entropice data:
|
|||
3. **Training Results**: Pickled models, Parquet/NetCDF CV results
|
||||
4. **Predictions**: GeoDataFrames with predicted classes/probabilities
|
||||
|
||||
### Dashboard Code Patterns
|
||||
|
||||
**Follow these patterns when developing or refactoring dashboard pages:**
|
||||
|
||||
1. **Modular Render Functions**: Break pages into focused render functions
|
||||
```python
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
# Implementation
|
||||
|
||||
def render_feature_count_section():
|
||||
"""Render the feature count section with comparison and explorer."""
|
||||
# Implementation
|
||||
```
|
||||
|
||||
2. **Use `@st.fragment` for Interactive Components**: Isolate reactive UI elements
|
||||
```python
|
||||
@st.fragment
|
||||
def render_feature_count_explorer():
|
||||
"""Render interactive detailed configuration explorer using fragments."""
|
||||
# Interactive selectboxes and checkboxes that re-run independently
|
||||
```
|
||||
|
||||
3. **Cached Data Loading via Utilities**: Use centralized loaders from `utils/loaders.py`
|
||||
```python
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import load_all_default_dataset_statistics
|
||||
|
||||
training_results = load_all_training_results() # Cached via @st.cache_data
|
||||
all_stats = load_all_default_dataset_statistics() # Cached via @st.cache_data
|
||||
```
|
||||
|
||||
4. **Consistent Color Palettes**: Use `get_palette()` from `utils/colors.py`
|
||||
```python
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
|
||||
task_colors = get_palette("task_types", n_colors=n_tasks)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
```
|
||||
|
||||
5. **Type Hints and Type Casting**: Use types from `entropice.utils.types`
|
||||
```python
|
||||
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
|
||||
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
```
|
||||
|
||||
6. **Tab-Based Organization**: Use tabs to organize complex visualizations
|
||||
```python
|
||||
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||
with tab1:
|
||||
# Heatmap visualization
|
||||
with tab2:
|
||||
# Bar chart visualization
|
||||
```
|
||||
|
||||
7. **Layout with Columns**: Use columns for metrics and side-by-side content
|
||||
```python
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
```
|
||||
|
||||
8. **Comprehensive Docstrings**: Document render functions clearly
|
||||
```python
|
||||
def render_training_results_summary(training_results):
|
||||
"""Render summary metrics for training results."""
|
||||
# Implementation
|
||||
```
|
||||
|
||||
### Visualization Guidelines
|
||||
|
||||
1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps
|
||||
|
|
@ -127,50 +203,79 @@ When working with Entropice data:
|
|||
3. **Distributions**: Use Plotly or Seaborn
|
||||
4. **Feature Importance**: Use Plotly bar charts
|
||||
5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates
|
||||
6. **Heatmaps**: Use `px.imshow()` with color palettes from `get_palette()`
|
||||
7. **Interactive Tables**: Use `st.dataframe()` with `width='stretch'` and formatting
|
||||
|
||||
### Key Utility Modules
|
||||
|
||||
**`utils/loaders.py`**: Data loading with Streamlit caching
|
||||
- `load_all_training_results()`: Load all training result directories
|
||||
- `load_training_result(path)`: Load specific training result
|
||||
- `TrainingResult` dataclass: Structured training result data
|
||||
|
||||
**`utils/stats.py`**: Dataset statistics computation
|
||||
- `load_all_default_dataset_statistics()`: Load/compute stats for all grid configs
|
||||
- `DatasetStatistics` class: Statistics per grid configuration
|
||||
- `MemberStatistics` class: Statistics per L2 source dataset
|
||||
- `TargetStatistics` class: Statistics per target dataset
|
||||
- Helper methods: `get_sample_count_df()`, `get_feature_count_df()`, `get_feature_breakdown_df()`
|
||||
|
||||
**`utils/colors.py`**: Consistent color palette management
|
||||
- `get_palette(variable, n_colors)`: Get color palette by semantic variable name
|
||||
- `get_cmap(variable)`: Get matplotlib colormap
|
||||
- "Refactor training_data_page.py to match the patterns in overview_page.py"
|
||||
- "Add a new tab to the overview page showing temporal statistics"
|
||||
- "Create a reusable plotting function in plots/ for feature importance"
|
||||
- Uses pypalettes material design palettes with deterministic mapping
|
||||
|
||||
**`utils/formatters.py`**: Display formatting utilities
|
||||
- `ModelDisplayInfo`: Model name formatting
|
||||
- `TaskDisplayInfo`: Task name formatting
|
||||
- `TrainingResultDisplayInfo`: Training result display names
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Understand the Request**: Clarify what visualization or feature is needed
|
||||
2. **Search for Context**: Use #tool:search to find relevant dashboard code and data structures
|
||||
3. **Read Data Pipeline**: If needed, read (but don't edit) data pipeline scripts to understand data formats
|
||||
4. **Consult Documentation**: Use #tool:fetch for library documentation when needed
|
||||
5. **Implement Changes**: Edit dashboard files only
|
||||
6. **Test Assumptions**: Check for errors with #tool:problems after edits
|
||||
7. **Track Progress**: Use #tool:todos for multi-step dashboard development
|
||||
1. Check `views/overview_page.py` for current patterns
|
||||
2. Use #tool:search to find relevant code and data structures
|
||||
3. Read data pipeline code if needed (read-only)
|
||||
4. Leverage existing utilities from `utils/`
|
||||
5. Use #tool:web to fetch documentation when needed
|
||||
6. Implement changes following overview_page.py patterns
|
||||
7. Use #tool:todo for multi-step tasks
|
||||
|
||||
## Example Interactions
|
||||
## Refactoring Checklist
|
||||
|
||||
### ✅ Good Requests (Within Scope)
|
||||
When updating pages to match new patterns:
|
||||
|
||||
- "Add a new page to visualize feature correlations"
|
||||
- "Create a PyDeck map showing RTS predictions by grid cell"
|
||||
- "Improve the hyperparameter analysis plot to show confidence intervals"
|
||||
- "Add a Plotly histogram showing the distribution of RTS density"
|
||||
- "Fix the deprecation warning about use_container_width"
|
||||
1. Move to `views/` subdirectory
|
||||
2. Use cached loaders from `utils/loaders.py` and `utils/stats.py`
|
||||
3. Split into focused `render_*()` functions
|
||||
4. Wrap interactive UI with `@st.fragment`
|
||||
5. Replace hardcoded colors with `get_palette()`
|
||||
6. Add type hints from `entropice.utils.types`
|
||||
7. Organize with tabs for complex views
|
||||
8. Use `width='stretch'` for charts/tables
|
||||
9. Add comprehensive docstrings
|
||||
10. Reference `overview_page.py` patterns
|
||||
|
||||
### ⚠️ Boundary Cases (Requires User Approval)
|
||||
## Example Tasks
|
||||
|
||||
User: "Add a new climate variable to the dashboard"
|
||||
Agent Response:
|
||||
```
|
||||
⚠️ This requires changes to the data pipeline (era5.py) to extract the new variable.
|
||||
Please add the variable to the ERA5 processing pipeline first, then I can add it to the dashboard visualizations.
|
||||
```
|
||||
**✅ In Scope:**
|
||||
- "Add feature correlation heatmap to overview page"
|
||||
- "Create PyDeck map for RTS predictions"
|
||||
- "Refactor training_data_page.py to match overview_page.py patterns"
|
||||
- "Fix use_container_width deprecation warnings"
|
||||
- "Add temporal statistics tab"
|
||||
|
||||
## Progress Reporting
|
||||
**⚠️ Out of Scope:**
|
||||
- "Add new climate variable" → Requires changes to `era5.py`
|
||||
- "Change training metrics" → Requires changes to `training.py`
|
||||
- "Modify grid generation" → Requires changes to `grids.py`
|
||||
|
||||
For complex dashboard development tasks:
|
||||
## Key Reminders
|
||||
|
||||
1. Use #tool:todos to create a task list
|
||||
2. Mark tasks as in-progress before starting
|
||||
3. Mark completed immediately after finishing
|
||||
4. Keep the user informed of progress
|
||||
|
||||
## Remember
|
||||
|
||||
- **Read-only for data pipeline**: You can read any file to understand data structures, but only edit `dashboard/` files
|
||||
- **Documentation first**: When unsure about Streamlit/Plotly/PyDeck APIs, fetch documentation
|
||||
- **Modern Streamlit API**: Always use `width='stretch'` instead of `use_container_width=True`
|
||||
- **Pause when needed**: If data pipeline changes are required, stop and inform the user
|
||||
|
||||
You are here to make the dashboard better, not to change how data is created or models are trained. Stay within these boundaries and you'll be most helpful!
|
||||
- Only edit files in `dashboard/`
|
||||
- Use `width='stretch'` not `use_container_width=True`
|
||||
- Always reference `overview_page.py` for patterns
|
||||
- Use #tool:web for documentation
|
||||
- Use #tool:todo for complex multi-step work
|
||||
|
|
|
|||
242
.github/copilot-instructions.md
vendored
242
.github/copilot-instructions.md
vendored
|
|
@ -1,53 +1,49 @@
|
|||
# Entropice - GitHub Copilot Instructions
|
||||
# Entropice - Copilot Instructions
|
||||
|
||||
## Project Overview
|
||||
## Project Context
|
||||
|
||||
Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. The system processes multi-source geospatial data, aggregates it into discrete global grids (H3/HEALPix), and trains probabilistic classifiers to estimate RTS occurrence patterns.
|
||||
This is a geospatial machine learning system for predicting Arctic permafrost degradation (Retrogressive Thaw Slumps) using entropy-optimal Scalable Probabilistic Approximations (eSPA). The system processes multi-source geospatial data through discrete global grid systems (H3, HEALPix) and trains probabilistic classifiers.
|
||||
|
||||
For detailed architecture information, see [ARCHITECTURE.md](../ARCHITECTURE.md).
|
||||
For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
|
||||
For project goals and setup, see [README.md](../README.md).
|
||||
## Code Style
|
||||
|
||||
## Core Technologies
|
||||
- Follow PEP 8 conventions with 120 character line length
|
||||
- Use type hints for all function signatures
|
||||
- Use google-style docstrings for public functions
|
||||
- Keep functions focused and modular
|
||||
- Use `ruff` for linting and formatting, `ty` for type checking
|
||||
|
||||
- **Python**: 3.13 (strict version requirement)
|
||||
- **Package Manager**: [Pixi](https://pixi.sh/) (not pip/conda directly)
|
||||
- **GPU**: CUDA 12 with RAPIDS (CuPy, cuML)
|
||||
- **Geospatial**: xarray, xdggs, GeoPandas, H3, Rasterio
|
||||
- **ML**: scikit-learn, XGBoost, entropy (eSPA), cuML
|
||||
- **Storage**: Zarr, Icechunk, Parquet, NetCDF
|
||||
- **Visualization**: Streamlit, Bokeh, Matplotlib, Cartopy
|
||||
## Technology Stack
|
||||
|
||||
## Code Style Guidelines
|
||||
- **Core**: Python 3.13, NumPy, Pandas, Xarray, GeoPandas
|
||||
- **Spatial**: H3, xdggs, xvec for discrete global grid systems
|
||||
- **ML**: scikit-learn, XGBoost, cuML, entropy (eSPA)
|
||||
- **GPU**: CuPy, PyTorch, CUDA 12 - prefer GPU-accelerated operations
|
||||
- **Storage**: Zarr, Icechunk, Parquet for intermediate data
|
||||
- **CLI**: Cyclopts with dataclass-based configurations
|
||||
|
||||
### Python Standards
|
||||
## Execution Guidelines
|
||||
|
||||
- Follow **PEP 8** conventions
|
||||
- Use **type hints** for all function signatures
|
||||
- Write **numpy-style docstrings** for public functions
|
||||
- Keep functions **focused and modular**
|
||||
- Prefer descriptive variable names over abbreviations
|
||||
- Always use `pixi run` to execute Python commands and scripts
|
||||
- Environment variables: `SCIPY_ARRAY_API=1`, `FAST_DATA_DIR=./data`
|
||||
|
||||
### Geospatial Best Practices
|
||||
## Geospatial Best Practices
|
||||
|
||||
- Use **EPSG:3413** (Arctic Stereographic) for computations
|
||||
- Use **EPSG:4326** (WGS84) for visualization and library compatibility
|
||||
- Store gridded data using **xarray with XDGGS** indexing
|
||||
- Store tabular data as **Parquet**, array data as **Zarr**
|
||||
- Leverage **Dask** for lazy evaluation of large datasets
|
||||
- Use **GeoPandas** for vector operations
|
||||
- Handle **antimeridian** correctly for polar regions
|
||||
- Use EPSG:3413 (Arctic Stereographic) for computations
|
||||
- Use EPSG:4326 (WGS84) for visualization and compatibility
|
||||
- Store gridded data as XDGGS Xarray datasets (Zarr format)
|
||||
- Store tabular data as GeoParquet
|
||||
- Handle antimeridian issues in polar regions
|
||||
- Leverage Xarray/Dask for lazy evaluation and chunked processing
|
||||
|
||||
### Data Pipeline Conventions
|
||||
## Architecture Patterns
|
||||
|
||||
- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → `02alphaearth.sh` → `03era5.sh` → `04arcticdem.sh` → `05train.sh`
|
||||
- Each pipeline stage should produce **reproducible intermediate outputs**
|
||||
- Use `src/entropice/utils/paths.py` for consistent path management
|
||||
- Environment variable `FAST_DATA_DIR` controls data directory location (default: `./data`)
|
||||
- Modular CLI design: each module exposes standalone Cyclopts CLI
|
||||
- Configuration as code: use dataclasses for typed configs, TOML for hyperparameters
|
||||
- GPU acceleration: use CuPy for arrays, cuML for ML, batch processing for memory management
|
||||
- Data flow: Raw sources → Grid aggregation → L2 datasets → Training → Inference → Visualization
|
||||
|
||||
### Storage Hierarchy
|
||||
## Data Storage Hierarchy
|
||||
|
||||
All data follows this structure:
|
||||
```
|
||||
DATA_DIR/
|
||||
├── grids/ # H3/HEALPix tessellations (GeoParquet)
|
||||
|
|
@ -56,171 +52,19 @@ DATA_DIR/
|
|||
├── arcticdem/ # Terrain data (Icechunk Zarr)
|
||||
├── alphaearth/ # Satellite embeddings (Zarr)
|
||||
├── datasets/ # L2 XDGGS datasets (Zarr)
|
||||
├── training-results/ # Models, CV results, predictions
|
||||
└── watermask/ # Ocean mask (GeoParquet)
|
||||
└── training-results/ # Models, CV results, predictions
|
||||
```
|
||||
|
||||
## Module Organization
|
||||
## Key Modules
|
||||
|
||||
### Core Modules (`src/entropice/`)
|
||||
- `entropice.spatial`: Grid generation and raster-to-vector aggregation
|
||||
- `entropice.ingest`: Data extractors (DARTS, ERA5, ArcticDEM, AlphaEarth)
|
||||
- `entropice.ml`: Dataset assembly, training, inference
|
||||
- `entropice.dashboard`: Streamlit visualization app
|
||||
- `entropice.utils`: Paths, codecs, types
|
||||
|
||||
The codebase is organized into four main packages:
|
||||
## Testing & Notebooks
|
||||
|
||||
- **`entropice.ingest`**: Data ingestion from external sources
|
||||
- **`entropice.spatial`**: Spatial operations and grid management
|
||||
- **`entropice.ml`**: Machine learning workflows
|
||||
- **`entropice.utils`**: Common utilities
|
||||
|
||||
#### Data Ingestion (`src/entropice/ingest/`)
|
||||
|
||||
- **`darts.py`**: RTS label extraction from DARTS v2 dataset
|
||||
- **`era5.py`**: Climate data processing from ERA5 (Arctic-aligned years: Oct 1 - Sep 30)
|
||||
- **`arcticdem.py`**: Terrain analysis from 32m Arctic elevation data
|
||||
- **`alphaearth.py`**: Satellite image embeddings via Google Earth Engine
|
||||
|
||||
#### Spatial Operations (`src/entropice/spatial/`)
|
||||
|
||||
- **`grids.py`**: H3/HEALPix spatial grid generation with watermask
|
||||
- **`aggregators.py`**: Raster-to-vector spatial aggregation engine
|
||||
- **`watermask.py`**: Ocean masking utilities
|
||||
- **`xvec.py`**: Extended vector operations for xarray
|
||||
|
||||
#### Machine Learning (`src/entropice/ml/`)
|
||||
|
||||
- **`dataset.py`**: Multi-source data integration and feature engineering
|
||||
- **`training.py`**: Model training with eSPA, XGBoost, Random Forest, KNN
|
||||
- **`inference.py`**: Batch prediction pipeline for trained models
|
||||
|
||||
#### Utilities (`src/entropice/utils/`)
|
||||
|
||||
- **`paths.py`**: Centralized path management
|
||||
- **`codecs.py`**: Custom codecs for data serialization
|
||||
|
||||
### Dashboard (`src/entropice/dashboard/`)
|
||||
|
||||
- Streamlit-based interactive visualization
|
||||
- Modular pages: overview, training data, analysis, model state, inference
|
||||
- Bokeh-based geospatial plotting utilities
|
||||
- Run with: `pixi run dashboard`
|
||||
|
||||
### Scripts (`scripts/`)
|
||||
|
||||
- Numbered pipeline scripts (`00grids.sh` through `05train.sh`)
|
||||
- Run entire pipeline for multiple grid configurations
|
||||
- Each script uses CLIs from core modules
|
||||
|
||||
### Notebooks (`notebooks/`)
|
||||
|
||||
- Exploratory analysis and validation
|
||||
- **NOT committed to git**
|
||||
- Keep production code in `src/entropice/`
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
pixi install # NOT pip install or conda install
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pixi run pytest
|
||||
```
|
||||
|
||||
### Running Python Commands
|
||||
|
||||
Always use `pixi run` to execute Python commands to use the correct environment:
|
||||
|
||||
```bash
|
||||
pixi run python script.py
|
||||
pixi run python -c "import entropice"
|
||||
```
|
||||
|
||||
### Common Tasks
|
||||
|
||||
**Important**: Always use `pixi run` prefix for Python commands to ensure correct environment.
|
||||
|
||||
- **Generate grids**: Use `pixi run create-grid` or `spatial/grids.py` CLI
|
||||
- **Process labels**: Use `pixi run darts` or `ingest/darts.py` CLI
|
||||
- **Train models**: Use `pixi run train` with TOML config or `ml/training.py` CLI
|
||||
- **Run inference**: Use `ml/inference.py` CLI
|
||||
- **View results**: `pixi run dashboard`
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### 1. XDGGS Indexing
|
||||
|
||||
All geospatial data uses discrete global grid systems (H3 or HEALPix) via `xdggs` library for consistent spatial indexing across sources.
|
||||
|
||||
### 2. Lazy Evaluation
|
||||
|
||||
Use Xarray/Dask for out-of-core computation with Zarr/Icechunk chunked storage to manage large datasets.
|
||||
|
||||
### 3. GPU Acceleration
|
||||
|
||||
Prefer GPU-accelerated operations:
|
||||
- CuPy for array operations
|
||||
- cuML for Random Forest and KNN
|
||||
- XGBoost GPU training
|
||||
- PyTorch tensors when applicable
|
||||
|
||||
### 4. Configuration as Code
|
||||
|
||||
- Use dataclasses for typed configuration
|
||||
- TOML files for training hyperparameters
|
||||
- Cyclopts for CLI argument parsing
|
||||
|
||||
## Data Sources
|
||||
|
||||
- **DARTS v2**: RTS labels (year, area, count, density)
|
||||
- **ERA5**: Climate data (40-year history, Arctic-aligned years)
|
||||
- **ArcticDEM**: 32m resolution terrain (slope, aspect, indices)
|
||||
- **AlphaEarth**: 64-dimensional satellite embeddings
|
||||
- **Watermask**: Ocean exclusion layer
|
||||
|
||||
## Model Support
|
||||
|
||||
Primary: **eSPA** (entropy-optimal Scalable Probabilistic Approximations)
|
||||
Alternatives: XGBoost, Random Forest, K-Nearest Neighbors
|
||||
|
||||
Training features:
|
||||
- Randomized hyperparameter search
|
||||
- K-Fold cross-validation
|
||||
- Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall)
|
||||
|
||||
## Extension Points
|
||||
|
||||
To extend Entropice:
|
||||
|
||||
- **New data source**: Follow patterns in `ingest/era5.py` or `ingest/arcticdem.py`
|
||||
- **Custom aggregations**: Add to `_Aggregations` dataclass in `spatial/aggregators.py`
|
||||
- **Alternative labels**: Implement extractor following `ingest/darts.py` pattern
|
||||
- **New models**: Add scikit-learn compatible estimators to `ml/training.py`
|
||||
- **Dashboard pages**: Add Streamlit pages to `dashboard/` module
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Always use `pixi run` prefix** for Python commands (not plain `python`)
|
||||
- Grid resolutions: **H3** (3-6), **HEALPix** (6-10)
|
||||
- Arctic years run **October 1 to September 30** (not calendar years)
|
||||
- Handle **antimeridian crossing** in polar regions
|
||||
- Use **batch processing** for GPU memory management
|
||||
- Notebooks are for exploration only - **keep production code in `src/`**
|
||||
- Always use **absolute paths** or paths from `utils/paths.py`
|
||||
|
||||
## Common Issues
|
||||
|
||||
- **Memory**: Use batch processing and Dask chunking for large datasets
|
||||
- **GPU OOM**: Reduce batch size in inference or training
|
||||
- **Antimeridian**: Use proper handling in `spatial/aggregators.py` for polar grids
|
||||
- **Temporal alignment**: ERA5 uses Arctic-aligned years (Oct-Sep)
|
||||
- **CRS**: Compute in EPSG:3413, visualize in EPSG:4326
|
||||
|
||||
## References
|
||||
|
||||
For more details, consult:
|
||||
- [ARCHITECTURE.md](../ARCHITECTURE.md) - System architecture and design patterns
|
||||
- [CONTRIBUTING.md](../CONTRIBUTING.md) - Development workflow and standards
|
||||
- [README.md](../README.md) - Project goals and setup instructions
|
||||
- Production code belongs in `src/entropice/`, not notebooks
|
||||
- Notebooks in `notebooks/` are for exploration only (not version-controlled)
|
||||
- Use `pytest` for testing geospatial correctness and data integrity
|
||||
|
|
|
|||
7
.github/python.instructions.md
vendored
7
.github/python.instructions.md
vendored
|
|
@ -10,7 +10,7 @@ applyTo: '**/*.py,**/*.ipynb'
|
|||
- Write clear and concise comments for each function.
|
||||
- Ensure functions have descriptive names and include type hints.
|
||||
- Provide docstrings following PEP 257 conventions.
|
||||
- Use the `typing` module for type annotations (e.g., `List[str]`, `Dict[str, int]`).
|
||||
- Use the `typing` module for advanced type annotations (e.g., `TypedDict`, `Literal["a", "b", ...]`).
|
||||
- Break down complex functions into smaller, more manageable functions.
|
||||
|
||||
## General Instructions
|
||||
|
|
@ -27,7 +27,7 @@ applyTo: '**/*.py,**/*.ipynb'
|
|||
|
||||
- Follow the **PEP 8** style guide for Python.
|
||||
- Maintain proper indentation (use 4 spaces for each level of indentation).
|
||||
- Ensure lines do not exceed 79 characters.
|
||||
- Ensure lines do not exceed 120 characters.
|
||||
- Place function and class docstrings immediately after the `def` or `class` keyword.
|
||||
- Use blank lines to separate functions, classes, and code blocks where appropriate.
|
||||
|
||||
|
|
@ -41,6 +41,8 @@ applyTo: '**/*.py,**/*.ipynb'
|
|||
## Example of Proper Documentation
|
||||
|
||||
```python
|
||||
import math
|
||||
|
||||
def calculate_area(radius: float) -> float:
|
||||
"""
|
||||
Calculate the area of a circle given the radius.
|
||||
|
|
@ -51,6 +53,5 @@ def calculate_area(radius: float) -> float:
|
|||
Returns:
|
||||
float: The area of the circle, calculated as π * radius^2.
|
||||
"""
|
||||
import math
|
||||
return math.pi * radius ** 2
|
||||
```
|
||||
|
|
|
|||
|
|
@ -27,12 +27,20 @@ The pipeline follows a sequential processing approach where each stage produces
|
|||
|
||||
### System Components
|
||||
|
||||
The codebase is organized into four main packages:
|
||||
The project in general is organized into four directories:
|
||||
|
||||
- **`src/entropice/`**: The main processing codebase
|
||||
- **`scripts/`**: Bash scripts and wrapper for data processing pipeline scripts
|
||||
- **`notebooks/`**: Exploratory analysis and validation notebooks (not commited to git)
|
||||
- **`tests/`**: Unit tests and manual validation scripts
|
||||
|
||||
The processing codebase is organized into five main packages:
|
||||
|
||||
- **`entropice.ingest`**: Data ingestion from external sources (DARTS, ERA5, ArcticDEM, AlphaEarth)
|
||||
- **`entropice.spatial`**: Spatial operations and grid management
|
||||
- **`entropice.ml`**: Machine learning workflows (dataset, training, inference)
|
||||
- **`entropice.utils`**: Common utilities (paths, codecs)
|
||||
- **`entropice.dashboard`**: Streamlit Dashboard for interactive visualization
|
||||
|
||||
#### 1. Spatial Grid System (`spatial/grids.py`)
|
||||
|
||||
|
|
@ -179,6 +187,14 @@ scripts/05train.sh # Model training
|
|||
- TOML files for training hyperparameters
|
||||
- Environment-based path management (`utils/paths.py`)
|
||||
|
||||
### 6. Geospatial Best Practices
|
||||
|
||||
- Use **xarray** with XDGGS for gridded data storage
|
||||
- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays)
|
||||
- Leverage **Dask** for lazy evaluation of large datasets
|
||||
- Use **GeoPandas** for vector operations
|
||||
- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries
|
||||
|
||||
## Data Storage Hierarchy
|
||||
|
||||
```sh
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ Thank you for your interest in contributing to Entropice! This document provides
|
|||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.13
|
||||
- CUDA 12 compatible GPU (for full functionality)
|
||||
- [Pixi package manager](https://pixi.sh/)
|
||||
|
||||
|
|
@ -16,55 +15,36 @@ Thank you for your interest in contributing to Entropice! This document provides
|
|||
pixi install
|
||||
```
|
||||
|
||||
This will set up the complete environment including RAPIDS, PyTorch, and all geospatial dependencies.
|
||||
This will set up the complete environment including Python, RAPIDS, PyTorch, and all geospatial dependencies.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies.
|
||||
> Read in the [Architecture Guide](ARCHITECTURE.md) about the code organisatoin and key modules
|
||||
|
||||
### Code Organization
|
||||
**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies:
|
||||
|
||||
- **`src/entropice/ingest/`**: Data ingestion modules (darts, era5, arcticdem, alphaearth)
|
||||
- **`src/entropice/spatial/`**: Spatial operations (grids, aggregators, watermask, xvec)
|
||||
- **`src/entropice/ml/`**: Machine learning components (dataset, training, inference)
|
||||
- **`src/entropice/utils/`**: Utilities (paths, codecs)
|
||||
- **`src/entropice/dashboard/`**: Streamlit visualization dashboard
|
||||
- **`scripts/`**: Data processing pipeline scripts (numbered 00-05)
|
||||
- **`notebooks/`**: Exploratory analysis and validation notebooks
|
||||
- **`tests/`**: Unit tests
|
||||
```bash
|
||||
pixi run python script.py
|
||||
pixi run python -c "import entropice"
|
||||
```
|
||||
|
||||
### Key Modules
|
||||
|
||||
- `spatial/grids.py`: H3/HEALPix spatial grid systems
|
||||
- `ingest/darts.py`, `ingest/era5.py`, `ingest/arcticdem.py`, `ingest/alphaearth.py`: Data source processors
|
||||
- `ml/dataset.py`: Dataset assembly and feature engineering
|
||||
- `ml/training.py`: Model training with eSPA, XGBoost, Random Forest, KNN
|
||||
- `ml/inference.py`: Prediction generation
|
||||
- `utils/paths.py`: Centralized path management
|
||||
|
||||
## Coding Standards
|
||||
|
||||
### Python Style
|
||||
### Python Style and Formatting
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use type hints for function signatures
|
||||
- Prefer numpy-style docstrings for public functions
|
||||
- Prefer google-style docstrings for public functions
|
||||
- Keep functions focused and modular
|
||||
|
||||
### Geospatial Best Practices
|
||||
`ty` and `ruff` are used for typing, linting and formatting.
|
||||
Ensure to check for any warnings from both of these:
|
||||
|
||||
- Use **xarray** with XDGGS for gridded data storage
|
||||
- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays)
|
||||
- Leverage **Dask** for lazy evaluation of large datasets
|
||||
- Use **GeoPandas** for vector operations
|
||||
- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries
|
||||
```sh
|
||||
pixi run ty check # For type checks
|
||||
pixi run ruff check # For linting
|
||||
pixi run ruff format # For formatting
|
||||
```
|
||||
|
||||
### Data Pipeline
|
||||
|
||||
- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → ... → `05train.sh`
|
||||
- Each stage should produce reproducible intermediate outputs
|
||||
- Document data dependencies in module docstrings
|
||||
- Use `utils/paths.py` for consistent path management
|
||||
Single files can be specified by just adding them to the command, e.g. `pixi run ty check src/entropice/dashboard/app.py`
|
||||
|
||||
## Testing
|
||||
|
||||
|
|
@ -74,13 +54,6 @@ Run tests for specific modules:
|
|||
pixi run pytest
|
||||
```
|
||||
|
||||
When running Python scripts or commands, always use `pixi run`:
|
||||
|
||||
```bash
|
||||
pixi run python script.py
|
||||
pixi run python -c "import entropice"
|
||||
```
|
||||
|
||||
When adding features, include tests that verify:
|
||||
|
||||
- Correct handling of geospatial coordinates and projections
|
||||
|
|
@ -100,15 +73,8 @@ When adding features, include tests that verify:
|
|||
### Commit Messages
|
||||
|
||||
- Use present tense: "Add feature" not "Added feature"
|
||||
- Reference issues when applicable: "Fix #123: Correct grid aggregation"
|
||||
- Keep first line under 72 characters
|
||||
|
||||
## Working with Data
|
||||
|
||||
### Local Development
|
||||
|
||||
- Set `FAST_DATA_DIR` environment variable for data directory (default: `./data`)
|
||||
|
||||
### Notebooks
|
||||
|
||||
- Notebooks in `notebooks/` are for exploration and validation, they are not commited to git
|
||||
|
|
|
|||
|
|
@ -100,6 +100,9 @@ cudf-cu12 = { index = "nvidia" }
|
|||
cuml-cu12 = { index = "nvidia" }
|
||||
cuspatial-cu12 = { index = "nvidia" }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint.pyflakes]
|
||||
# Ignore libraries when checking for unused imports
|
||||
allowed-unused-imports = [
|
||||
|
|
|
|||
|
|
@ -224,7 +224,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
alt.Chart(value_counts)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(f"{formatted_col}:N", title=param_name, sort=None),
|
||||
alt.X(
|
||||
f"{formatted_col}:N",
|
||||
title=param_name,
|
||||
sort=None,
|
||||
),
|
||||
alt.Y("count:Q", title="Count"),
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".2e"),
|
||||
|
|
@ -238,7 +242,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
alt.Chart(value_counts)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(f"{param_name}:Q", title=param_name, scale=x_scale),
|
||||
alt.X(
|
||||
f"{param_name}:Q",
|
||||
title=param_name,
|
||||
scale=x_scale,
|
||||
),
|
||||
alt.Y("count:Q", title="Count"),
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".3f"),
|
||||
|
|
@ -301,10 +309,18 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
alt.Chart(df_plot)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name),
|
||||
alt.X(
|
||||
f"{param_name}:Q",
|
||||
bin=alt.Bin(maxbins=n_bins),
|
||||
title=param_name,
|
||||
),
|
||||
alt.Y("count()", title="Count"),
|
||||
tooltip=[
|
||||
alt.Tooltip(f"{param_name}:Q", format=format_str, bin=True),
|
||||
alt.Tooltip(
|
||||
f"{param_name}:Q",
|
||||
format=format_str,
|
||||
bin=True,
|
||||
),
|
||||
"count()",
|
||||
],
|
||||
)
|
||||
|
|
@ -396,9 +412,15 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".2e"),
|
||||
alt.Tooltip(metric, format=".4f"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
height=300,
|
||||
title=f"{metric} vs {param_name} (log scale)",
|
||||
)
|
||||
.properties(height=300, title=f"{metric} vs {param_name} (log scale)")
|
||||
)
|
||||
else:
|
||||
chart = (
|
||||
|
|
@ -412,7 +434,10 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".3f"),
|
||||
alt.Tooltip(metric, format=".4f"),
|
||||
],
|
||||
)
|
||||
.properties(height=300, title=f"{metric} vs {param_name}")
|
||||
)
|
||||
|
|
@ -568,7 +593,16 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
|||
if len(param_names) == 2:
|
||||
# Simple case: just one pair
|
||||
x_param, y_param = param_names_sorted
|
||||
_render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500)
|
||||
_render_2d_param_plot(
|
||||
results_binned,
|
||||
x_param,
|
||||
y_param,
|
||||
score_col,
|
||||
bin_info,
|
||||
hex_colors,
|
||||
metric,
|
||||
height=500,
|
||||
)
|
||||
else:
|
||||
# Multiple parameters: create structured plots
|
||||
st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
|
||||
|
|
@ -599,7 +633,14 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
|||
|
||||
with cols[col_idx]:
|
||||
_render_2d_param_plot(
|
||||
results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350
|
||||
results_binned,
|
||||
x_param,
|
||||
y_param,
|
||||
score_col,
|
||||
bin_info,
|
||||
hex_colors,
|
||||
metric,
|
||||
height=350,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -708,7 +749,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
|
||||
# Get colormap for score evolution
|
||||
evolution_cmap = get_cmap("score_evolution")
|
||||
evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))]
|
||||
evolution_colors = [
|
||||
mcolors.rgb2hex(evolution_cmap(0.3)),
|
||||
mcolors.rgb2hex(evolution_cmap(0.7)),
|
||||
]
|
||||
|
||||
# Create line chart
|
||||
chart = (
|
||||
|
|
@ -717,13 +761,21 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
.encode(
|
||||
alt.X("Iteration", title="Iteration"),
|
||||
alt.Y("value", title=metric.replace("_", " ").title()),
|
||||
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)),
|
||||
alt.Color(
|
||||
"Type",
|
||||
legend=alt.Legend(title=""),
|
||||
scale=alt.Scale(range=evolution_colors),
|
||||
),
|
||||
strokeDash=alt.StrokeDash(
|
||||
"Type",
|
||||
legend=None,
|
||||
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
|
||||
),
|
||||
tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")],
|
||||
tooltip=[
|
||||
"Iteration",
|
||||
"Type",
|
||||
alt.Tooltip("value", format=".4f", title="Score"),
|
||||
],
|
||||
)
|
||||
.properties(height=400)
|
||||
)
|
||||
|
|
@ -1195,7 +1247,13 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
with col1:
|
||||
# Filter by confusion category
|
||||
if task == "binary":
|
||||
categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"]
|
||||
categories = [
|
||||
"All",
|
||||
"True Positive",
|
||||
"False Positive",
|
||||
"True Negative",
|
||||
"False Negative",
|
||||
]
|
||||
else:
|
||||
categories = ["All", "Correct", "Incorrect"]
|
||||
|
||||
|
|
@ -1206,7 +1264,14 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
)
|
||||
|
||||
with col2:
|
||||
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity")
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="confusion_map_opacity",
|
||||
)
|
||||
|
||||
# Filter data if needed
|
||||
if selected_category != "All":
|
||||
|
|
|
|||
|
|
@ -85,7 +85,11 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
|
|||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("year:O", title="Year"),
|
||||
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
|
||||
y=alt.Y(
|
||||
"band:O",
|
||||
title="Band",
|
||||
sort=alt.SortField(field="band", order="ascending"),
|
||||
),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="blues"),
|
||||
|
|
@ -105,7 +109,9 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
|
|||
return chart
|
||||
|
||||
|
||||
def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
def plot_embedding_aggregation_summary(
|
||||
embedding_array: xr.DataArray,
|
||||
) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
|
||||
|
||||
Args:
|
||||
|
|
@ -345,8 +351,18 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs()
|
||||
|
||||
# Create DataFrames
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values})
|
||||
df_variable = pd.DataFrame(
|
||||
{
|
||||
"dimension": by_variable.index.astype(str),
|
||||
"mean_abs_weight": by_variable.values,
|
||||
}
|
||||
)
|
||||
df_season = pd.DataFrame(
|
||||
{
|
||||
"dimension": by_season.index.astype(str),
|
||||
"mean_abs_weight": by_season.values,
|
||||
}
|
||||
)
|
||||
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
|
||||
|
||||
# Sort by weight
|
||||
|
|
@ -358,7 +374,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Variable",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=300),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -377,7 +398,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
alt.Chart(df_season)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Season",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=200),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -396,7 +422,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
alt.Chart(df_year)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
y=alt.Y(
|
||||
"dimension:O",
|
||||
title="Year",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=200),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -418,7 +449,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs()
|
||||
|
||||
# Create DataFrames, handling potential MultiIndex
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_variable = pd.DataFrame(
|
||||
{
|
||||
"dimension": by_variable.index.astype(str),
|
||||
"mean_abs_weight": by_variable.values,
|
||||
}
|
||||
)
|
||||
df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values})
|
||||
|
||||
# Sort by weight
|
||||
|
|
@ -430,7 +466,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Variable",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=300),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -449,7 +490,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
|||
alt.Chart(df_time)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Time",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=200),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -508,7 +554,9 @@ def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart:
|
|||
return chart
|
||||
|
||||
|
||||
def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
|
||||
def plot_arcticdem_summary(
|
||||
arcticdem_array: xr.DataArray,
|
||||
) -> tuple[alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing ArcticDEM weights by variable and aggregation.
|
||||
|
||||
Args:
|
||||
|
|
@ -523,7 +571,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
|
|||
by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs()
|
||||
|
||||
# Create DataFrames, handling potential MultiIndex
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_variable = pd.DataFrame(
|
||||
{
|
||||
"dimension": by_variable.index.astype(str),
|
||||
"mean_abs_weight": by_variable.values,
|
||||
}
|
||||
)
|
||||
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
|
||||
|
||||
# Sort by weight
|
||||
|
|
@ -535,7 +588,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
|
|||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Variable",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=300),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -554,7 +612,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al
|
|||
alt.Chart(df_agg)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
y=alt.Y(
|
||||
"dimension:N",
|
||||
title="Aggregation",
|
||||
sort="-x",
|
||||
axis=alt.Axis(labelLimit=200),
|
||||
),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
|
|
@ -665,7 +728,12 @@ def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str])
|
|||
alt.Chart(counts)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
|
||||
x=alt.X(
|
||||
"class:N",
|
||||
title="Class Label",
|
||||
sort=class_order,
|
||||
axis=alt.Axis(labelAngle=-45),
|
||||
),
|
||||
y=alt.Y("count:Q", title="Number of Boxes"),
|
||||
color=alt.Color(
|
||||
"class:N",
|
||||
|
|
@ -767,7 +835,10 @@ def plot_xgboost_feature_importance(
|
|||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"),
|
||||
x=alt.X(
|
||||
"importance:Q",
|
||||
title=f"{importance_type.replace('_', ' ').title()} Importance",
|
||||
),
|
||||
color=alt.value("steelblue"),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
|
|
@ -880,7 +951,9 @@ def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.
|
|||
return chart
|
||||
|
||||
|
||||
def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
def plot_rf_tree_statistics(
|
||||
model_state: xr.Dataset,
|
||||
) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
"""Plot Random Forest tree statistics.
|
||||
|
||||
Args:
|
||||
|
|
@ -908,7 +981,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
|
|||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("steelblue"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")],
|
||||
tooltip=[
|
||||
alt.Tooltip("count()", title="Count"),
|
||||
alt.Tooltip("value:Q", bin=True, title="Depth Range"),
|
||||
],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Tree Depths")
|
||||
)
|
||||
|
|
@ -921,7 +997,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
|
|||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("forestgreen"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")],
|
||||
tooltip=[
|
||||
alt.Tooltip("count()", title="Count"),
|
||||
alt.Tooltip("value:Q", bin=True, title="Leaves Range"),
|
||||
],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Leaf Counts")
|
||||
)
|
||||
|
|
@ -934,7 +1013,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha
|
|||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("darkorange"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")],
|
||||
tooltip=[
|
||||
alt.Tooltip("count()", title="Count"),
|
||||
alt.Tooltip("value:Q", bin=True, title="Nodes Range"),
|
||||
],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Node Counts")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -330,7 +330,10 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
|
|||
|
||||
with col3:
|
||||
time_values = pd.to_datetime(ds["time"].values)
|
||||
st.metric("Time Steps", f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}")
|
||||
st.metric(
|
||||
"Time Steps",
|
||||
f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}",
|
||||
)
|
||||
|
||||
with col4:
|
||||
if has_agg:
|
||||
|
|
@ -398,11 +401,15 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
|
|||
col1, col2, col3 = st.columns([2, 2, 1])
|
||||
with col1:
|
||||
selected_var = st.selectbox(
|
||||
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
|
||||
"Select variable to visualize",
|
||||
options=variables,
|
||||
key=f"era5_{temporal_type}_var_select",
|
||||
)
|
||||
with col2:
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select"
|
||||
"Aggregation",
|
||||
options=ds["aggregations"].values,
|
||||
key=f"era5_{temporal_type}_agg_select",
|
||||
)
|
||||
with col3:
|
||||
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
|
||||
|
|
@ -411,7 +418,9 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
|
|||
col1, col2 = st.columns([3, 1])
|
||||
with col1:
|
||||
selected_var = st.selectbox(
|
||||
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
|
||||
"Select variable to visualize",
|
||||
options=variables,
|
||||
key=f"era5_{temporal_type}_var_select",
|
||||
)
|
||||
with col2:
|
||||
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
|
||||
|
|
@ -614,7 +623,10 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
|||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": "<b>Value:</b> {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
tooltip={
|
||||
"html": "<b>Value:</b> {value}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
|
|
@ -746,7 +758,10 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
|||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
tooltip={
|
||||
"html": tooltip_html,
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
|
|
@ -784,7 +799,14 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
|
|||
)
|
||||
|
||||
with col2:
|
||||
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity")
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="areas_map_opacity",
|
||||
)
|
||||
|
||||
# Create GeoDataFrame
|
||||
gdf = grid_gdf.copy()
|
||||
|
|
@ -896,7 +918,9 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
|
|||
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
|
||||
with col2:
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
|
||||
"Aggregation",
|
||||
options=ds["aggregations"].values,
|
||||
key=f"era5_{temporal_type}_agg",
|
||||
)
|
||||
with col3:
|
||||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||||
|
|
@ -1022,7 +1046,10 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
|
|||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
tooltip={
|
||||
"html": tooltip_html,
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ from entropice.dashboard.utils.colors import get_palette
|
|||
from entropice.ml.dataset import CategoricalTrainingDataset
|
||||
|
||||
|
||||
def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
||||
def render_all_distribution_histograms(
|
||||
train_data_dict: dict[str, CategoricalTrainingDataset],
|
||||
):
|
||||
"""Render histograms for all three tasks side by side.
|
||||
|
||||
Args:
|
||||
|
|
@ -81,7 +83,13 @@ def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTra
|
|||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||
showlegend=True,
|
||||
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
|
||||
legend={
|
||||
"orientation": "h",
|
||||
"yanchor": "bottom",
|
||||
"y": 1.02,
|
||||
"xanchor": "right",
|
||||
"x": 1,
|
||||
},
|
||||
xaxis_title=None,
|
||||
yaxis_title="Count",
|
||||
xaxis={"tickangle": -45},
|
||||
|
|
@ -137,7 +145,10 @@ def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
|
|||
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
|
||||
|
||||
elif color_mode == "split":
|
||||
split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange
|
||||
split_colors = {
|
||||
"train": [66, 135, 245],
|
||||
"test": [245, 135, 66],
|
||||
} # Blue # Orange
|
||||
gdf["fill_color"] = gdf["split"].map(split_colors)
|
||||
|
||||
return gdf
|
||||
|
|
@ -168,7 +179,14 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
|||
)
|
||||
|
||||
with col2:
|
||||
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity")
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="spatial_map_opacity",
|
||||
)
|
||||
|
||||
# Determine which task dataset to use and color mode
|
||||
if vis_mode == "split":
|
||||
|
|
@ -258,7 +276,10 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
|||
# Set initial view state (centered on the Arctic)
|
||||
# Adjust pitch and zoom based on whether we're using elevation
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not use_elevation else 1.5,
|
||||
pitch=0 if not use_elevation else 45,
|
||||
)
|
||||
|
||||
# Create deck
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
|
|||
return colors
|
||||
|
||||
|
||||
def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
|
||||
def generate_unified_colormap(
|
||||
settings: dict,
|
||||
) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
|
||||
"""Generate unified colormaps for all plotting libraries.
|
||||
|
||||
This function creates consistent color schemes across Matplotlib/Ultraplot,
|
||||
|
|
@ -125,7 +127,10 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
|
|||
if is_dark_theme:
|
||||
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
|
||||
else:
|
||||
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
|
||||
base_colors = [
|
||||
"#3498db",
|
||||
"#e74c3c",
|
||||
] # Brighter blue and red for light theme
|
||||
else:
|
||||
# For multi-class: use a sequential colormap
|
||||
# Use matplotlib's viridis colormap
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ class ModelDisplayInfo:
|
|||
|
||||
model_display_infos: dict[Model, ModelDisplayInfo] = {
|
||||
"espa": ModelDisplayInfo(
|
||||
keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm"
|
||||
keyword="espa",
|
||||
short="eSPA",
|
||||
long="entropy-optimal Scalable Probabilistic Approximations algorithm",
|
||||
),
|
||||
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
|
||||
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
return training_results
|
||||
|
||||
|
||||
def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]:
|
||||
def load_all_training_data(
|
||||
e: DatasetEnsemble,
|
||||
) -> dict[Task, CategoricalTrainingDataset]:
|
||||
"""Load training data for all three tasks.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -142,7 +142,9 @@ class DatasetStatistics:
|
|||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||
|
||||
@staticmethod
|
||||
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
def get_sample_count_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert sample count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
|
|
@ -164,7 +166,9 @@ class DatasetStatistics:
|
|||
return pd.DataFrame(rows)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
def get_feature_count_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert feature count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
|
|
@ -191,7 +195,9 @@ class DatasetStatistics:
|
|||
return pd.DataFrame(rows)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
def get_feature_breakdown_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
|
|
@ -219,7 +225,10 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
|||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||
for target in all_target_datasets:
|
||||
target_statistics[target] = TargetStatistics.compute(
|
||||
grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
|
||||
grid=grid_config.grid,
|
||||
level=grid_config.level,
|
||||
target=target,
|
||||
total_cells=total_cells,
|
||||
)
|
||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||
for member in all_l2_source_datasets:
|
||||
|
|
@ -357,7 +366,10 @@ class TrainingDatasetStatistics:
|
|||
|
||||
@classmethod
|
||||
def compute(
|
||||
cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None
|
||||
cls,
|
||||
ensemble: DatasetEnsemble,
|
||||
task: Task,
|
||||
dataset: gpd.GeoDataFrame | None = None,
|
||||
) -> "TrainingDatasetStatistics":
|
||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
|
||||
|
|
|
|||
|
|
@ -65,7 +65,9 @@ def extract_embedding_features(
|
|||
|
||||
|
||||
def extract_era5_features(
|
||||
model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None
|
||||
model_state: xr.Dataset,
|
||||
importance_type: str = "feature_weights",
|
||||
temporal_group: str | None = None,
|
||||
) -> xr.DataArray | None:
|
||||
"""Extract ERA5 features from the model state.
|
||||
|
||||
|
|
@ -100,7 +102,17 @@ def extract_era5_features(
|
|||
- Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021")
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
common_aggs = {
|
||||
"mean",
|
||||
"std",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"sum",
|
||||
"count",
|
||||
"q25",
|
||||
"q75",
|
||||
}
|
||||
|
||||
# Find where the time part starts (after "era5" and variable name)
|
||||
# Pattern: era5_variable_time or era5_variable_time_agg
|
||||
|
|
@ -166,7 +178,17 @@ def extract_era5_features(
|
|||
|
||||
def _extract_time_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
common_aggs = {
|
||||
"mean",
|
||||
"std",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"sum",
|
||||
"count",
|
||||
"q25",
|
||||
"q75",
|
||||
}
|
||||
if parts[-1] in common_aggs:
|
||||
# Has aggregation: era5_var_time_agg -> time is second to last
|
||||
return parts[-2]
|
||||
|
|
@ -179,11 +201,8 @@ def extract_era5_features(
|
|||
|
||||
Pattern: era5_variable_season_year_agg or era5_variable_season_year
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
|
||||
# Look through parts to find season/shoulder indicators
|
||||
for part in parts:
|
||||
for part in feature.split("_"):
|
||||
if part.lower() in ["summer", "winter"]:
|
||||
return part.lower()
|
||||
elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
|
|
@ -197,11 +216,26 @@ def extract_era5_features(
|
|||
For seasonal/shoulder features, find the year that comes after the season.
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
common_aggs = {
|
||||
"mean",
|
||||
"std",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"sum",
|
||||
"count",
|
||||
"q25",
|
||||
"q75",
|
||||
}
|
||||
|
||||
# Find the season/shoulder part, then the next part should be the year
|
||||
for i, part in enumerate(parts):
|
||||
if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
if part.lower() in ["summer", "winter"] or part.upper() in [
|
||||
"JFM",
|
||||
"AMJ",
|
||||
"JAS",
|
||||
"OND",
|
||||
]:
|
||||
# Next part should be the year
|
||||
if i + 1 < len(parts):
|
||||
next_part = parts[i + 1]
|
||||
|
|
@ -218,7 +252,17 @@ def extract_era5_features(
|
|||
|
||||
def _extract_agg_name(feature: str) -> str | None:
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
common_aggs = {
|
||||
"mean",
|
||||
"std",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"sum",
|
||||
"count",
|
||||
"q25",
|
||||
"q75",
|
||||
}
|
||||
if parts[-1] in common_aggs:
|
||||
return parts[-1]
|
||||
return None
|
||||
|
|
@ -255,7 +299,10 @@ def extract_era5_features(
|
|||
|
||||
if has_agg:
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||
agg=(
|
||||
"feature",
|
||||
[_extract_agg_name(f) or "none" for f in era5_features],
|
||||
),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack(
|
||||
"feature"
|
||||
|
|
@ -274,7 +321,10 @@ def extract_era5_features(
|
|||
if has_agg:
|
||||
# Add aggregation dimension
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||
agg=(
|
||||
"feature",
|
||||
[_extract_agg_name(f) or "none" for f in era5_features],
|
||||
),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
|
||||
else:
|
||||
|
|
@ -345,7 +395,14 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
|
|||
Returns None if no common features are found.
|
||||
|
||||
"""
|
||||
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
|
||||
common_feature_names = [
|
||||
"cell_area",
|
||||
"water_area",
|
||||
"land_area",
|
||||
"land_ratio",
|
||||
"lon",
|
||||
"lat",
|
||||
]
|
||||
|
||||
def _is_common_feature(feature: str) -> bool:
|
||||
return feature in common_feature_names
|
||||
|
|
|
|||
|
|
@ -66,7 +66,10 @@ def render_inference_page():
|
|||
with col4:
|
||||
st.metric("Level", selected_result.settings.get("level", "Unknown"))
|
||||
with col5:
|
||||
st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", ""))
|
||||
st.metric(
|
||||
"Target",
|
||||
selected_result.settings.get("target", "Unknown").replace("darts_", ""),
|
||||
)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,16 @@ from stopuhr import stopwatch
|
|||
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
|
||||
from entropice.dashboard.utils.stats import (
|
||||
DatasetStatistics,
|
||||
load_all_default_dataset_statistics,
|
||||
)
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
|
||||
def render_sample_count_overview():
|
||||
|
|
@ -45,7 +53,12 @@ def render_sample_count_overview():
|
|||
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||
|
||||
# Pivot for heatmap: Grid x Task
|
||||
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
|
||||
pivot_df = target_df.pivot_table(
|
||||
index="Grid",
|
||||
columns="Task",
|
||||
values="Samples (Coverage)",
|
||||
aggfunc="mean",
|
||||
)
|
||||
|
||||
# Sort index by grid type and level
|
||||
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||
|
|
@ -56,7 +69,11 @@ def render_sample_count_overview():
|
|||
|
||||
fig = px.imshow(
|
||||
pivot_df,
|
||||
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
|
||||
labels={
|
||||
"x": "Task",
|
||||
"y": "Grid Configuration",
|
||||
"color": "Sample Count",
|
||||
},
|
||||
x=pivot_df.columns,
|
||||
y=pivot_df.index,
|
||||
color_continuous_scale=sample_colors,
|
||||
|
|
@ -87,7 +104,10 @@ def render_sample_count_overview():
|
|||
facet_col="Target",
|
||||
barmode="group",
|
||||
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||
labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Samples (Coverage)": "Number of Samples",
|
||||
},
|
||||
color_discrete_sequence=task_colors,
|
||||
height=500,
|
||||
)
|
||||
|
|
@ -141,7 +161,10 @@ def render_feature_count_comparison():
|
|||
color="Data Source",
|
||||
barmode="stack",
|
||||
title="Total Features by Data Source Across Grid Configurations",
|
||||
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Number of Features": "Number of Features",
|
||||
},
|
||||
color_discrete_sequence=source_colors,
|
||||
text_auto=False,
|
||||
)
|
||||
|
|
@ -162,7 +185,10 @@ def render_feature_count_comparison():
|
|||
y="Inference Cells",
|
||||
color="Grid",
|
||||
title="Inference Cells by Grid Configuration",
|
||||
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Inference Cells": "Number of Cells",
|
||||
},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Inference Cells",
|
||||
)
|
||||
|
|
@ -177,7 +203,10 @@ def render_feature_count_comparison():
|
|||
y="Total Samples",
|
||||
color="Grid",
|
||||
title="Total Samples by Grid Configuration",
|
||||
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Total Samples": "Number of Samples",
|
||||
},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Total Samples",
|
||||
)
|
||||
|
|
@ -226,7 +255,13 @@ def render_feature_count_comparison():
|
|||
|
||||
# Display full comparison table with formatting
|
||||
display_df = comparison_df[
|
||||
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
|
||||
[
|
||||
"Grid",
|
||||
"Total Features",
|
||||
"Data Sources",
|
||||
"Inference Cells",
|
||||
"Total Samples",
|
||||
]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
|
|
@ -314,7 +349,11 @@ def render_feature_count_explorer():
|
|||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
||||
st.metric(
|
||||
"Inference Cells",
|
||||
f"{min_cells:,}",
|
||||
help="Number of union of cells across all data sources",
|
||||
)
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,10 @@ from entropice.dashboard.plots.source_data import (
|
|||
render_era5_overview,
|
||||
render_era5_plots,
|
||||
)
|
||||
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||
from entropice.dashboard.plots.training_data import (
|
||||
render_all_distribution_histograms,
|
||||
render_spatial_map,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.spatial import grids
|
||||
|
|
@ -42,7 +45,10 @@ def render_training_data_page():
|
|||
]
|
||||
|
||||
grid_level_combined = st.selectbox(
|
||||
"Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level"
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
help="Select the grid system and resolution level",
|
||||
)
|
||||
|
||||
# Parse grid type and level
|
||||
|
|
@ -60,7 +66,13 @@ def render_training_data_page():
|
|||
# Members selection
|
||||
st.subheader("Dataset Members")
|
||||
|
||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
all_members = [
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
]
|
||||
selected_members = []
|
||||
|
||||
for member in all_members:
|
||||
|
|
@ -69,7 +81,10 @@ def render_training_data_page():
|
|||
|
||||
# Form submit button
|
||||
load_button = st.form_submit_button(
|
||||
"Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0
|
||||
"Load Dataset",
|
||||
type="primary",
|
||||
width="stretch",
|
||||
disabled=len(selected_members) == 0,
|
||||
)
|
||||
|
||||
# Create DatasetEnsemble only when form is submitted
|
||||
|
|
|
|||
|
|
@ -77,7 +77,20 @@ def download(grid: Grid, level: int):
|
|||
for year in track(range(2024, 2025), total=1, description="Processing years..."):
|
||||
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
||||
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
||||
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||
aggs = [
|
||||
"mean",
|
||||
"stdDev",
|
||||
"min",
|
||||
"max",
|
||||
"count",
|
||||
"median",
|
||||
"p1",
|
||||
"p5",
|
||||
"p25",
|
||||
"p75",
|
||||
"p95",
|
||||
"p99",
|
||||
]
|
||||
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
|
||||
|
||||
def extract_embedding(feature):
|
||||
|
|
@ -136,7 +149,20 @@ def combine_to_zarr(grid: Grid, level: int):
|
|||
"""
|
||||
cell_ids = grids.get_cell_ids(grid, level)
|
||||
years = list(range(2018, 2025))
|
||||
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||
aggs = [
|
||||
"mean",
|
||||
"stdDev",
|
||||
"min",
|
||||
"max",
|
||||
"count",
|
||||
"median",
|
||||
"p1",
|
||||
"p5",
|
||||
"p25",
|
||||
"p75",
|
||||
"p95",
|
||||
"p99",
|
||||
]
|
||||
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
|
||||
|
||||
a = xr.DataArray(
|
||||
|
|
|
|||
|
|
@ -125,8 +125,14 @@ def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
|
|||
cs = 3600
|
||||
|
||||
# Calculate safe slice bounds for edge chunks
|
||||
y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d)
|
||||
x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d)
|
||||
y_start, y_end = (
|
||||
max(0, cs * chunk_loc[0] - d),
|
||||
min(len(y), cs * chunk_loc[0] + cs + d),
|
||||
)
|
||||
x_start, x_end = (
|
||||
max(0, cs * chunk_loc[1] - d),
|
||||
min(len(x), cs * chunk_loc[1] + cs + d),
|
||||
)
|
||||
|
||||
# Extract coordinate arrays with safe bounds
|
||||
y_chunk = cp.asarray(y[y_start:y_end])
|
||||
|
|
@ -162,7 +168,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
|
|||
if np.all(np.isnan(chunk)):
|
||||
# Return an array of NaNs with the expected shape
|
||||
return np.full(
|
||||
(12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px),
|
||||
(
|
||||
12,
|
||||
chunk.shape[0] - 2 * large_kernels.size_px,
|
||||
chunk.shape[1] - 2 * large_kernels.size_px,
|
||||
),
|
||||
np.nan,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
|
@ -171,7 +181,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No
|
|||
|
||||
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
|
||||
mask = cp.isnan(chunk)
|
||||
mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True)
|
||||
mask &= ~binary_dilation(
|
||||
binary_erosion(mask, iterations=3, brute_force=True),
|
||||
iterations=3,
|
||||
brute_force=True,
|
||||
)
|
||||
if cp.any(mask):
|
||||
# Find indices of valid values
|
||||
indices = distance_transform_edt(mask, return_distances=False, return_indices=True)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,12 @@ from rich.progress import track
|
|||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.spatial import grids
|
||||
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||
from entropice.utils.paths import (
|
||||
darts_ml_training_labels_repo,
|
||||
dartsl2_cov_file,
|
||||
dartsl2_file,
|
||||
get_darts_rts_file,
|
||||
)
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
traceback.install()
|
||||
|
|
|
|||
|
|
@ -204,15 +204,36 @@ def download_daily_aggregated():
|
|||
)
|
||||
|
||||
# Assign attributes
|
||||
daily_raw["t2m_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"}
|
||||
daily_raw["t2m_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"}
|
||||
daily_raw["t2m_mean"].attrs = {"long_name": "Daily mean 2 metre temperature", "units": "K"}
|
||||
daily_raw["t2m_max"].attrs = {
|
||||
"long_name": "Daily maximum 2 metre temperature",
|
||||
"units": "K",
|
||||
}
|
||||
daily_raw["t2m_min"].attrs = {
|
||||
"long_name": "Daily minimum 2 metre temperature",
|
||||
"units": "K",
|
||||
}
|
||||
daily_raw["t2m_mean"].attrs = {
|
||||
"long_name": "Daily mean 2 metre temperature",
|
||||
"units": "K",
|
||||
}
|
||||
daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"}
|
||||
daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"}
|
||||
daily_raw["lblt_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"}
|
||||
daily_raw["tp"].attrs = {"long_name": "Daily total precipitation", "units": "m"} # Units are rather m^3 / m^2
|
||||
daily_raw["sf"].attrs = {"long_name": "Daily total snow fall", "units": "m"} # Units are rather m^3 / m^2
|
||||
daily_raw["sshf"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"}
|
||||
daily_raw["lblt_max"].attrs = {
|
||||
"long_name": "Daily maximum lake ice bottom temperature",
|
||||
"units": "K",
|
||||
}
|
||||
daily_raw["tp"].attrs = {
|
||||
"long_name": "Daily total precipitation",
|
||||
"units": "m",
|
||||
} # Units are rather m^3 / m^2
|
||||
daily_raw["sf"].attrs = {
|
||||
"long_name": "Daily total snow fall",
|
||||
"units": "m",
|
||||
} # Units are rather m^3 / m^2
|
||||
daily_raw["sshf"].attrs = {
|
||||
"long_name": "Daily total surface sensible heat flux",
|
||||
"units": "J/m²",
|
||||
}
|
||||
|
||||
daily_raw = daily_raw.odc.assign_crs("epsg:4326")
|
||||
daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"])
|
||||
|
|
@ -281,11 +302,17 @@ def daily_enrich():
|
|||
|
||||
# Formulas based on Groeke et. al. (2025) Stochastic Weather generation...
|
||||
daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2
|
||||
daily.t2m_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"}
|
||||
daily.t2m_avg.attrs = {
|
||||
"long_name": "Daily average 2 metre temperature",
|
||||
"units": "K",
|
||||
}
|
||||
_store("t2m_avg")
|
||||
|
||||
daily["t2m_range"] = daily.t2m_max - daily.t2m_min
|
||||
daily.t2m_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"}
|
||||
daily.t2m_range.attrs = {
|
||||
"long_name": "Daily range of 2 metre temperature",
|
||||
"units": "K",
|
||||
}
|
||||
_store("t2m_range")
|
||||
|
||||
with np.errstate(invalid="ignore"):
|
||||
|
|
@ -298,7 +325,10 @@ def daily_enrich():
|
|||
_store("thawing_degree_days")
|
||||
|
||||
daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0)
|
||||
daily.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"}
|
||||
daily.freezing_degree_days.attrs = {
|
||||
"long_name": "Freezing degree days",
|
||||
"units": "K",
|
||||
}
|
||||
_store("freezing_degree_days")
|
||||
|
||||
daily["thawing_days"] = (daily.t2m_avg > 273.15).where(~daily.t2m_avg.isnull())
|
||||
|
|
@ -425,7 +455,12 @@ def multi_monthly_aggregate(agg: Literal["yearly", "seasonal", "shoulder"] = "ye
|
|||
|
||||
multimonthly_store = get_era5_stores(agg)
|
||||
print(f"Saving empty multi-monthly ERA5 data to {multimonthly_store}.")
|
||||
multimonthly.to_zarr(multimonthly_store, mode="w", encoding=codecs.from_ds(multimonthly), consolidated=False)
|
||||
multimonthly.to_zarr(
|
||||
multimonthly_store,
|
||||
mode="w",
|
||||
encoding=codecs.from_ds(multimonthly),
|
||||
consolidated=False,
|
||||
)
|
||||
|
||||
def _store(var):
|
||||
nonlocal multimonthly
|
||||
|
|
@ -512,14 +547,20 @@ def yearly_thaw_periods():
|
|||
never_thaws = (daily.thawing_days.resample(time="12MS").sum(dim="time") == 0).persist()
|
||||
|
||||
first_thaw_day = daily.thawing_days.resample(time="12MS").map(_get_first_day).where(~never_thaws)
|
||||
first_thaw_day.attrs = {"long_name": "Day of first thaw in year", "units": "day of year"}
|
||||
first_thaw_day.attrs = {
|
||||
"long_name": "Day of first thaw in year",
|
||||
"units": "day of year",
|
||||
}
|
||||
_store_in_yearly(first_thaw_day, "day_of_first_thaw")
|
||||
|
||||
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").first()
|
||||
last_thaw_day = n_days_in_year - daily.thawing_days[::-1].resample(time="12MS").map(_get_first_day).where(
|
||||
~never_thaws
|
||||
)
|
||||
last_thaw_day.attrs = {"long_name": "Day of last thaw in year", "units": "day of year"}
|
||||
last_thaw_day.attrs = {
|
||||
"long_name": "Day of last thaw in year",
|
||||
"units": "day of year",
|
||||
}
|
||||
_store_in_yearly(last_thaw_day, "day_of_last_thaw")
|
||||
|
||||
|
||||
|
|
@ -532,14 +573,20 @@ def yearly_freeze_periods():
|
|||
never_freezes = (daily.freezing_days.resample(time="12MS").sum(dim="time") == 0).persist()
|
||||
|
||||
first_freeze_day = daily.freezing_days.resample(time="12MS").map(_get_first_day).where(~never_freezes)
|
||||
first_freeze_day.attrs = {"long_name": "Day of first freeze in year", "units": "day of year"}
|
||||
first_freeze_day.attrs = {
|
||||
"long_name": "Day of first freeze in year",
|
||||
"units": "day of year",
|
||||
}
|
||||
_store_in_yearly(first_freeze_day, "day_of_first_freeze")
|
||||
|
||||
n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").last()
|
||||
last_freeze_day = n_days_in_year - daily.freezing_days[::-1].resample(time="12MS").map(_get_first_day).where(
|
||||
~never_freezes
|
||||
)
|
||||
last_freeze_day.attrs = {"long_name": "Day of last freeze in year", "units": "day of year"}
|
||||
last_freeze_day.attrs = {
|
||||
"long_name": "Day of last freeze in year",
|
||||
"units": "day of year",
|
||||
}
|
||||
_store_in_yearly(last_freeze_day, "day_of_last_freeze")
|
||||
|
||||
|
||||
|
|
@ -728,7 +775,12 @@ def spatial_agg(
|
|||
pxbuffer=10,
|
||||
)
|
||||
|
||||
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||
aggregated = aggregated.chunk(
|
||||
{
|
||||
"cell_ids": min(len(aggregated.cell_ids), 10000),
|
||||
"time": len(aggregated.time),
|
||||
}
|
||||
)
|
||||
store = get_era5_stores(agg, grid, level)
|
||||
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
|
||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||
|
|
|
|||
|
|
@ -99,7 +99,14 @@ def bin_values(
|
|||
"""
|
||||
labels_dict = {
|
||||
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
|
||||
"density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
|
||||
"density": [
|
||||
"Empty",
|
||||
"Very Sparse",
|
||||
"Sparse",
|
||||
"Moderate",
|
||||
"Dense",
|
||||
"Very Dense",
|
||||
],
|
||||
}
|
||||
labels = labels_dict[task]
|
||||
|
||||
|
|
@ -186,7 +193,13 @@ class DatasetEnsemble:
|
|||
level: int
|
||||
target: Literal["darts_rts", "darts_mllabels"]
|
||||
members: list[L2SourceDataset] = field(
|
||||
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
default_factory=lambda: [
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
]
|
||||
)
|
||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||
variable_filters: dict[str, list[str]] = field(default_factory=dict)
|
||||
|
|
@ -277,7 +290,9 @@ class DatasetEnsemble:
|
|||
return targets
|
||||
|
||||
def _prep_era5(
|
||||
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||
self,
|
||||
targets: gpd.GeoDataFrame,
|
||||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
||||
) -> pd.DataFrame:
|
||||
era5 = self._read_member("ERA5-" + temporal, targets)
|
||||
era5_df = era5.to_dataframe()
|
||||
|
|
@ -352,7 +367,9 @@ class DatasetEnsemble:
|
|||
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
||||
|
||||
def create(
|
||||
self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r"
|
||||
self,
|
||||
filter_target_col: str | None = None,
|
||||
cache_mode: Literal["n", "o", "r"] = "r",
|
||||
) -> gpd.GeoDataFrame:
|
||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
||||
|
|
@ -438,7 +455,10 @@ class DatasetEnsemble:
|
|||
yield dataset
|
||||
|
||||
def _cat_and_split(
|
||||
self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||
self,
|
||||
dataset: gpd.GeoDataFrame,
|
||||
task: Task,
|
||||
device: Literal["cpu", "cuda", "torch"],
|
||||
) -> CategoricalTrainingDataset:
|
||||
taskcol = self.taskcol(task)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ set_config(array_api_dispatch=True)
|
|||
|
||||
|
||||
def predict_proba(
|
||||
e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list
|
||||
e: DatasetEnsemble,
|
||||
clf: RandomForestClassifier | ESPAClassifier | XGBClassifier,
|
||||
classes: list,
|
||||
) -> gpd.GeoDataFrame:
|
||||
"""Get predicted probabilities for each cell.
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,19 @@ class _Aggregations:
|
|||
|
||||
def aggnames(self) -> list[str]:
|
||||
if self._common:
|
||||
return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||
return [
|
||||
"mean",
|
||||
"std",
|
||||
"min",
|
||||
"max",
|
||||
"median",
|
||||
"p1",
|
||||
"p5",
|
||||
"p25",
|
||||
"p75",
|
||||
"p95",
|
||||
"p99",
|
||||
]
|
||||
names = []
|
||||
if self.mean:
|
||||
names.append("mean")
|
||||
|
|
@ -105,7 +117,15 @@ class _Aggregations:
|
|||
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
|
||||
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
|
||||
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
|
||||
quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here!
|
||||
quantiles_to_compute = [
|
||||
0.5,
|
||||
0.01,
|
||||
0.05,
|
||||
0.25,
|
||||
0.75,
|
||||
0.95,
|
||||
0.99,
|
||||
] # ? Ordering is important here!
|
||||
cell_data[4:, ...] = flattened_var.quantile(
|
||||
q=quantiles_to_compute,
|
||||
dim="z",
|
||||
|
|
@ -156,7 +176,11 @@ class _Aggregations:
|
|||
return self._agg_cell_data_single(flattened)
|
||||
|
||||
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
|
||||
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
|
||||
cell_data = np.full(
|
||||
(len(flattened.data_vars), len(self), *others_shape),
|
||||
np.nan,
|
||||
dtype=np.float32,
|
||||
)
|
||||
for i, var in enumerate(flattened.data_vars):
|
||||
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
|
||||
# Transform to numpy arrays
|
||||
|
|
@ -194,7 +218,9 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
|||
|
||||
|
||||
@stopwatch("Correcting geometries", log=False)
|
||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
||||
def _get_corrected_geoms(
|
||||
inp: tuple[Polygon, odc.geo.geobox.GeoBox, str],
|
||||
) -> list[odc.geo.Geometry]:
|
||||
geom, gbox, crs = inp
|
||||
# cell.geometry is a shapely Polygon
|
||||
if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
|
||||
|
|
@ -440,7 +466,12 @@ def _align_partition(
|
|||
others_shape = tuple(
|
||||
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
)
|
||||
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
ongrid_shape = (
|
||||
len(grid_partition_gdf),
|
||||
len(raster.data_vars),
|
||||
len(aggregations),
|
||||
*others_shape,
|
||||
)
|
||||
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
|
||||
|
||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
|
|
@ -455,7 +486,11 @@ def _align_partition(
|
|||
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
coords = {
|
||||
"cell_ids": cell_ids,
|
||||
"variables": list(raster.data_vars),
|
||||
"aggregations": aggregations.aggnames(),
|
||||
}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
|
|
|||
|
|
@ -131,7 +131,9 @@ def create_global_hex_grid(resolution):
|
|||
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
|
||||
}
|
||||
for future in track(
|
||||
as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells)
|
||||
as_completed(future_to_hex),
|
||||
description="Creating hex polygons...",
|
||||
total=len(hex0_cells),
|
||||
):
|
||||
hex_batch, hex_id_batch, hex_area_batch = future.result()
|
||||
hex_list.extend(hex_batch)
|
||||
|
|
@ -157,7 +159,10 @@ def create_global_hex_grid(resolution):
|
|||
hex_area_list.append(hex_area)
|
||||
|
||||
# Create GeoDataFrame
|
||||
grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326")
|
||||
grid = gpd.GeoDataFrame(
|
||||
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
|
||||
crs="EPSG:4326",
|
||||
)
|
||||
|
||||
return grid
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ from zarr.codecs import BloscCodec
|
|||
|
||||
|
||||
def from_ds(
|
||||
ds: xr.Dataset, store_floats_as_float32: bool = True, include_coords: bool = True, filter_existing: bool = True
|
||||
ds: xr.Dataset,
|
||||
store_floats_as_float32: bool = True,
|
||||
include_coords: bool = True,
|
||||
filter_existing: bool = True,
|
||||
) -> dict:
|
||||
"""Create compression encoding for zarr dataset storage.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,17 @@ from dataclasses import dataclass
|
|||
from typing import Literal
|
||||
|
||||
type Grid = Literal["hex", "healpix"]
|
||||
type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"]
|
||||
type GridLevel = Literal[
|
||||
"hex3",
|
||||
"hex4",
|
||||
"hex5",
|
||||
"hex6",
|
||||
"healpix6",
|
||||
"healpix7",
|
||||
"healpix8",
|
||||
"healpix9",
|
||||
"healpix10",
|
||||
]
|
||||
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
|
||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||
|
|
|
|||
83
tests/l2dataset_validation.py
Normal file
83
tests/l2dataset_validation.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
from itertools import product
|
||||
|
||||
import xarray as xr
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.utils.paths import (
|
||||
get_arcticdem_stores,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
)
|
||||
from entropice.utils.types import Grid, L2SourceDataset
|
||||
|
||||
|
||||
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
|
||||
"""Validate if the L2 dataset exists for the given grid and level.
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
l2ds (L2SourceDataset): The L2 source dataset to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset exists and does not contain NaNs, False otherwise.
|
||||
|
||||
"""
|
||||
if l2ds == "ArcticDEM":
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
|
||||
agg = l2ds.split("-")[1]
|
||||
store = get_era5_stores(agg, grid, level) # type: ignore
|
||||
elif l2ds == "AlphaEarth":
|
||||
store = get_embeddings_store(grid, level)
|
||||
else:
|
||||
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
|
||||
|
||||
if not store.exists():
|
||||
print("\t Dataset store does not exist")
|
||||
return False
|
||||
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
has_nan = False
|
||||
for var in ds.data_vars:
|
||||
n_nans = ds[var].isnull().sum().compute().item()
|
||||
if n_nans > 0:
|
||||
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
|
||||
has_nan = True
|
||||
if has_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-yearly",
|
||||
"AlphaEarth",
|
||||
]
|
||||
|
||||
for (grid, level), l2ds in track(
|
||||
product(grid_levels, l2_source_datasets),
|
||||
total=len(grid_levels) * len(l2_source_datasets),
|
||||
description="Validating L2 datasets...",
|
||||
):
|
||||
is_valid = validate_l2_dataset(grid, level, l2ds)
|
||||
status = "VALID" if is_valid else "INVALID"
|
||||
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue