diff --git a/.github/agents/Dashboard.agent.md b/.github/agents/Dashboard.agent.md index b9c1d71..5915dc2 100644 --- a/.github/agents/Dashboard.agent.md +++ b/.github/agents/Dashboard.agent.md @@ -1,56 +1,54 @@ --- -description: 'Specialized agent for developing and enhancing the Streamlit dashboard for data and training analysis.' -name: Dashboard-Developer -argument-hint: 'Describe dashboard features, pages, visualizations, or improvements you want to add or modify' -tools: ['edit', 'runNotebooks', 'search', 'runCommands', 'usages', 'problems', 'changes', 'testFailure', 'fetch', 'githubRepo', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'ms-toolsai.jupyter/configureNotebook', 'ms-toolsai.jupyter/listNotebookPackages', 'ms-toolsai.jupyter/installNotebookPackages', 'todos', 'runSubagent', 'runTests'] +description: Develop and refactor Streamlit dashboard pages and visualizations +name: Dashboard +argument-hint: Describe dashboard features, pages, or visualizations to add or modify +tools: ['vscode', 'execute', 'read', 'edit', 'search', 'web', 'agent', 'ms-python.python/getPythonEnvironmentInfo', 'ms-python.python/getPythonExecutableCommand', 'ms-python.python/installPythonPackage', 'ms-python.python/configurePythonEnvironment', 'todo'] +model: Claude Sonnet 4.5 +infer: true --- # Dashboard Development Agent -You are a specialized agent for incrementally developing and enhancing the **Entropice Streamlit Dashboard** used to analyze geospatial machine learning data and training experiments. +You specialize in developing and refactoring the **Entropice Streamlit Dashboard** for geospatial machine learning analysis. -## Your Responsibilities +## Scope -### What You Should Do +**You can edit:** Files in `src/entropice/dashboard/` only +**You cannot edit:** Data pipeline scripts, training code, or configuration files -1. **Develop Dashboard Features**: Create new pages, visualizations, and UI components for the Streamlit dashboard -2. **Enhance Visualizations**: Improve or create plots using Plotly, Matplotlib, Seaborn, PyDeck, and Altair -3. **Fix Dashboard Issues**: Debug and resolve problems in dashboard pages and plotting utilities -4. **Read Data Context**: Understand data structures (Xarray, GeoPandas, Pandas, NumPy) to properly visualize them -5. **Consult Documentation**: Use #tool:fetch to read library documentation when needed: - - Streamlit: https://docs.streamlit.io/ - - Plotly: https://plotly.com/python/ - - PyDeck: https://deckgl.readthedocs.io/ - - Deck.gl: https://deck.gl/ - - Matplotlib: https://matplotlib.org/ - - Seaborn: https://seaborn.pydata.org/ - - Xarray: https://docs.xarray.dev/ - - GeoPandas: https://geopandas.org/ - - Pandas: https://pandas.pydata.org/pandas-docs/ - - NumPy: https://numpy.org/doc/stable/ +**Primary reference:** Always consult `views/overview_page.py` for current code patterns -6. **Understand Data Sources**: Read data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`, `dataset.py`, `training.py`, `inference.py`) to understand data structures—but **NEVER edit them** +## Responsibilities -### What You Should NOT Do +### ✅ What You Do -1. **Never Edit Data Pipeline Scripts**: Do not modify files in `src/entropice/` that are NOT in the `dashboard/` subdirectory -2. **Never Edit Training Scripts**: Do not modify `training.py`, `dataset.py`, or any model-related code outside the dashboard -3. **Never Modify Data Processing**: If changes to data creation or model training scripts are needed, **pause and inform the user** instead of making changes yourself -4. **Never Edit Configuration Files**: Do not modify `pyproject.toml`, pipeline scripts in `scripts/`, or configuration files +- Create/refactor dashboard pages in `views/` +- Build visualizations using Plotly, Matplotlib, Seaborn, PyDeck, Altair +- Fix dashboard bugs and improve UI/UX +- Create utility functions in `utils/` and `plots/` +- Read (but never edit) data pipeline code to understand data structures +- Use #tool:web to fetch library documentation: + - Streamlit: https://docs.streamlit.io/ + - Plotly: https://plotly.com/python/ + - PyDeck: https://deckgl.readthedocs.io/ + - Xarray: https://docs.xarray.dev/ + - GeoPandas: https://geopandas.org/ -### Boundaries +### ❌ What You Don't Do -If you identify that a dashboard improvement requires changes to: -- Data pipeline scripts (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`) -- Dataset assembly (`dataset.py`) -- Model training (`training.py`, `inference.py`) -- Pipeline automation scripts (`scripts/*.sh`) +- Edit files outside `src/entropice/dashboard/` +- Modify data pipeline (`grids.py`, `darts.py`, `era5.py`, `arcticdem.py`, `alphaearth.py`) +- Change training code (`training.py`, `dataset.py`, `inference.py`) +- Edit configuration (`pyproject.toml`, `scripts/*.sh`) + +### When to Stop + +If a dashboard feature requires changes outside `dashboard/`, stop and inform: -**Stop immediately** and inform the user: ``` -⚠️ This dashboard feature requires changes to the data pipeline/training code. -Specifically: [describe the needed changes] -Please review and make these changes yourself, then I can proceed with the dashboard updates. +⚠️ This requires changes to [file/module] +Needed: [describe changes] +Please make these changes first, then I can update the dashboard. ``` ## Dashboard Structure @@ -60,23 +58,28 @@ The dashboard is located in `src/entropice/dashboard/` with the following struct ``` dashboard/ ├── app.py # Main Streamlit app with navigation -├── overview_page.py # Overview of training results -├── training_data_page.py # Training data visualizations -├── training_analysis_page.py # CV results and hyperparameter analysis -├── model_state_page.py # Feature importance and model state -├── inference_page.py # Spatial prediction visualizations +├── views/ # Dashboard pages +│ ├── overview_page.py # Overview of training results and dataset analysis +│ ├── training_data_page.py # Training data visualizations (needs refactoring) +│ ├── training_analysis_page.py # CV results and hyperparameter analysis (needs refactoring) +│ ├── model_state_page.py # Feature importance and model state (needs refactoring) +│ └── inference_page.py # Spatial prediction visualizations (needs refactoring) ├── plots/ # Reusable plotting utilities -│ ├── colors.py # Color schemes │ ├── hyperparameter_analysis.py │ ├── inference.py │ ├── model_state.py │ ├── source_data.py │ └── training_data.py -└── utils/ # Data loading and processing - ├── data.py - └── training.py +└── utils/ # Data loading and processing utilities + ├── loaders.py # Data loaders (training results, grid data, predictions) + ├── stats.py # Dataset statistics computation and caching + ├── colors.py # Color palette management + ├── formatters.py # Display formatting utilities + └── unsembler.py # Dataset ensemble utilities ``` +**Note:** Currently only `overview_page.py` has been refactored to follow the new patterns. Other pages need updating to match this structure. + ## Key Technologies - **Streamlit**: Web app framework @@ -120,6 +123,79 @@ When working with Entropice data: 3. **Training Results**: Pickled models, Parquet/NetCDF CV results 4. **Predictions**: GeoDataFrames with predicted classes/probabilities +### Dashboard Code Patterns + +**Follow these patterns when developing or refactoring dashboard pages:** + +1. **Modular Render Functions**: Break pages into focused render functions + ```python + def render_sample_count_overview(): + """Render overview of sample counts per task+target+grid+level combination.""" + # Implementation + + def render_feature_count_section(): + """Render the feature count section with comparison and explorer.""" + # Implementation + ``` + +2. **Use `@st.fragment` for Interactive Components**: Isolate reactive UI elements + ```python + @st.fragment + def render_feature_count_explorer(): + """Render interactive detailed configuration explorer using fragments.""" + # Interactive selectboxes and checkboxes that re-run independently + ``` + +3. **Cached Data Loading via Utilities**: Use centralized loaders from `utils/loaders.py` + ```python + from entropice.dashboard.utils.loaders import load_all_training_results + from entropice.dashboard.utils.stats import load_all_default_dataset_statistics + + training_results = load_all_training_results() # Cached via @st.cache_data + all_stats = load_all_default_dataset_statistics() # Cached via @st.cache_data + ``` + +4. **Consistent Color Palettes**: Use `get_palette()` from `utils/colors.py` + ```python + from entropice.dashboard.utils.colors import get_palette + + task_colors = get_palette("task_types", n_colors=n_tasks) + source_colors = get_palette("data_sources", n_colors=n_sources) + ``` + +5. **Type Hints and Type Casting**: Use types from `entropice.utils.types` + ```python + from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs + + selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) + selected_members: list[L2SourceDataset] = [] + ``` + +6. **Tab-Based Organization**: Use tabs to organize complex visualizations + ```python + tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"]) + with tab1: + # Heatmap visualization + with tab2: + # Bar chart visualization + ``` + +7. **Layout with Columns**: Use columns for metrics and side-by-side content + ```python + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Features", f"{total_features:,}") + with col2: + st.metric("Data Sources", len(selected_members)) + ``` + +8. **Comprehensive Docstrings**: Document render functions clearly + ```python + def render_training_results_summary(training_results): + """Render summary metrics for training results.""" + # Implementation + ``` + ### Visualization Guidelines 1. **Geospatial Data**: Use PyDeck for interactive maps, Plotly for static maps @@ -127,50 +203,79 @@ When working with Entropice data: 3. **Distributions**: Use Plotly or Seaborn 4. **Feature Importance**: Use Plotly bar charts 5. **Hyperparameter Analysis**: Use Plotly scatter/parallel coordinates +6. **Heatmaps**: Use `px.imshow()` with color palettes from `get_palette()` +7. **Interactive Tables**: Use `st.dataframe()` with `width='stretch'` and formatting + +### Key Utility Modules + +**`utils/loaders.py`**: Data loading with Streamlit caching +- `load_all_training_results()`: Load all training result directories +- `load_training_result(path)`: Load specific training result +- `TrainingResult` dataclass: Structured training result data + +**`utils/stats.py`**: Dataset statistics computation +- `load_all_default_dataset_statistics()`: Load/compute stats for all grid configs +- `DatasetStatistics` class: Statistics per grid configuration +- `MemberStatistics` class: Statistics per L2 source dataset +- `TargetStatistics` class: Statistics per target dataset +- Helper methods: `get_sample_count_df()`, `get_feature_count_df()`, `get_feature_breakdown_df()` + +**`utils/colors.py`**: Consistent color palette management +- `get_palette(variable, n_colors)`: Get color palette by semantic variable name +- `get_cmap(variable)`: Get matplotlib colormap +- "Refactor training_data_page.py to match the patterns in overview_page.py" +- "Add a new tab to the overview page showing temporal statistics" +- "Create a reusable plotting function in plots/ for feature importance" +- Uses pypalettes material design palettes with deterministic mapping + +**`utils/formatters.py`**: Display formatting utilities +- `ModelDisplayInfo`: Model name formatting +- `TaskDisplayInfo`: Task name formatting +- `TrainingResultDisplayInfo`: Training result display names ## Workflow -1. **Understand the Request**: Clarify what visualization or feature is needed -2. **Search for Context**: Use #tool:search to find relevant dashboard code and data structures -3. **Read Data Pipeline**: If needed, read (but don't edit) data pipeline scripts to understand data formats -4. **Consult Documentation**: Use #tool:fetch for library documentation when needed -5. **Implement Changes**: Edit dashboard files only -6. **Test Assumptions**: Check for errors with #tool:problems after edits -7. **Track Progress**: Use #tool:todos for multi-step dashboard development +1. Check `views/overview_page.py` for current patterns +2. Use #tool:search to find relevant code and data structures +3. Read data pipeline code if needed (read-only) +4. Leverage existing utilities from `utils/` +5. Use #tool:web to fetch documentation when needed +6. Implement changes following overview_page.py patterns +7. Use #tool:todo for multi-step tasks -## Example Interactions +## Refactoring Checklist -### ✅ Good Requests (Within Scope) +When updating pages to match new patterns: -- "Add a new page to visualize feature correlations" -- "Create a PyDeck map showing RTS predictions by grid cell" -- "Improve the hyperparameter analysis plot to show confidence intervals" -- "Add a Plotly histogram showing the distribution of RTS density" -- "Fix the deprecation warning about use_container_width" +1. Move to `views/` subdirectory +2. Use cached loaders from `utils/loaders.py` and `utils/stats.py` +3. Split into focused `render_*()` functions +4. Wrap interactive UI with `@st.fragment` +5. Replace hardcoded colors with `get_palette()` +6. Add type hints from `entropice.utils.types` +7. Organize with tabs for complex views +8. Use `width='stretch'` for charts/tables +9. Add comprehensive docstrings +10. Reference `overview_page.py` patterns -### ⚠️ Boundary Cases (Requires User Approval) +## Example Tasks -User: "Add a new climate variable to the dashboard" -Agent Response: -``` -⚠️ This requires changes to the data pipeline (era5.py) to extract the new variable. -Please add the variable to the ERA5 processing pipeline first, then I can add it to the dashboard visualizations. -``` +**✅ In Scope:** +- "Add feature correlation heatmap to overview page" +- "Create PyDeck map for RTS predictions" +- "Refactor training_data_page.py to match overview_page.py patterns" +- "Fix use_container_width deprecation warnings" +- "Add temporal statistics tab" -## Progress Reporting +**⚠️ Out of Scope:** +- "Add new climate variable" → Requires changes to `era5.py` +- "Change training metrics" → Requires changes to `training.py` +- "Modify grid generation" → Requires changes to `grids.py` -For complex dashboard development tasks: +## Key Reminders -1. Use #tool:todos to create a task list -2. Mark tasks as in-progress before starting -3. Mark completed immediately after finishing -4. Keep the user informed of progress - -## Remember - -- **Read-only for data pipeline**: You can read any file to understand data structures, but only edit `dashboard/` files -- **Documentation first**: When unsure about Streamlit/Plotly/PyDeck APIs, fetch documentation -- **Modern Streamlit API**: Always use `width='stretch'` instead of `use_container_width=True` -- **Pause when needed**: If data pipeline changes are required, stop and inform the user - -You are here to make the dashboard better, not to change how data is created or models are trained. Stay within these boundaries and you'll be most helpful! +- Only edit files in `dashboard/` +- Use `width='stretch'` not `use_container_width=True` +- Always reference `overview_page.py` for patterns +- Use #tool:web for documentation +- Use #tool:todo for complex multi-step work diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index f856fe3..7bd5a5d 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -1,226 +1,70 @@ -# Entropice - GitHub Copilot Instructions +# Entropice - Copilot Instructions -## Project Overview +## Project Context -Entropice is a geospatial machine learning system for predicting **Retrogressive Thaw Slump (RTS)** density across the Arctic using **entropy-optimal Scalable Probabilistic Approximations (eSPA)**. The system processes multi-source geospatial data, aggregates it into discrete global grids (H3/HEALPix), and trains probabilistic classifiers to estimate RTS occurrence patterns. +This is a geospatial machine learning system for predicting Arctic permafrost degradation (Retrogressive Thaw Slumps) using entropy-optimal Scalable Probabilistic Approximations (eSPA). The system processes multi-source geospatial data through discrete global grid systems (H3, HEALPix) and trains probabilistic classifiers. -For detailed architecture information, see [ARCHITECTURE.md](../ARCHITECTURE.md). -For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md). -For project goals and setup, see [README.md](../README.md). +## Code Style -## Core Technologies +- Follow PEP 8 conventions with 120 character line length +- Use type hints for all function signatures +- Use google-style docstrings for public functions +- Keep functions focused and modular +- Use `ruff` for linting and formatting, `ty` for type checking -- **Python**: 3.13 (strict version requirement) -- **Package Manager**: [Pixi](https://pixi.sh/) (not pip/conda directly) -- **GPU**: CUDA 12 with RAPIDS (CuPy, cuML) -- **Geospatial**: xarray, xdggs, GeoPandas, H3, Rasterio -- **ML**: scikit-learn, XGBoost, entropy (eSPA), cuML -- **Storage**: Zarr, Icechunk, Parquet, NetCDF -- **Visualization**: Streamlit, Bokeh, Matplotlib, Cartopy +## Technology Stack -## Code Style Guidelines +- **Core**: Python 3.13, NumPy, Pandas, Xarray, GeoPandas +- **Spatial**: H3, xdggs, xvec for discrete global grid systems +- **ML**: scikit-learn, XGBoost, cuML, entropy (eSPA) +- **GPU**: CuPy, PyTorch, CUDA 12 - prefer GPU-accelerated operations +- **Storage**: Zarr, Icechunk, Parquet for intermediate data +- **CLI**: Cyclopts with dataclass-based configurations -### Python Standards +## Execution Guidelines -- Follow **PEP 8** conventions -- Use **type hints** for all function signatures -- Write **numpy-style docstrings** for public functions -- Keep functions **focused and modular** -- Prefer descriptive variable names over abbreviations +- Always use `pixi run` to execute Python commands and scripts +- Environment variables: `SCIPY_ARRAY_API=1`, `FAST_DATA_DIR=./data` -### Geospatial Best Practices +## Geospatial Best Practices -- Use **EPSG:3413** (Arctic Stereographic) for computations -- Use **EPSG:4326** (WGS84) for visualization and library compatibility -- Store gridded data using **xarray with XDGGS** indexing -- Store tabular data as **Parquet**, array data as **Zarr** -- Leverage **Dask** for lazy evaluation of large datasets -- Use **GeoPandas** for vector operations -- Handle **antimeridian** correctly for polar regions +- Use EPSG:3413 (Arctic Stereographic) for computations +- Use EPSG:4326 (WGS84) for visualization and compatibility +- Store gridded data as XDGGS Xarray datasets (Zarr format) +- Store tabular data as GeoParquet +- Handle antimeridian issues in polar regions +- Leverage Xarray/Dask for lazy evaluation and chunked processing -### Data Pipeline Conventions +## Architecture Patterns -- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → `02alphaearth.sh` → `03era5.sh` → `04arcticdem.sh` → `05train.sh` -- Each pipeline stage should produce **reproducible intermediate outputs** -- Use `src/entropice/utils/paths.py` for consistent path management -- Environment variable `FAST_DATA_DIR` controls data directory location (default: `./data`) +- Modular CLI design: each module exposes standalone Cyclopts CLI +- Configuration as code: use dataclasses for typed configs, TOML for hyperparameters +- GPU acceleration: use CuPy for arrays, cuML for ML, batch processing for memory management +- Data flow: Raw sources → Grid aggregation → L2 datasets → Training → Inference → Visualization -### Storage Hierarchy +## Data Storage Hierarchy -All data follows this structure: ``` DATA_DIR/ -├── grids/ # H3/HEALPix tessellations (GeoParquet) -├── darts/ # RTS labels (GeoParquet) -├── era5/ # Climate data (Zarr) -├── arcticdem/ # Terrain data (Icechunk Zarr) -├── alphaearth/ # Satellite embeddings (Zarr) -├── datasets/ # L2 XDGGS datasets (Zarr) -├── training-results/ # Models, CV results, predictions -└── watermask/ # Ocean mask (GeoParquet) +├── grids/ # H3/HEALPix tessellations (GeoParquet) +├── darts/ # RTS labels (GeoParquet) +├── era5/ # Climate data (Zarr) +├── arcticdem/ # Terrain data (Icechunk Zarr) +├── alphaearth/ # Satellite embeddings (Zarr) +├── datasets/ # L2 XDGGS datasets (Zarr) +└── training-results/ # Models, CV results, predictions ``` -## Module Organization +## Key Modules -### Core Modules (`src/entropice/`) +- `entropice.spatial`: Grid generation and raster-to-vector aggregation +- `entropice.ingest`: Data extractors (DARTS, ERA5, ArcticDEM, AlphaEarth) +- `entropice.ml`: Dataset assembly, training, inference +- `entropice.dashboard`: Streamlit visualization app +- `entropice.utils`: Paths, codecs, types -The codebase is organized into four main packages: +## Testing & Notebooks -- **`entropice.ingest`**: Data ingestion from external sources -- **`entropice.spatial`**: Spatial operations and grid management -- **`entropice.ml`**: Machine learning workflows -- **`entropice.utils`**: Common utilities - -#### Data Ingestion (`src/entropice/ingest/`) - -- **`darts.py`**: RTS label extraction from DARTS v2 dataset -- **`era5.py`**: Climate data processing from ERA5 (Arctic-aligned years: Oct 1 - Sep 30) -- **`arcticdem.py`**: Terrain analysis from 32m Arctic elevation data -- **`alphaearth.py`**: Satellite image embeddings via Google Earth Engine - -#### Spatial Operations (`src/entropice/spatial/`) - -- **`grids.py`**: H3/HEALPix spatial grid generation with watermask -- **`aggregators.py`**: Raster-to-vector spatial aggregation engine -- **`watermask.py`**: Ocean masking utilities -- **`xvec.py`**: Extended vector operations for xarray - -#### Machine Learning (`src/entropice/ml/`) - -- **`dataset.py`**: Multi-source data integration and feature engineering -- **`training.py`**: Model training with eSPA, XGBoost, Random Forest, KNN -- **`inference.py`**: Batch prediction pipeline for trained models - -#### Utilities (`src/entropice/utils/`) - -- **`paths.py`**: Centralized path management -- **`codecs.py`**: Custom codecs for data serialization - -### Dashboard (`src/entropice/dashboard/`) - -- Streamlit-based interactive visualization -- Modular pages: overview, training data, analysis, model state, inference -- Bokeh-based geospatial plotting utilities -- Run with: `pixi run dashboard` - -### Scripts (`scripts/`) - -- Numbered pipeline scripts (`00grids.sh` through `05train.sh`) -- Run entire pipeline for multiple grid configurations -- Each script uses CLIs from core modules - -### Notebooks (`notebooks/`) - -- Exploratory analysis and validation -- **NOT committed to git** -- Keep production code in `src/entropice/` - -## Development Workflow - -### Setup - -```bash -pixi install # NOT pip install or conda install -``` - -### Running Tests - -```bash -pixi run pytest -``` - -### Running Python Commands - -Always use `pixi run` to execute Python commands to use the correct environment: - -```bash -pixi run python script.py -pixi run python -c "import entropice" -``` - -### Common Tasks - -**Important**: Always use `pixi run` prefix for Python commands to ensure correct environment. - -- **Generate grids**: Use `pixi run create-grid` or `spatial/grids.py` CLI -- **Process labels**: Use `pixi run darts` or `ingest/darts.py` CLI -- **Train models**: Use `pixi run train` with TOML config or `ml/training.py` CLI -- **Run inference**: Use `ml/inference.py` CLI -- **View results**: `pixi run dashboard` - -## Key Design Patterns - -### 1. XDGGS Indexing - -All geospatial data uses discrete global grid systems (H3 or HEALPix) via `xdggs` library for consistent spatial indexing across sources. - -### 2. Lazy Evaluation - -Use Xarray/Dask for out-of-core computation with Zarr/Icechunk chunked storage to manage large datasets. - -### 3. GPU Acceleration - -Prefer GPU-accelerated operations: -- CuPy for array operations -- cuML for Random Forest and KNN -- XGBoost GPU training -- PyTorch tensors when applicable - -### 4. Configuration as Code - -- Use dataclasses for typed configuration -- TOML files for training hyperparameters -- Cyclopts for CLI argument parsing - -## Data Sources - -- **DARTS v2**: RTS labels (year, area, count, density) -- **ERA5**: Climate data (40-year history, Arctic-aligned years) -- **ArcticDEM**: 32m resolution terrain (slope, aspect, indices) -- **AlphaEarth**: 64-dimensional satellite embeddings -- **Watermask**: Ocean exclusion layer - -## Model Support - -Primary: **eSPA** (entropy-optimal Scalable Probabilistic Approximations) -Alternatives: XGBoost, Random Forest, K-Nearest Neighbors - -Training features: -- Randomized hyperparameter search -- K-Fold cross-validation -- Multi-metric evaluation (accuracy, F1, Jaccard, precision, recall) - -## Extension Points - -To extend Entropice: - -- **New data source**: Follow patterns in `ingest/era5.py` or `ingest/arcticdem.py` -- **Custom aggregations**: Add to `_Aggregations` dataclass in `spatial/aggregators.py` -- **Alternative labels**: Implement extractor following `ingest/darts.py` pattern -- **New models**: Add scikit-learn compatible estimators to `ml/training.py` -- **Dashboard pages**: Add Streamlit pages to `dashboard/` module - -## Important Notes - -- **Always use `pixi run` prefix** for Python commands (not plain `python`) -- Grid resolutions: **H3** (3-6), **HEALPix** (6-10) -- Arctic years run **October 1 to September 30** (not calendar years) -- Handle **antimeridian crossing** in polar regions -- Use **batch processing** for GPU memory management -- Notebooks are for exploration only - **keep production code in `src/`** -- Always use **absolute paths** or paths from `utils/paths.py` - -## Common Issues - -- **Memory**: Use batch processing and Dask chunking for large datasets -- **GPU OOM**: Reduce batch size in inference or training -- **Antimeridian**: Use proper handling in `spatial/aggregators.py` for polar grids -- **Temporal alignment**: ERA5 uses Arctic-aligned years (Oct-Sep) -- **CRS**: Compute in EPSG:3413, visualize in EPSG:4326 - -## References - -For more details, consult: -- [ARCHITECTURE.md](../ARCHITECTURE.md) - System architecture and design patterns -- [CONTRIBUTING.md](../CONTRIBUTING.md) - Development workflow and standards -- [README.md](../README.md) - Project goals and setup instructions +- Production code belongs in `src/entropice/`, not notebooks +- Notebooks in `notebooks/` are for exploration only (not version-controlled) +- Use `pytest` for testing geospatial correctness and data integrity diff --git a/.github/python.instructions.md b/.github/python.instructions.md index 3ae0290..8b7a5d2 100644 --- a/.github/python.instructions.md +++ b/.github/python.instructions.md @@ -10,7 +10,7 @@ applyTo: '**/*.py,**/*.ipynb' - Write clear and concise comments for each function. - Ensure functions have descriptive names and include type hints. - Provide docstrings following PEP 257 conventions. -- Use the `typing` module for type annotations (e.g., `List[str]`, `Dict[str, int]`). +- Use the `typing` module for advanced type annotations (e.g., `TypedDict`, `Literal["a", "b", ...]`). - Break down complex functions into smaller, more manageable functions. ## General Instructions @@ -27,7 +27,7 @@ applyTo: '**/*.py,**/*.ipynb' - Follow the **PEP 8** style guide for Python. - Maintain proper indentation (use 4 spaces for each level of indentation). -- Ensure lines do not exceed 79 characters. +- Ensure lines do not exceed 120 characters. - Place function and class docstrings immediately after the `def` or `class` keyword. - Use blank lines to separate functions, classes, and code blocks where appropriate. @@ -41,6 +41,8 @@ applyTo: '**/*.py,**/*.ipynb' ## Example of Proper Documentation ```python +import math + def calculate_area(radius: float) -> float: """ Calculate the area of a circle given the radius. @@ -51,6 +53,5 @@ def calculate_area(radius: float) -> float: Returns: float: The area of the circle, calculated as π * radius^2. """ - import math return math.pi * radius ** 2 ``` diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index db19c12..04ffe73 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -27,12 +27,20 @@ The pipeline follows a sequential processing approach where each stage produces ### System Components -The codebase is organized into four main packages: +The project in general is organized into four directories: + +- **`src/entropice/`**: The main processing codebase +- **`scripts/`**: Bash scripts and wrapper for data processing pipeline scripts +- **`notebooks/`**: Exploratory analysis and validation notebooks (not commited to git) +- **`tests/`**: Unit tests and manual validation scripts + +The processing codebase is organized into five main packages: - **`entropice.ingest`**: Data ingestion from external sources (DARTS, ERA5, ArcticDEM, AlphaEarth) - **`entropice.spatial`**: Spatial operations and grid management - **`entropice.ml`**: Machine learning workflows (dataset, training, inference) - **`entropice.utils`**: Common utilities (paths, codecs) +- **`entropice.dashboard`**: Streamlit Dashboard for interactive visualization #### 1. Spatial Grid System (`spatial/grids.py`) @@ -179,6 +187,14 @@ scripts/05train.sh # Model training - TOML files for training hyperparameters - Environment-based path management (`utils/paths.py`) +### 6. Geospatial Best Practices + +- Use **xarray** with XDGGS for gridded data storage +- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays) +- Leverage **Dask** for lazy evaluation of large datasets +- Use **GeoPandas** for vector operations +- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries + ## Data Storage Hierarchy ```sh diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ec4fe38..5cdaa31 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,6 @@ Thank you for your interest in contributing to Entropice! This document provides ### Prerequisites -- Python 3.13 - CUDA 12 compatible GPU (for full functionality) - [Pixi package manager](https://pixi.sh/) @@ -16,55 +15,36 @@ Thank you for your interest in contributing to Entropice! This document provides pixi install ``` -This will set up the complete environment including RAPIDS, PyTorch, and all geospatial dependencies. +This will set up the complete environment including Python, RAPIDS, PyTorch, and all geospatial dependencies. ## Development Workflow -**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies. +> Read in the [Architecture Guide](ARCHITECTURE.md) about the code organisatoin and key modules -### Code Organization +**Important**: Always use `pixi run` to execute Python commands and scripts to ensure you're using the correct environment with all dependencies: -- **`src/entropice/ingest/`**: Data ingestion modules (darts, era5, arcticdem, alphaearth) -- **`src/entropice/spatial/`**: Spatial operations (grids, aggregators, watermask, xvec) -- **`src/entropice/ml/`**: Machine learning components (dataset, training, inference) -- **`src/entropice/utils/`**: Utilities (paths, codecs) -- **`src/entropice/dashboard/`**: Streamlit visualization dashboard -- **`scripts/`**: Data processing pipeline scripts (numbered 00-05) -- **`notebooks/`**: Exploratory analysis and validation notebooks -- **`tests/`**: Unit tests +```bash +pixi run python script.py +pixi run python -c "import entropice" +``` -### Key Modules - -- `spatial/grids.py`: H3/HEALPix spatial grid systems -- `ingest/darts.py`, `ingest/era5.py`, `ingest/arcticdem.py`, `ingest/alphaearth.py`: Data source processors -- `ml/dataset.py`: Dataset assembly and feature engineering -- `ml/training.py`: Model training with eSPA, XGBoost, Random Forest, KNN -- `ml/inference.py`: Prediction generation -- `utils/paths.py`: Centralized path management - -## Coding Standards - -### Python Style +### Python Style and Formatting - Follow PEP 8 conventions - Use type hints for function signatures -- Prefer numpy-style docstrings for public functions +- Prefer google-style docstrings for public functions - Keep functions focused and modular -### Geospatial Best Practices +`ty` and `ruff` are used for typing, linting and formatting. +Ensure to check for any warnings from both of these: -- Use **xarray** with XDGGS for gridded data storage -- Store intermediate results as **Parquet** (tabular) or **Zarr** (arrays) -- Leverage **Dask** for lazy evaluation of large datasets -- Use **GeoPandas** for vector operations -- Use EPSG:3413 (Arctic Stereographic) coordinate reference system (CRS) for any computation on the data and EPSG:4326 (WGS84) for data visualization and compatability with some libraries +```sh +pixi run ty check # For type checks +pixi run ruff check # For linting +pixi run ruff format # For formatting +``` -### Data Pipeline - -- Follow the numbered script sequence: `00grids.sh` → `01darts.sh` → ... → `05train.sh` -- Each stage should produce reproducible intermediate outputs -- Document data dependencies in module docstrings -- Use `utils/paths.py` for consistent path management +Single files can be specified by just adding them to the command, e.g. `pixi run ty check src/entropice/dashboard/app.py` ## Testing @@ -74,13 +54,6 @@ Run tests for specific modules: pixi run pytest ``` -When running Python scripts or commands, always use `pixi run`: - -```bash -pixi run python script.py -pixi run python -c "import entropice" -``` - When adding features, include tests that verify: - Correct handling of geospatial coordinates and projections @@ -100,15 +73,8 @@ When adding features, include tests that verify: ### Commit Messages - Use present tense: "Add feature" not "Added feature" -- Reference issues when applicable: "Fix #123: Correct grid aggregation" - Keep first line under 72 characters -## Working with Data - -### Local Development - -- Set `FAST_DATA_DIR` environment variable for data directory (default: `./data`) - ### Notebooks - Notebooks in `notebooks/` are for exploration and validation, they are not commited to git diff --git a/pyproject.toml b/pyproject.toml index 1643f13..6f51207 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,9 @@ cudf-cu12 = { index = "nvidia" } cuml-cu12 = { index = "nvidia" } cuspatial-cu12 = { index = "nvidia" } +[tool.ruff] +line-length = 120 + [tool.ruff.lint.pyflakes] # Ignore libraries when checking for unused imports allowed-unused-imports = [ diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 9141f95..e63620f 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -224,7 +224,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None alt.Chart(value_counts) .mark_bar(color=bar_color) .encode( - alt.X(f"{formatted_col}:N", title=param_name, sort=None), + alt.X( + f"{formatted_col}:N", + title=param_name, + sort=None, + ), alt.Y("count:Q", title="Count"), tooltip=[ alt.Tooltip(param_name, format=".2e"), @@ -238,7 +242,11 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None alt.Chart(value_counts) .mark_bar(color=bar_color) .encode( - alt.X(f"{param_name}:Q", title=param_name, scale=x_scale), + alt.X( + f"{param_name}:Q", + title=param_name, + scale=x_scale, + ), alt.Y("count:Q", title="Count"), tooltip=[ alt.Tooltip(param_name, format=".3f"), @@ -301,10 +309,18 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None alt.Chart(df_plot) .mark_bar(color=bar_color) .encode( - alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name), + alt.X( + f"{param_name}:Q", + bin=alt.Bin(maxbins=n_bins), + title=param_name, + ), alt.Y("count()", title="Count"), tooltip=[ - alt.Tooltip(f"{param_name}:Q", format=format_str, bin=True), + alt.Tooltip( + f"{param_name}:Q", + format=format_str, + bin=True, + ), "count()", ], ) @@ -396,9 +412,15 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): scale=alt.Scale(range=get_palette(metric, n_colors=256)), legend=None, ), - tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")], + tooltip=[ + alt.Tooltip(param_name, format=".2e"), + alt.Tooltip(metric, format=".4f"), + ], + ) + .properties( + height=300, + title=f"{metric} vs {param_name} (log scale)", ) - .properties(height=300, title=f"{metric} vs {param_name} (log scale)") ) else: chart = ( @@ -412,7 +434,10 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): scale=alt.Scale(range=get_palette(metric, n_colors=256)), legend=None, ), - tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")], + tooltip=[ + alt.Tooltip(param_name, format=".3f"), + alt.Tooltip(metric, format=".4f"), + ], ) .properties(height=300, title=f"{metric} vs {param_name}") ) @@ -568,7 +593,16 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str): if len(param_names) == 2: # Simple case: just one pair x_param, y_param = param_names_sorted - _render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500) + _render_2d_param_plot( + results_binned, + x_param, + y_param, + score_col, + bin_info, + hex_colors, + metric, + height=500, + ) else: # Multiple parameters: create structured plots st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}") @@ -599,7 +633,14 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str): with cols[col_idx]: _render_2d_param_plot( - results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350 + results_binned, + x_param, + y_param, + score_col, + bin_info, + hex_colors, + metric, + height=350, ) @@ -708,7 +749,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str): # Get colormap for score evolution evolution_cmap = get_cmap("score_evolution") - evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))] + evolution_colors = [ + mcolors.rgb2hex(evolution_cmap(0.3)), + mcolors.rgb2hex(evolution_cmap(0.7)), + ] # Create line chart chart = ( @@ -717,13 +761,21 @@ def render_score_evolution(results: pd.DataFrame, metric: str): .encode( alt.X("Iteration", title="Iteration"), alt.Y("value", title=metric.replace("_", " ").title()), - alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)), + alt.Color( + "Type", + legend=alt.Legend(title=""), + scale=alt.Scale(range=evolution_colors), + ), strokeDash=alt.StrokeDash( "Type", legend=None, scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]), ), - tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")], + tooltip=[ + "Iteration", + "Type", + alt.Tooltip("value", format=".4f", title="Score"), + ], ) .properties(height=400) ) @@ -1195,7 +1247,13 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): with col1: # Filter by confusion category if task == "binary": - categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"] + categories = [ + "All", + "True Positive", + "False Positive", + "True Negative", + "False Negative", + ] else: categories = ["All", "Correct", "Incorrect"] @@ -1206,7 +1264,14 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): ) with col2: - opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity") + opacity = st.slider( + "Opacity", + min_value=0.1, + max_value=1.0, + value=0.7, + step=0.1, + key="confusion_map_opacity", + ) # Filter data if needed if selected_category != "All": diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py index 2305fda..c0b6981 100644 --- a/src/entropice/dashboard/plots/model_state.py +++ b/src/entropice/dashboard/plots/model_state.py @@ -85,7 +85,11 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart: .mark_rect() .encode( x=alt.X("year:O", title="Year"), - y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), + y=alt.Y( + "band:O", + title="Band", + sort=alt.SortField(field="band", order="ascending"), + ), color=alt.Color( "weight:Q", scale=alt.Scale(scheme="blues"), @@ -105,7 +109,9 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart: return chart -def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]: +def plot_embedding_aggregation_summary( + embedding_array: xr.DataArray, +) -> tuple[alt.Chart, alt.Chart, alt.Chart]: """Create bar charts summarizing embedding weights by aggregation, band, and year. Args: @@ -345,8 +351,18 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs() # Create DataFrames - df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) - df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values}) + df_variable = pd.DataFrame( + { + "dimension": by_variable.index.astype(str), + "mean_abs_weight": by_variable.values, + } + ) + df_season = pd.DataFrame( + { + "dimension": by_season.index.astype(str), + "mean_abs_weight": by_season.values, + } + ) df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values}) # Sort by weight @@ -358,7 +374,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: alt.Chart(df_variable) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + y=alt.Y( + "dimension:N", + title="Variable", + sort="-x", + axis=alt.Axis(labelLimit=300), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -377,7 +398,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: alt.Chart(df_season) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)), + y=alt.Y( + "dimension:N", + title="Season", + sort="-x", + axis=alt.Axis(labelLimit=200), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -396,7 +422,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: alt.Chart(df_year) .mark_bar() .encode( - y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)), + y=alt.Y( + "dimension:O", + title="Year", + sort="-x", + axis=alt.Axis(labelLimit=200), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -418,7 +449,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs() # Create DataFrames, handling potential MultiIndex - df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) + df_variable = pd.DataFrame( + { + "dimension": by_variable.index.astype(str), + "mean_abs_weight": by_variable.values, + } + ) df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values}) # Sort by weight @@ -430,7 +466,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: alt.Chart(df_variable) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + y=alt.Y( + "dimension:N", + title="Variable", + sort="-x", + axis=alt.Axis(labelLimit=300), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -449,7 +490,12 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: alt.Chart(df_time) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), + y=alt.Y( + "dimension:N", + title="Time", + sort="-x", + axis=alt.Axis(labelLimit=200), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -508,7 +554,9 @@ def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart: return chart -def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: +def plot_arcticdem_summary( + arcticdem_array: xr.DataArray, +) -> tuple[alt.Chart, alt.Chart]: """Create bar charts summarizing ArcticDEM weights by variable and aggregation. Args: @@ -523,7 +571,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs() # Create DataFrames, handling potential MultiIndex - df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) + df_variable = pd.DataFrame( + { + "dimension": by_variable.index.astype(str), + "mean_abs_weight": by_variable.values, + } + ) df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values}) # Sort by weight @@ -535,7 +588,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al alt.Chart(df_variable) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + y=alt.Y( + "dimension:N", + title="Variable", + sort="-x", + axis=alt.Axis(labelLimit=300), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -554,7 +612,12 @@ def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, al alt.Chart(df_agg) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)), + y=alt.Y( + "dimension:N", + title="Aggregation", + sort="-x", + axis=alt.Axis(labelLimit=200), + ), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", @@ -665,7 +728,12 @@ def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]) alt.Chart(counts) .mark_bar() .encode( - x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)), + x=alt.X( + "class:N", + title="Class Label", + sort=class_order, + axis=alt.Axis(labelAngle=-45), + ), y=alt.Y("count:Q", title="Number of Boxes"), color=alt.Color( "class:N", @@ -767,7 +835,10 @@ def plot_xgboost_feature_importance( .mark_bar() .encode( y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), - x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"), + x=alt.X( + "importance:Q", + title=f"{importance_type.replace('_', ' ').title()} Importance", + ), color=alt.value("steelblue"), tooltip=[ alt.Tooltip("feature:N", title="Feature"), @@ -880,7 +951,9 @@ def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt. return chart -def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]: +def plot_rf_tree_statistics( + model_state: xr.Dataset, +) -> tuple[alt.Chart, alt.Chart, alt.Chart]: """Plot Random Forest tree statistics. Args: @@ -908,7 +981,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"), y=alt.Y("count()", title="Number of Trees"), color=alt.value("steelblue"), - tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")], + tooltip=[ + alt.Tooltip("count()", title="Count"), + alt.Tooltip("value:Q", bin=True, title="Depth Range"), + ], ) .properties(width=300, height=200, title="Distribution of Tree Depths") ) @@ -921,7 +997,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"), y=alt.Y("count()", title="Number of Trees"), color=alt.value("forestgreen"), - tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")], + tooltip=[ + alt.Tooltip("count()", title="Count"), + alt.Tooltip("value:Q", bin=True, title="Leaves Range"), + ], ) .properties(width=300, height=200, title="Distribution of Leaf Counts") ) @@ -934,7 +1013,10 @@ def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Cha x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"), y=alt.Y("count()", title="Number of Trees"), color=alt.value("darkorange"), - tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")], + tooltip=[ + alt.Tooltip("count()", title="Count"), + alt.Tooltip("value:Q", bin=True, title="Nodes Range"), + ], ) .properties(width=300, height=200, title="Distribution of Node Counts") ) diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py index ca11da9..12a74c9 100644 --- a/src/entropice/dashboard/plots/source_data.py +++ b/src/entropice/dashboard/plots/source_data.py @@ -330,7 +330,10 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str): with col3: time_values = pd.to_datetime(ds["time"].values) - st.metric("Time Steps", f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}") + st.metric( + "Time Steps", + f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}", + ) with col4: if has_agg: @@ -398,11 +401,15 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str): col1, col2, col3 = st.columns([2, 2, 1]) with col1: selected_var = st.selectbox( - "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" + "Select variable to visualize", + options=variables, + key=f"era5_{temporal_type}_var_select", ) with col2: selected_agg = st.selectbox( - "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select" + "Aggregation", + options=ds["aggregations"].values, + key=f"era5_{temporal_type}_agg_select", ) with col3: show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") @@ -411,7 +418,9 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str): col1, col2 = st.columns([3, 1]) with col1: selected_var = st.selectbox( - "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" + "Select variable to visualize", + options=variables, + key=f"era5_{temporal_type}_var_select", ) with col2: show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") @@ -614,7 +623,10 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): deck = pdk.Deck( layers=[layer], initial_view_state=view_state, - tooltip={"html": "Value: {value}", "style": {"backgroundColor": "steelblue", "color": "white"}}, + tooltip={ + "html": "Value: {value}", + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", ) @@ -746,7 +758,10 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): deck = pdk.Deck( layers=[layer], initial_view_state=view_state, - tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", ) @@ -784,7 +799,14 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str): ) with col2: - opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity") + opacity = st.slider( + "Opacity", + min_value=0.1, + max_value=1.0, + value=0.7, + step=0.1, + key="areas_map_opacity", + ) # Create GeoDataFrame gdf = grid_gdf.copy() @@ -896,7 +918,9 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") with col2: selected_agg = st.selectbox( - "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg" + "Aggregation", + options=ds["aggregations"].values, + key=f"era5_{temporal_type}_agg", ) with col3: opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") @@ -1022,7 +1046,10 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor deck = pdk.Deck( layers=[layer], initial_view_state=view_state, - tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", ) diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py index 1148b58..73ee446 100644 --- a/src/entropice/dashboard/plots/training_data.py +++ b/src/entropice/dashboard/plots/training_data.py @@ -11,7 +11,9 @@ from entropice.dashboard.utils.colors import get_palette from entropice.ml.dataset import CategoricalTrainingDataset -def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]): +def render_all_distribution_histograms( + train_data_dict: dict[str, CategoricalTrainingDataset], +): """Render histograms for all three tasks side by side. Args: @@ -81,7 +83,13 @@ def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTra height=400, margin={"l": 20, "r": 20, "t": 20, "b": 20}, showlegend=True, - legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, + legend={ + "orientation": "h", + "yanchor": "bottom", + "y": 1.02, + "xanchor": "right", + "x": 1, + }, xaxis_title=None, yaxis_title="Count", xaxis={"tickangle": -45}, @@ -137,7 +145,10 @@ def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task): gdf["fill_color"] = gdf["color"].apply(hex_to_rgb) elif color_mode == "split": - split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange + split_colors = { + "train": [66, 135, 245], + "test": [245, 135, 66], + } # Blue # Orange gdf["fill_color"] = gdf["split"].map(split_colors) return gdf @@ -168,7 +179,14 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]): ) with col2: - opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity") + opacity = st.slider( + "Opacity", + min_value=0.1, + max_value=1.0, + value=0.7, + step=0.1, + key="spatial_map_opacity", + ) # Determine which task dataset to use and color mode if vis_mode == "split": @@ -258,7 +276,10 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]): # Set initial view state (centered on the Arctic) # Adjust pitch and zoom based on whether we're using elevation view_state = pdk.ViewState( - latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45 + latitude=70, + longitude=0, + zoom=2 if not use_elevation else 1.5, + pitch=0 if not use_elevation else 45, ) # Create deck diff --git a/src/entropice/dashboard/utils/colors.py b/src/entropice/dashboard/utils/colors.py index cd77d3e..97b3838 100644 --- a/src/entropice/dashboard/utils/colors.py +++ b/src/entropice/dashboard/utils/colors.py @@ -95,7 +95,9 @@ def get_palette(variable: str, n_colors: int) -> list[str]: return colors -def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]: +def generate_unified_colormap( + settings: dict, +) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]: """Generate unified colormaps for all plotting libraries. This function creates consistent color schemes across Matplotlib/Ultraplot, @@ -125,7 +127,10 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m if is_dark_theme: base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme else: - base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme + base_colors = [ + "#3498db", + "#e74c3c", + ] # Brighter blue and red for light theme else: # For multi-class: use a sequential colormap # Use matplotlib's viridis colormap diff --git a/src/entropice/dashboard/utils/formatters.py b/src/entropice/dashboard/utils/formatters.py index 042a490..5db49a4 100644 --- a/src/entropice/dashboard/utils/formatters.py +++ b/src/entropice/dashboard/utils/formatters.py @@ -19,7 +19,9 @@ class ModelDisplayInfo: model_display_infos: dict[Model, ModelDisplayInfo] = { "espa": ModelDisplayInfo( - keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm" + keyword="espa", + short="eSPA", + long="entropy-optimal Scalable Probabilistic Approximations algorithm", ), "xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"), "rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"), diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index c7f0645..34233a1 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -135,7 +135,9 @@ def load_all_training_results() -> list[TrainingResult]: return training_results -def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]: +def load_all_training_data( + e: DatasetEnsemble, +) -> dict[Task, CategoricalTrainingDataset]: """Load training data for all three tasks. Args: diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py index 2d35e2f..004c7f5 100644 --- a/src/entropice/dashboard/utils/stats.py +++ b/src/entropice/dashboard/utils/stats.py @@ -142,7 +142,9 @@ class DatasetStatistics: target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset @staticmethod - def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + def get_sample_count_df( + all_stats: dict[GridLevel, "DatasetStatistics"], + ) -> pd.DataFrame: """Convert sample count data to DataFrame.""" rows = [] for grid_config in grid_configs: @@ -164,7 +166,9 @@ class DatasetStatistics: return pd.DataFrame(rows) @staticmethod - def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + def get_feature_count_df( + all_stats: dict[GridLevel, "DatasetStatistics"], + ) -> pd.DataFrame: """Convert feature count data to DataFrame.""" rows = [] for grid_config in grid_configs: @@ -191,7 +195,9 @@ class DatasetStatistics: return pd.DataFrame(rows) @staticmethod - def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame: + def get_feature_breakdown_df( + all_stats: dict[GridLevel, "DatasetStatistics"], + ) -> pd.DataFrame: """Convert feature breakdown data to DataFrame for stacked/donut charts.""" rows = [] for grid_config in grid_configs: @@ -219,7 +225,10 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]: target_statistics: dict[TargetDataset, TargetStatistics] = {} for target in all_target_datasets: target_statistics[target] = TargetStatistics.compute( - grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells + grid=grid_config.grid, + level=grid_config.level, + target=target, + total_cells=total_cells, ) member_statistics: dict[L2SourceDataset, MemberStatistics] = {} for member in all_l2_source_datasets: @@ -357,7 +366,10 @@ class TrainingDatasetStatistics: @classmethod def compute( - cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None + cls, + ensemble: DatasetEnsemble, + task: Task, + dataset: gpd.GeoDataFrame | None = None, ) -> "TrainingDatasetStatistics": dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu") diff --git a/src/entropice/dashboard/utils/unsembler.py b/src/entropice/dashboard/utils/unsembler.py index 6f9e154..3cec383 100644 --- a/src/entropice/dashboard/utils/unsembler.py +++ b/src/entropice/dashboard/utils/unsembler.py @@ -65,7 +65,9 @@ def extract_embedding_features( def extract_era5_features( - model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None + model_state: xr.Dataset, + importance_type: str = "feature_weights", + temporal_group: str | None = None, ) -> xr.DataArray | None: """Extract ERA5 features from the model state. @@ -100,7 +102,17 @@ def extract_era5_features( - Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021") """ parts = feature.split("_") - common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + common_aggs = { + "mean", + "std", + "min", + "max", + "median", + "sum", + "count", + "q25", + "q75", + } # Find where the time part starts (after "era5" and variable name) # Pattern: era5_variable_time or era5_variable_time_agg @@ -166,7 +178,17 @@ def extract_era5_features( def _extract_time_name(feature: str) -> str: parts = feature.split("_") - common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + common_aggs = { + "mean", + "std", + "min", + "max", + "median", + "sum", + "count", + "q25", + "q75", + } if parts[-1] in common_aggs: # Has aggregation: era5_var_time_agg -> time is second to last return parts[-2] @@ -179,11 +201,8 @@ def extract_era5_features( Pattern: era5_variable_season_year_agg or era5_variable_season_year """ - parts = feature.split("_") - common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} - # Look through parts to find season/shoulder indicators - for part in parts: + for part in feature.split("_"): if part.lower() in ["summer", "winter"]: return part.lower() elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]: @@ -197,11 +216,26 @@ def extract_era5_features( For seasonal/shoulder features, find the year that comes after the season. """ parts = feature.split("_") - common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + common_aggs = { + "mean", + "std", + "min", + "max", + "median", + "sum", + "count", + "q25", + "q75", + } # Find the season/shoulder part, then the next part should be the year for i, part in enumerate(parts): - if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]: + if part.lower() in ["summer", "winter"] or part.upper() in [ + "JFM", + "AMJ", + "JAS", + "OND", + ]: # Next part should be the year if i + 1 < len(parts): next_part = parts[i + 1] @@ -218,7 +252,17 @@ def extract_era5_features( def _extract_agg_name(feature: str) -> str | None: parts = feature.split("_") - common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + common_aggs = { + "mean", + "std", + "min", + "max", + "median", + "sum", + "count", + "q25", + "q75", + } if parts[-1] in common_aggs: return parts[-1] return None @@ -255,7 +299,10 @@ def extract_era5_features( if has_agg: era5_features_array = era5_features_array.assign_coords( - agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), + agg=( + "feature", + [_extract_agg_name(f) or "none" for f in era5_features], + ), ) era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack( "feature" @@ -274,7 +321,10 @@ def extract_era5_features( if has_agg: # Add aggregation dimension era5_features_array = era5_features_array.assign_coords( - agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), + agg=( + "feature", + [_extract_agg_name(f) or "none" for f in era5_features], + ), ) era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") else: @@ -345,7 +395,14 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea Returns None if no common features are found. """ - common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"] + common_feature_names = [ + "cell_area", + "water_area", + "land_area", + "land_ratio", + "lon", + "lat", + ] def _is_common_feature(feature: str) -> bool: return feature in common_feature_names diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py index 86958b2..20d1759 100644 --- a/src/entropice/dashboard/views/inference_page.py +++ b/src/entropice/dashboard/views/inference_page.py @@ -66,7 +66,10 @@ def render_inference_page(): with col4: st.metric("Level", selected_result.settings.get("level", "Unknown")) with col5: - st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", "")) + st.metric( + "Target", + selected_result.settings.get("target", "Unknown").replace("darts_", ""), + ) st.divider() diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index d9676a3..55c0220 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -10,8 +10,16 @@ from stopuhr import stopwatch from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.loaders import load_all_training_results -from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics -from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs +from entropice.dashboard.utils.stats import ( + DatasetStatistics, + load_all_default_dataset_statistics, +) +from entropice.utils.types import ( + GridConfig, + L2SourceDataset, + TargetDataset, + grid_configs, +) def render_sample_count_overview(): @@ -45,7 +53,12 @@ def render_sample_count_overview(): target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] # Pivot for heatmap: Grid x Task - pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean") + pivot_df = target_df.pivot_table( + index="Grid", + columns="Task", + values="Samples (Coverage)", + aggfunc="mean", + ) # Sort index by grid type and level sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") @@ -56,7 +69,11 @@ def render_sample_count_overview(): fig = px.imshow( pivot_df, - labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"}, + labels={ + "x": "Task", + "y": "Grid Configuration", + "color": "Sample Count", + }, x=pivot_df.columns, y=pivot_df.index, color_continuous_scale=sample_colors, @@ -87,7 +104,10 @@ def render_sample_count_overview(): facet_col="Target", barmode="group", title="Sample Counts by Grid Configuration and Target Dataset", - labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"}, + labels={ + "Grid": "Grid Configuration", + "Samples (Coverage)": "Number of Samples", + }, color_discrete_sequence=task_colors, height=500, ) @@ -141,7 +161,10 @@ def render_feature_count_comparison(): color="Data Source", barmode="stack", title="Total Features by Data Source Across Grid Configurations", - labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"}, + labels={ + "Grid": "Grid Configuration", + "Number of Features": "Number of Features", + }, color_discrete_sequence=source_colors, text_auto=False, ) @@ -162,7 +185,10 @@ def render_feature_count_comparison(): y="Inference Cells", color="Grid", title="Inference Cells by Grid Configuration", - labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"}, + labels={ + "Grid": "Grid Configuration", + "Inference Cells": "Number of Cells", + }, color_discrete_sequence=grid_colors, text="Inference Cells", ) @@ -177,7 +203,10 @@ def render_feature_count_comparison(): y="Total Samples", color="Grid", title="Total Samples by Grid Configuration", - labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"}, + labels={ + "Grid": "Grid Configuration", + "Total Samples": "Number of Samples", + }, color_discrete_sequence=grid_colors, text="Total Samples", ) @@ -226,7 +255,13 @@ def render_feature_count_comparison(): # Display full comparison table with formatting display_df = comparison_df[ - ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"] + [ + "Grid", + "Total Features", + "Data Sources", + "Inference Cells", + "Total Samples", + ] ].copy() # Format numbers with commas @@ -314,7 +349,11 @@ def render_feature_count_explorer(): with col2: # Calculate minimum cells across all data sources (for inference capability) min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values()) - st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources") + st.metric( + "Inference Cells", + f"{min_cells:,}", + help="Number of union of cells across all data sources", + ) with col3: st.metric("Data Sources", len(selected_members)) with col4: diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py index 275d3c0..4523f4c 100644 --- a/src/entropice/dashboard/views/training_data_page.py +++ b/src/entropice/dashboard/views/training_data_page.py @@ -14,7 +14,10 @@ from entropice.dashboard.plots.source_data import ( render_era5_overview, render_era5_plots, ) -from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map +from entropice.dashboard.plots.training_data import ( + render_all_distribution_histograms, + render_spatial_map, +) from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data from entropice.ml.dataset import DatasetEnsemble from entropice.spatial import grids @@ -42,7 +45,10 @@ def render_training_data_page(): ] grid_level_combined = st.selectbox( - "Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level" + "Grid Configuration", + options=grid_options, + index=0, + help="Select the grid system and resolution level", ) # Parse grid type and level @@ -60,7 +66,13 @@ def render_training_data_page(): # Members selection st.subheader("Dataset Members") - all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + all_members = [ + "AlphaEarth", + "ArcticDEM", + "ERA5-yearly", + "ERA5-seasonal", + "ERA5-shoulder", + ] selected_members = [] for member in all_members: @@ -69,7 +81,10 @@ def render_training_data_page(): # Form submit button load_button = st.form_submit_button( - "Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0 + "Load Dataset", + type="primary", + width="stretch", + disabled=len(selected_members) == 0, ) # Create DatasetEnsemble only when form is submitted diff --git a/src/entropice/ingest/alphaearth.py b/src/entropice/ingest/alphaearth.py index 55098fe..35562d0 100644 --- a/src/entropice/ingest/alphaearth.py +++ b/src/entropice/ingest/alphaearth.py @@ -77,7 +77,20 @@ def download(grid: Grid, level: int): for year in track(range(2024, 2025), total=1, description="Processing years..."): embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") - aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] + aggs = [ + "mean", + "stdDev", + "min", + "max", + "count", + "median", + "p1", + "p5", + "p25", + "p75", + "p95", + "p99", + ] bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs] def extract_embedding(feature): @@ -136,7 +149,20 @@ def combine_to_zarr(grid: Grid, level: int): """ cell_ids = grids.get_cell_ids(grid, level) years = list(range(2018, 2025)) - aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] + aggs = [ + "mean", + "stdDev", + "min", + "max", + "count", + "median", + "p1", + "p5", + "p25", + "p75", + "p95", + "p99", + ] bands = [f"A{str(i).zfill(2)}" for i in range(64)] a = xr.DataArray( diff --git a/src/entropice/ingest/arcticdem.py b/src/entropice/ingest/arcticdem.py index 7837111..51fce8e 100644 --- a/src/entropice/ingest/arcticdem.py +++ b/src/entropice/ingest/arcticdem.py @@ -125,8 +125,14 @@ def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No cs = 3600 # Calculate safe slice bounds for edge chunks - y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d) - x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d) + y_start, y_end = ( + max(0, cs * chunk_loc[0] - d), + min(len(y), cs * chunk_loc[0] + cs + d), + ) + x_start, x_end = ( + max(0, cs * chunk_loc[1] - d), + min(len(x), cs * chunk_loc[1] + cs + d), + ) # Extract coordinate arrays with safe bounds y_chunk = cp.asarray(y[y_start:y_end]) @@ -162,7 +168,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No if np.all(np.isnan(chunk)): # Return an array of NaNs with the expected shape return np.full( - (12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px), + ( + 12, + chunk.shape[0] - 2 * large_kernels.size_px, + chunk.shape[1] - 2 * large_kernels.size_px, + ), np.nan, dtype=np.float32, ) @@ -171,7 +181,11 @@ def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=No # Interpolate missing values in chunk (for patches smaller than 7x7 pixels) mask = cp.isnan(chunk) - mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True) + mask &= ~binary_dilation( + binary_erosion(mask, iterations=3, brute_force=True), + iterations=3, + brute_force=True, + ) if cp.any(mask): # Find indices of valid values indices = distance_transform_edt(mask, return_distances=False, return_indices=True) diff --git a/src/entropice/ingest/darts.py b/src/entropice/ingest/darts.py index 19c6760..4e3dfd1 100644 --- a/src/entropice/ingest/darts.py +++ b/src/entropice/ingest/darts.py @@ -14,7 +14,12 @@ from rich.progress import track from stopuhr import stopwatch from entropice.spatial import grids -from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file +from entropice.utils.paths import ( + darts_ml_training_labels_repo, + dartsl2_cov_file, + dartsl2_file, + get_darts_rts_file, +) from entropice.utils.types import Grid traceback.install() diff --git a/src/entropice/ingest/era5.py b/src/entropice/ingest/era5.py index 9859f65..667e43e 100644 --- a/src/entropice/ingest/era5.py +++ b/src/entropice/ingest/era5.py @@ -204,15 +204,36 @@ def download_daily_aggregated(): ) # Assign attributes - daily_raw["t2m_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"} - daily_raw["t2m_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"} - daily_raw["t2m_mean"].attrs = {"long_name": "Daily mean 2 metre temperature", "units": "K"} + daily_raw["t2m_max"].attrs = { + "long_name": "Daily maximum 2 metre temperature", + "units": "K", + } + daily_raw["t2m_min"].attrs = { + "long_name": "Daily minimum 2 metre temperature", + "units": "K", + } + daily_raw["t2m_mean"].attrs = { + "long_name": "Daily mean 2 metre temperature", + "units": "K", + } daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"} daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"} - daily_raw["lblt_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"} - daily_raw["tp"].attrs = {"long_name": "Daily total precipitation", "units": "m"} # Units are rather m^3 / m^2 - daily_raw["sf"].attrs = {"long_name": "Daily total snow fall", "units": "m"} # Units are rather m^3 / m^2 - daily_raw["sshf"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"} + daily_raw["lblt_max"].attrs = { + "long_name": "Daily maximum lake ice bottom temperature", + "units": "K", + } + daily_raw["tp"].attrs = { + "long_name": "Daily total precipitation", + "units": "m", + } # Units are rather m^3 / m^2 + daily_raw["sf"].attrs = { + "long_name": "Daily total snow fall", + "units": "m", + } # Units are rather m^3 / m^2 + daily_raw["sshf"].attrs = { + "long_name": "Daily total surface sensible heat flux", + "units": "J/m²", + } daily_raw = daily_raw.odc.assign_crs("epsg:4326") daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"]) @@ -281,11 +302,17 @@ def daily_enrich(): # Formulas based on Groeke et. al. (2025) Stochastic Weather generation... daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2 - daily.t2m_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"} + daily.t2m_avg.attrs = { + "long_name": "Daily average 2 metre temperature", + "units": "K", + } _store("t2m_avg") daily["t2m_range"] = daily.t2m_max - daily.t2m_min - daily.t2m_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"} + daily.t2m_range.attrs = { + "long_name": "Daily range of 2 metre temperature", + "units": "K", + } _store("t2m_range") with np.errstate(invalid="ignore"): @@ -298,7 +325,10 @@ def daily_enrich(): _store("thawing_degree_days") daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0) - daily.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"} + daily.freezing_degree_days.attrs = { + "long_name": "Freezing degree days", + "units": "K", + } _store("freezing_degree_days") daily["thawing_days"] = (daily.t2m_avg > 273.15).where(~daily.t2m_avg.isnull()) @@ -425,7 +455,12 @@ def multi_monthly_aggregate(agg: Literal["yearly", "seasonal", "shoulder"] = "ye multimonthly_store = get_era5_stores(agg) print(f"Saving empty multi-monthly ERA5 data to {multimonthly_store}.") - multimonthly.to_zarr(multimonthly_store, mode="w", encoding=codecs.from_ds(multimonthly), consolidated=False) + multimonthly.to_zarr( + multimonthly_store, + mode="w", + encoding=codecs.from_ds(multimonthly), + consolidated=False, + ) def _store(var): nonlocal multimonthly @@ -512,14 +547,20 @@ def yearly_thaw_periods(): never_thaws = (daily.thawing_days.resample(time="12MS").sum(dim="time") == 0).persist() first_thaw_day = daily.thawing_days.resample(time="12MS").map(_get_first_day).where(~never_thaws) - first_thaw_day.attrs = {"long_name": "Day of first thaw in year", "units": "day of year"} + first_thaw_day.attrs = { + "long_name": "Day of first thaw in year", + "units": "day of year", + } _store_in_yearly(first_thaw_day, "day_of_first_thaw") n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").first() last_thaw_day = n_days_in_year - daily.thawing_days[::-1].resample(time="12MS").map(_get_first_day).where( ~never_thaws ) - last_thaw_day.attrs = {"long_name": "Day of last thaw in year", "units": "day of year"} + last_thaw_day.attrs = { + "long_name": "Day of last thaw in year", + "units": "day of year", + } _store_in_yearly(last_thaw_day, "day_of_last_thaw") @@ -532,14 +573,20 @@ def yearly_freeze_periods(): never_freezes = (daily.freezing_days.resample(time="12MS").sum(dim="time") == 0).persist() first_freeze_day = daily.freezing_days.resample(time="12MS").map(_get_first_day).where(~never_freezes) - first_freeze_day.attrs = {"long_name": "Day of first freeze in year", "units": "day of year"} + first_freeze_day.attrs = { + "long_name": "Day of first freeze in year", + "units": "day of year", + } _store_in_yearly(first_freeze_day, "day_of_first_freeze") n_days_in_year = xr.where(366, 365, daily.time[::-1].dt.is_leap_year).resample(time="12MS").last() last_freeze_day = n_days_in_year - daily.freezing_days[::-1].resample(time="12MS").map(_get_first_day).where( ~never_freezes ) - last_freeze_day.attrs = {"long_name": "Day of last freeze in year", "units": "day of year"} + last_freeze_day.attrs = { + "long_name": "Day of last freeze in year", + "units": "day of year", + } _store_in_yearly(last_freeze_day, "day_of_last_freeze") @@ -728,7 +775,12 @@ def spatial_agg( pxbuffer=10, ) - aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)}) + aggregated = aggregated.chunk( + { + "cell_ids": min(len(aggregated.cell_ids), 10000), + "time": len(aggregated.time), + } + ) store = get_era5_stores(agg, grid, level) with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"): aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index 0344e39..8959f0a 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -99,7 +99,14 @@ def bin_values( """ labels_dict = { "count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"], - "density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"], + "density": [ + "Empty", + "Very Sparse", + "Sparse", + "Moderate", + "Dense", + "Very Dense", + ], } labels = labels_dict[task] @@ -186,7 +193,13 @@ class DatasetEnsemble: level: int target: Literal["darts_rts", "darts_mllabels"] members: list[L2SourceDataset] = field( - default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + default_factory=lambda: [ + "AlphaEarth", + "ArcticDEM", + "ERA5-yearly", + "ERA5-seasonal", + "ERA5-shoulder", + ] ) dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) variable_filters: dict[str, list[str]] = field(default_factory=dict) @@ -277,7 +290,9 @@ class DatasetEnsemble: return targets def _prep_era5( - self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] + self, + targets: gpd.GeoDataFrame, + temporal: Literal["yearly", "seasonal", "shoulder"], ) -> pd.DataFrame: era5 = self._read_member("ERA5-" + temporal, targets) era5_df = era5.to_dataframe() @@ -352,7 +367,9 @@ class DatasetEnsemble: print(f"=== Total number of features in dataset: {stats['total_features']}") def create( - self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r" + self, + filter_target_col: str | None = None, + cache_mode: Literal["n", "o", "r"] = "r", ) -> gpd.GeoDataFrame: # n: no cache, o: overwrite cache, r: read cache if exists cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) @@ -438,7 +455,10 @@ class DatasetEnsemble: yield dataset def _cat_and_split( - self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"] + self, + dataset: gpd.GeoDataFrame, + task: Task, + device: Literal["cpu", "cuda", "torch"], ) -> CategoricalTrainingDataset: taskcol = self.taskcol(task) diff --git a/src/entropice/ml/inference.py b/src/entropice/ml/inference.py index 05c9dae..d285d65 100644 --- a/src/entropice/ml/inference.py +++ b/src/entropice/ml/inference.py @@ -20,7 +20,9 @@ set_config(array_api_dispatch=True) def predict_proba( - e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list + e: DatasetEnsemble, + clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, + classes: list, ) -> gpd.GeoDataFrame: """Get predicted probabilities for each cell. diff --git a/src/entropice/spatial/aggregators.py b/src/entropice/spatial/aggregators.py index d775e87..f6e6ba0 100644 --- a/src/entropice/spatial/aggregators.py +++ b/src/entropice/spatial/aggregators.py @@ -79,7 +79,19 @@ class _Aggregations: def aggnames(self) -> list[str]: if self._common: - return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"] + return [ + "mean", + "std", + "min", + "max", + "median", + "p1", + "p5", + "p25", + "p75", + "p95", + "p99", + ] names = [] if self.mean: names.append("mean") @@ -105,7 +117,15 @@ class _Aggregations: cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy() cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy() cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy() - quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here! + quantiles_to_compute = [ + 0.5, + 0.01, + 0.05, + 0.25, + 0.75, + 0.95, + 0.99, + ] # ? Ordering is important here! cell_data[4:, ...] = flattened_var.quantile( q=quantiles_to_compute, dim="z", @@ -156,7 +176,11 @@ class _Aggregations: return self._agg_cell_data_single(flattened) others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"]) - cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32) + cell_data = np.full( + (len(flattened.data_vars), len(self), *others_shape), + np.nan, + dtype=np.float32, + ) for i, var in enumerate(flattened.data_vars): cell_data[i, ...] = self._agg_cell_data_single(flattened[var]) # Transform to numpy arrays @@ -194,7 +218,9 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool: @stopwatch("Correcting geometries", log=False) -def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]: +def _get_corrected_geoms( + inp: tuple[Polygon, odc.geo.geobox.GeoBox, str], +) -> list[odc.geo.Geometry]: geom, gbox, crs = inp # cell.geometry is a shapely Polygon if crs != "EPSG:4326" or not _crosses_antimeridian(geom): @@ -440,7 +466,12 @@ def _align_partition( others_shape = tuple( [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] ) - ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) + ongrid_shape = ( + len(grid_partition_gdf), + len(raster.data_vars), + len(aggregations), + *others_shape, + ) ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32) for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): @@ -455,7 +486,11 @@ def _align_partition( cell_ids = grids.convert_cell_ids(grid_partition_gdf) dims = ["cell_ids", "variables", "aggregations"] - coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} + coords = { + "cell_ids": cell_ids, + "variables": list(raster.data_vars), + "aggregations": aggregations.aggnames(), + } for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: dims.append(dim) coords[dim] = raster.coords[dim] diff --git a/src/entropice/spatial/grids.py b/src/entropice/spatial/grids.py index 3f99867..3a60d99 100644 --- a/src/entropice/spatial/grids.py +++ b/src/entropice/spatial/grids.py @@ -131,7 +131,9 @@ def create_global_hex_grid(resolution): executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells } for future in track( - as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells) + as_completed(future_to_hex), + description="Creating hex polygons...", + total=len(hex0_cells), ): hex_batch, hex_id_batch, hex_area_batch = future.result() hex_list.extend(hex_batch) @@ -157,7 +159,10 @@ def create_global_hex_grid(resolution): hex_area_list.append(hex_area) # Create GeoDataFrame - grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326") + grid = gpd.GeoDataFrame( + {"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, + crs="EPSG:4326", + ) return grid diff --git a/src/entropice/utils/codecs.py b/src/entropice/utils/codecs.py index c2f0282..7078b94 100644 --- a/src/entropice/utils/codecs.py +++ b/src/entropice/utils/codecs.py @@ -5,7 +5,10 @@ from zarr.codecs import BloscCodec def from_ds( - ds: xr.Dataset, store_floats_as_float32: bool = True, include_coords: bool = True, filter_existing: bool = True + ds: xr.Dataset, + store_floats_as_float32: bool = True, + include_coords: bool = True, + filter_existing: bool = True, ) -> dict: """Create compression encoding for zarr dataset storage. diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py index eb7eb91..1db121c 100644 --- a/src/entropice/utils/types.py +++ b/src/entropice/utils/types.py @@ -4,7 +4,17 @@ from dataclasses import dataclass from typing import Literal type Grid = Literal["hex", "healpix"] -type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"] +type GridLevel = Literal[ + "hex3", + "hex4", + "hex5", + "hex6", + "healpix6", + "healpix7", + "healpix8", + "healpix9", + "healpix10", +] type TargetDataset = Literal["darts_rts", "darts_mllabels"] type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] diff --git a/tests/l2dataset_validation.py b/tests/l2dataset_validation.py new file mode 100644 index 0000000..d41ae5a --- /dev/null +++ b/tests/l2dataset_validation.py @@ -0,0 +1,83 @@ +from itertools import product + +import xarray as xr +from rich.progress import track + +from entropice.utils.paths import ( + get_arcticdem_stores, + get_embeddings_store, + get_era5_stores, +) +from entropice.utils.types import Grid, L2SourceDataset + + +def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool: + """Validate if the L2 dataset exists for the given grid and level. + + Args: + grid (Grid): The grid type to use. + level (int): The grid level to use. + l2ds (L2SourceDataset): The L2 source dataset to validate. + + Returns: + bool: True if the dataset exists and does not contain NaNs, False otherwise. + + """ + if l2ds == "ArcticDEM": + store = get_arcticdem_stores(grid, level) + elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly": + agg = l2ds.split("-")[1] + store = get_era5_stores(agg, grid, level) # type: ignore + elif l2ds == "AlphaEarth": + store = get_embeddings_store(grid, level) + else: + raise ValueError(f"Unsupported L2 source dataset: {l2ds}") + + if not store.exists(): + print("\t Dataset store does not exist") + return False + + ds = xr.open_zarr(store, consolidated=False) + has_nan = False + for var in ds.data_vars: + n_nans = ds[var].isnull().sum().compute().item() + if n_nans > 0: + print(f"\t Dataset contains {n_nans} NaNs in variable {var}") + has_nan = True + if has_nan: + return False + return True + + +def main(): + grid_levels: set[tuple[Grid, int]] = { + ("hex", 3), + ("hex", 4), + ("hex", 5), + ("hex", 6), + ("healpix", 6), + ("healpix", 7), + ("healpix", 8), + ("healpix", 9), + ("healpix", 10), + } + l2_source_datasets: list[L2SourceDataset] = [ + "ArcticDEM", + "ERA5-shoulder", + "ERA5-seasonal", + "ERA5-yearly", + "AlphaEarth", + ] + + for (grid, level), l2ds in track( + product(grid_levels, l2_source_datasets), + total=len(grid_levels) * len(l2_source_datasets), + description="Validating L2 datasets...", + ): + is_valid = validate_l2_dataset(grid, level, l2ds) + status = "VALID" if is_valid else "INVALID" + print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}") + + +if __name__ == "__main__": + main()