diff --git a/pixi.lock b/pixi.lock index c1369b7..8ccac5f 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,6 +9,7 @@ environments: - https://pypi.org/simple options: channel-priority: disabled + pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -478,7 +479,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl @@ -498,7 +499,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl @@ -567,7 +568,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e2/c7/b8f2b3e53f26f8f463002f3e8023189653b627b22ba6c00ef86eaba50b73/odc_stac-0.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl @@ -1025,10 +1026,10 @@ packages: - astropy[dev] ; extra == 'dev-all' - astropy[test-all] ; extra == 'dev-all' requires_python: '>=3.11' -- pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl name: astropy-iers-data - version: 0.2025.12.15.0.40.51 - sha256: 5f89c4b9d711638005407b4b57f76d82484f0182e459df91fcdfc354c3a34d8e + version: 0.2025.12.22.0.40.30 + sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac requires_dist: - pytest ; extra == 'docs' - hypothesis ; extra == 'test' @@ -2516,10 +2517,10 @@ packages: - pkg:pypi/cycler?source=hash-mapping size: 14778 timestamp: 1764466758386 -- pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl name: cyclopts - version: 4.4.0 - sha256: 78ff95a5e52e738a1d0f01e5a3af48049c47748fa2c255f2629a4cef54dcf2b3 + version: 4.4.1 + sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde requires_dist: - attrs>=23.1.0 - docstring-parser>=0.15,<4.0 @@ -2932,7 +2933,6 @@ packages: - ruff>=0.14.9,<0.15 - pandas-stubs>=2.3.3.251201,<3 requires_python: '>=3.13,<3.14' - editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 name: entropy version: 0.1.0 @@ -7326,12 +7326,12 @@ packages: - pkg:pypi/pandas?source=compressed-mapping size: 14912799 timestamp: 1764615091147 -- pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl name: pandas-stubs - version: 2.3.3.251201 - sha256: eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5 + version: 2.3.3.251219 + sha256: ccc6337febb51d6d8a08e4c96b479478a0da0ef704b5e08bd212423fe1cb549c requires_dist: - - numpy>=1.23.5 + - numpy>=1.23.5,<=2.3.5 - types-pytz>=2022.1.1 requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2 diff --git a/scripts/01darts.sh b/scripts/01darts.sh index e363b45..720d860 100644 --- a/scripts/01darts.sh +++ b/scripts/01darts.sh @@ -1,11 +1,22 @@ #! /bin/bash -pixi run darts extract_darts_mllabels --grid hex --level 3 -pixi run darts extract_darts_mllabels --grid hex --level 4 -pixi run darts extract_darts_mllabels --grid hex --level 5 -pixi run darts extract_darts_mllabels --grid hex --level 6 -pixi run darts extract_darts_mllabels --grid healpix --level 6 -pixi run darts extract_darts_mllabels --grid healpix --level 7 -pixi run darts extract_darts_mllabels --grid healpix --level 8 -pixi run darts extract_darts_mllabels --grid healpix --level 9 -pixi run darts extract_darts_mllabels --grid healpix --level 10 \ No newline at end of file +# pixi shell +darts extract-darts-rts --grid hex --level 3 +darts extract-darts-rts --grid hex --level 4 +darts extract-darts-rts --grid hex --level 5 +darts extract-darts-rts --grid hex --level 6 +darts extract-darts-rts --grid healpix --level 6 +darts extract-darts-rts --grid healpix --level 7 +darts extract-darts-rts --grid healpix --level 8 +darts extract-darts-rts --grid healpix --level 9 +darts extract-darts-rts --grid healpix --level 10 + +darts extract-darts-mllabels --grid hex --level 3 +darts extract-darts-mllabels --grid hex --level 4 +darts extract-darts-mllabels --grid hex --level 5 +darts extract-darts-mllabels --grid hex --level 6 +darts extract-darts-mllabels --grid healpix --level 6 +darts extract-darts-mllabels --grid healpix --level 7 +darts extract-darts-mllabels --grid healpix --level 8 +darts extract-darts-mllabels --grid healpix --level 9 +darts extract-darts-mllabels --grid healpix --level 10 \ No newline at end of file diff --git a/scripts/05train.sh b/scripts/05train.sh new file mode 100644 index 0000000..851e505 --- /dev/null +++ b/scripts/05train.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model espa +pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model xgboost \ No newline at end of file diff --git a/scripts/fix_xgboost_importance.py b/scripts/fix_xgboost_importance.py new file mode 100644 index 0000000..a343fff --- /dev/null +++ b/scripts/fix_xgboost_importance.py @@ -0,0 +1,191 @@ +"""Fix XGBoost feature importance in existing model state files. + +This script repairs XGBoost model state files that have all-zero feature importance +values due to incorrect feature name lookup. It reloads the pickled models and +regenerates the feature importance arrays with the correct feature index mapping. +""" + +import pickle +from pathlib import Path + +import toml +import xarray as xr +from rich import print + +from entropice.paths import RESULTS_DIR + + +def fix_xgboost_model_state(results_dir: Path) -> bool: + """Fix a single XGBoost model state file. + + Args: + results_dir: Directory containing the model files. + + Returns: + True if fixed successfully, False otherwise. + + """ + # Check if this is an XGBoost model + settings_file = results_dir / "search_settings.toml" + if not settings_file.exists(): + return False + + settings = toml.load(settings_file) + model_type = settings.get("settings", {}).get("model", "") + + if model_type != "xgboost": + print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})") + return False + + # Check if required files exist + model_file = results_dir / "best_estimator_model.pkl" + state_file = results_dir / "best_estimator_state.nc" + + if not model_file.exists(): + print(f"⚠️ Missing model file in {results_dir.name}") + return False + + if not state_file.exists(): + print(f"⚠️ Missing state file in {results_dir.name}") + return False + + # Load the pickled model + print(f"Loading model from {results_dir.name}...") + with open(model_file, "rb") as f: + best_estimator = pickle.load(f) + + # Load the old state to get feature names + old_state = xr.open_dataset(state_file, engine="h5netcdf") + features = old_state.coords["feature"].values.tolist() + old_state.close() + + # Get the booster and extract feature importance with correct mapping + booster = best_estimator.get_booster() + + importance_weight = booster.get_score(importance_type="weight") + importance_gain = booster.get_score(importance_type="gain") + importance_cover = booster.get_score(importance_type="cover") + importance_total_gain = booster.get_score(importance_type="total_gain") + importance_total_cover = booster.get_score(importance_type="total_cover") + + # Align importance using feature indices (f0, f1, ...) + def align_importance(importance_dict, features): + """Align importance dict to feature list using feature indices.""" + return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))] + + # Create new DataArrays + feature_importance_weight = xr.DataArray( + align_importance(importance_weight, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_weight", + attrs={"description": "Number of times a feature is used to split the data across all trees."}, + ) + feature_importance_gain = xr.DataArray( + align_importance(importance_gain, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_gain", + attrs={"description": "Average gain across all splits the feature is used in."}, + ) + feature_importance_cover = xr.DataArray( + align_importance(importance_cover, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_cover", + attrs={"description": "Average coverage across all splits the feature is used in."}, + ) + feature_importance_total_gain = xr.DataArray( + align_importance(importance_total_gain, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_total_gain", + attrs={"description": "Total gain across all splits the feature is used in."}, + ) + feature_importance_total_cover = xr.DataArray( + align_importance(importance_total_cover, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_total_cover", + attrs={"description": "Total coverage across all splits the feature is used in."}, + ) + + # Create new state dataset + n_trees = booster.num_boosted_rounds() + state = xr.Dataset( + { + "feature_importance_weight": feature_importance_weight, + "feature_importance_gain": feature_importance_gain, + "feature_importance_cover": feature_importance_cover, + "feature_importance_total_gain": feature_importance_total_gain, + "feature_importance_total_cover": feature_importance_total_cover, + }, + attrs={ + "description": "Inner state of the best XGBClassifier from RandomizedSearchCV.", + "n_trees": n_trees, + "objective": str(best_estimator.objective), + }, + ) + + # Backup the old file + backup_file = state_file.with_suffix(".nc.backup") + if not backup_file.exists(): + print(f" Creating backup: {backup_file.name}") + state_file.rename(backup_file) + else: + print(" Backup already exists, removing old state file") + state_file.unlink() + + # Save the fixed state + print(f" Saving fixed state to {state_file.name}") + state.to_netcdf(state_file, engine="h5netcdf") + + # Verify the fix + total_importance = ( + feature_importance_weight.sum().item() + + feature_importance_gain.sum().item() + + feature_importance_cover.sum().item() + ) + + if total_importance > 0: + print(f" ✓ Success! Total importance: {total_importance:.2f}") + return True + else: + print(" ✗ Warning: Total importance is still 0!") + return False + + +def main(): + """Find and fix all XGBoost model state files.""" + print("Scanning for XGBoost model results...") + + # Find all result directories + result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()] + print(f"Found {len(result_dirs)} result directories") + + fixed_count = 0 + skipped_count = 0 + failed_count = 0 + + for result_dir in sorted(result_dirs): + try: + success = fix_xgboost_model_state(result_dir) + if success: + fixed_count += 1 + elif success is False: + # Explicitly skipped (not XGBoost) + skipped_count += 1 + except Exception as e: + print(f"❌ Error processing {result_dir.name}: {e}") + failed_count += 1 + + print("\n" + "=" * 60) + print("Summary:") + print(f" ✓ Fixed: {fixed_count}") + print(f" ⊘ Skipped: {skipped_count}") + print(f" ✗ Failed: {failed_count}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/entropice/dashboard/model_state_page.py b/src/entropice/dashboard/model_state_page.py index 3e424dd..031df78 100644 --- a/src/entropice/dashboard/model_state_page.py +++ b/src/entropice/dashboard/model_state_page.py @@ -5,6 +5,8 @@ import xarray as xr from entropice.dashboard.plots.colors import generate_unified_colormap from entropice.dashboard.plots.model_state import ( + plot_arcticdem_heatmap, + plot_arcticdem_summary, plot_box_assignment_bars, plot_box_assignments, plot_common_features, @@ -12,12 +14,15 @@ from entropice.dashboard.plots.model_state import ( plot_embedding_heatmap, plot_era5_heatmap, plot_era5_summary, + plot_era5_time_heatmap, plot_top_features, ) from entropice.dashboard.utils.data import ( + extract_arcticdem_features, extract_common_features, extract_embedding_features, extract_era5_features, + get_members_from_settings, load_all_training_results, ) from entropice.dashboard.utils.training import load_model_state @@ -35,14 +40,20 @@ def render_model_state_page(): st.error("No training results found. Please run a training search first.") return - # Result selection - result_options = {tr.name: tr for tr in training_results} - selected_name = st.selectbox( - "Select Training Result", - options=list(result_options.keys()), - help="Choose a training result to visualize model state", - ) - selected_result = result_options[selected_name] + # Sidebar: Training run selection + with st.sidebar: + st.header("Select Training Run") + + # Result selection with model-first naming + result_options = {tr.get_display_name("model_first"): tr for tr in training_results} + selected_name = st.selectbox( + "Training Run", + options=list(result_options.keys()), + help="Choose a training result to visualize model state", + ) + selected_result = result_options[selected_name] + + st.divider() # Get the model type from settings model_type = selected_result.settings.get("model", "espa") @@ -62,6 +73,33 @@ def render_model_state_page(): st.write(f"**Coordinates:** {list(model_state.coords)}") st.write(f"**Attributes:** {dict(model_state.attrs)}") + # Display dataset members summary + st.header("📊 Training Data Summary") + members = get_members_from_settings(selected_result.settings) + + st.markdown(f""" + **Dataset Members Used in Training:** {len(members)} + + The following data sources were used to train this model: + """) + + # Create a nice display of members with emojis + member_display = { + "AlphaEarth": "🛰️ AlphaEarth (Satellite Embeddings)", + "ArcticDEM": "🏔️ ArcticDEM (Topography)", + "ERA5-yearly": "⛅ ERA5 Yearly (Climate)", + "ERA5-seasonal": "⛅ ERA5 Seasonal (Summer/Winter)", + "ERA5-shoulder": "⛅ ERA5 Shoulder Seasons (JFM/AMJ/JAS/OND)", + } + + cols = st.columns(min(len(members), 3)) + for idx, member in enumerate(members): + with cols[idx % 3]: + display_name = member_display.get(member, f"📁 {member}") + st.info(display_name) + + st.divider() + # Render model-specific visualizations if model_type == "espa": render_espa_model_state(model_state, selected_result) @@ -81,9 +119,28 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): n_features = model_state.sizes["feature"] model_state["feature_weights"] *= n_features - # Extract different feature types - embedding_feature_array = extract_embedding_features(model_state) - era5_feature_array = extract_era5_features(model_state) + # Get members used in training + members = get_members_from_settings(selected_result.settings) + + # Extract different feature types based on what was used in training + embedding_feature_array = None + if "AlphaEarth" in members: + embedding_feature_array = extract_embedding_features(model_state) + + era5_yearly_array = None + era5_seasonal_array = None + era5_shoulder_array = None + if "ERA5-yearly" in members: + era5_yearly_array = extract_era5_features(model_state, temporal_group="yearly") + if "ERA5-seasonal" in members: + era5_seasonal_array = extract_era5_features(model_state, temporal_group="seasonal") + if "ERA5-shoulder" in members: + era5_shoulder_array = extract_era5_features(model_state, temporal_group="shoulder") + + arcticdem_feature_array = None + if "ArcticDEM" in members: + arcticdem_feature_array = extract_arcticdem_features(model_state) + common_feature_array = extract_common_features(model_state) # Generate unified colormaps @@ -107,7 +164,7 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): with st.spinner("Generating feature importance plot..."): feature_chart = plot_top_features(model_state, top_n=top_n) - st.altair_chart(feature_chart, use_container_width=True) + st.altair_chart(feature_chart, width="stretch") st.markdown( """ @@ -136,12 +193,12 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): with col1: st.markdown("### Assignment Heatmap") box_assignment_heatmap = plot_box_assignments(model_state) - st.altair_chart(box_assignment_heatmap, use_container_width=True) + st.altair_chart(box_assignment_heatmap, width="stretch") with col2: st.markdown("### Box Count by Class") box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors) - st.altair_chart(box_assignment_bars, use_container_width=True) + st.altair_chart(box_assignment_bars, width="stretch") # Show statistics with st.expander("Box Assignment Statistics"): @@ -179,9 +236,19 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): if embedding_feature_array is not None: render_embedding_features(embedding_feature_array) - # ERA5 features analysis (if present) - if era5_feature_array is not None: - render_era5_features(era5_feature_array) + # ERA5 features analysis (if present) - split by temporal group + if era5_yearly_array is not None: + render_era5_features(era5_yearly_array, temporal_group="Yearly") + + if era5_seasonal_array is not None: + render_era5_features(era5_seasonal_array, temporal_group="Seasonal") + + if era5_shoulder_array is not None: + render_era5_features(era5_shoulder_array, temporal_group="Shoulder") + + # ArcticDEM features analysis (if present) + if arcticdem_feature_array is not None: + render_arcticdem_features(arcticdem_feature_array) # Common features analysis (if present) if common_feature_array is not None: @@ -237,7 +304,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result): with st.spinner("Generating feature importance plot..."): importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n) - st.altair_chart(importance_chart, use_container_width=True) + st.altair_chart(importance_chart, width="stretch") # Comparison of importance types st.subheader("Importance Type Comparison") @@ -245,7 +312,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result): with st.spinner("Generating importance comparison..."): comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15) - st.altair_chart(comparison_chart, use_container_width=True) + st.altair_chart(comparison_chart, width="stretch") # Statistics with st.expander("Model Statistics"): @@ -256,6 +323,64 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result): with col2: st.metric("Total Features", model_state.sizes.get("feature", "N/A")) + # Feature source analysis + st.subheader("Feature Importance by Data Source") + st.markdown( + """ + Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate, + ArcticDEM topography, and common features). + """ + ) + + # Get members used in training + members = get_members_from_settings(selected_result.settings) + + # Extract features by source using the selected importance type + importance_var = f"feature_importance_{importance_type}" + + embedding_feature_array = None + if "AlphaEarth" in members: + embedding_feature_array = extract_embedding_features(model_state, importance_type=importance_var) + + era5_yearly_array = None + era5_seasonal_array = None + era5_shoulder_array = None + if "ERA5-yearly" in members: + era5_yearly_array = extract_era5_features(model_state, importance_type=importance_var, temporal_group="yearly") + if "ERA5-seasonal" in members: + era5_seasonal_array = extract_era5_features( + model_state, importance_type=importance_var, temporal_group="seasonal" + ) + if "ERA5-shoulder" in members: + era5_shoulder_array = extract_era5_features( + model_state, importance_type=importance_var, temporal_group="shoulder" + ) + + arcticdem_feature_array = None + if "ArcticDEM" in members: + arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type=importance_var) + + common_feature_array = extract_common_features(model_state, importance_type=importance_var) + + # Render each source's features if present + if embedding_feature_array is not None: + render_embedding_features(embedding_feature_array) + + if era5_yearly_array is not None: + render_era5_features(era5_yearly_array, temporal_group="Yearly") + + if era5_seasonal_array is not None: + render_era5_features(era5_seasonal_array, temporal_group="Seasonal") + + if era5_shoulder_array is not None: + render_era5_features(era5_shoulder_array, temporal_group="Shoulder") + + if arcticdem_feature_array is not None: + render_arcticdem_features(arcticdem_feature_array) + + if common_feature_array is not None: + render_common_features(common_feature_array) + def render_rf_model_state(model_state: xr.Dataset, selected_result): """Render visualizations for Random Forest model.""" @@ -302,7 +427,7 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result): with st.spinner("Generating feature importance plot..."): importance_chart = plot_rf_feature_importance(model_state, top_n=top_n) - st.altair_chart(importance_chart, use_container_width=True) + st.altair_chart(importance_chart, width="stretch") # Tree statistics (only if available - sklearn RF has them, cuML RF doesn't) if not is_cuml and "tree_depths" in model_state: @@ -315,11 +440,11 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result): chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state) col1, col2, col3 = st.columns(3) with col1: - st.altair_chart(chart_depths, use_container_width=True) + st.altair_chart(chart_depths, width="stretch") with col2: - st.altair_chart(chart_leaves, use_container_width=True) + st.altair_chart(chart_leaves, width="stretch") with col3: - st.altair_chart(chart_nodes, use_container_width=True) + st.altair_chart(chart_nodes, width="stretch") # Statistics with st.expander("Forest Statistics"): @@ -345,6 +470,64 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result): st.metric("Max Nodes", f"{nodes.max()}") st.metric("Min Nodes", f"{nodes.min()}") + # Feature source analysis + st.subheader("Feature Importance by Data Source") + st.markdown( + """ + Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate, + ArcticDEM topography, and common features). + """ + ) + + # Get members used in training + members = get_members_from_settings(selected_result.settings) + + # Extract features by source + embedding_feature_array = None + if "AlphaEarth" in members: + embedding_feature_array = extract_embedding_features(model_state, importance_type="feature_importance") + + era5_yearly_array = None + era5_seasonal_array = None + era5_shoulder_array = None + if "ERA5-yearly" in members: + era5_yearly_array = extract_era5_features( + model_state, importance_type="feature_importance", temporal_group="yearly" + ) + if "ERA5-seasonal" in members: + era5_seasonal_array = extract_era5_features( + model_state, importance_type="feature_importance", temporal_group="seasonal" + ) + if "ERA5-shoulder" in members: + era5_shoulder_array = extract_era5_features( + model_state, importance_type="feature_importance", temporal_group="shoulder" + ) + + arcticdem_feature_array = None + if "ArcticDEM" in members: + arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type="feature_importance") + + common_feature_array = extract_common_features(model_state, importance_type="feature_importance") + + # Render each source's features if present + if embedding_feature_array is not None: + render_embedding_features(embedding_feature_array) + + if era5_yearly_array is not None: + render_era5_features(era5_yearly_array, temporal_group="Yearly") + + if era5_seasonal_array is not None: + render_era5_features(era5_seasonal_array, temporal_group="Seasonal") + + if era5_shoulder_array is not None: + render_era5_features(era5_shoulder_array, temporal_group="Shoulder") + + if arcticdem_feature_array is not None: + render_arcticdem_features(arcticdem_feature_array) + + if common_feature_array is not None: + render_common_features(common_feature_array) + def render_knn_model_state(model_state: xr.Dataset, selected_result): """Render visualizations for KNN model.""" @@ -402,18 +585,18 @@ def render_embedding_features(embedding_feature_array): chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array) col1, col2, col3 = st.columns(3) with col1: - st.altair_chart(chart_agg, use_container_width=True) + st.altair_chart(chart_agg, width="stretch") with col2: - st.altair_chart(chart_band, use_container_width=True) + st.altair_chart(chart_band, width="stretch") with col3: - st.altair_chart(chart_year, use_container_width=True) + st.altair_chart(chart_year, width="stretch") # Detailed heatmap st.markdown("### Detailed Heatmap by Aggregation") st.markdown("Shows the weight of each band-year combination for each aggregation type.") with st.spinner("Generating heatmap..."): heatmap_chart = plot_embedding_heatmap(embedding_feature_array) - st.altair_chart(heatmap_chart, use_container_width=True) + st.altair_chart(heatmap_chart, width="stretch") # Statistics with st.expander("Embedding Feature Statistics"): @@ -436,13 +619,21 @@ def render_embedding_features(embedding_feature_array): st.dataframe(top_emb, width="stretch") -def render_era5_features(era5_feature_array): - """Render ERA5 feature visualizations.""" +def render_era5_features(era5_feature_array, temporal_group: str = ""): + """Render ERA5 feature visualizations. + + Args: + era5_feature_array: ERA5 feature importance array. + temporal_group: Name of the temporal grouping (e.g., "Yearly", "Seasonal", "Shoulder"). + + """ + group_suffix = f" ({temporal_group})" if temporal_group else "" + with st.container(border=True): - st.header("⛅ ERA5 Feature Analysis") + st.header(f"⛅ ERA5 Feature Analysis{group_suffix}") st.markdown( - """ - Analysis of ERA5 climate features showing which variables and time periods + f""" + Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods are most important for the model predictions. """ ) @@ -450,19 +641,50 @@ def render_era5_features(era5_feature_array): # Summary bar charts st.markdown("### Importance by Dimension") with st.spinner("Generating ERA5 dimension summaries..."): - chart_variable, chart_time = plot_era5_summary(era5_feature_array) - col1, col2 = st.columns(2) - with col1: - st.altair_chart(chart_variable, use_container_width=True) - with col2: - st.altair_chart(chart_time, use_container_width=True) + charts = plot_era5_summary(era5_feature_array) + + # Check if this is seasonal/shoulder data (returns 3 charts) or yearly (returns 2 charts) + if len(charts) == 3: + chart_variable, chart_season, chart_year = charts + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_variable, width="stretch") + with col2: + st.altair_chart(chart_season, width="stretch") + with col3: + st.altair_chart(chart_year, width="stretch") + else: + chart_variable, chart_time = charts + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, width="stretch") + with col2: + st.altair_chart(chart_time, width="stretch") # Detailed heatmap st.markdown("### Detailed Heatmap") - st.markdown("Shows the weight of each variable-time combination.") - with st.spinner("Generating ERA5 heatmap..."): - era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) - st.altair_chart(era5_heatmap_chart, use_container_width=True) + + # Check if this is seasonal/shoulder data + has_season = "season" in era5_feature_array.dims + + if has_season: + st.markdown("Shows the weight of each variable-season-year combination.") + with st.spinner("Generating ERA5 season heatmap..."): + era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, width="stretch") + + # Add time-based heatmap for seasonal/shoulder + st.markdown("### By Time Heatmap") + st.markdown("Shows temporal trends by averaging over seasons.") + with st.spinner("Generating ERA5 time heatmap..."): + era5_time_heatmap_chart = plot_era5_time_heatmap(era5_feature_array) + if era5_time_heatmap_chart is not None: + st.altair_chart(era5_time_heatmap_chart, width="stretch") + else: + st.markdown("Shows the weight of each variable-time combination.") + with st.spinner("Generating ERA5 heatmap..."): + era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, width="stretch") # Statistics with st.expander("ERA5 Feature Statistics"): @@ -481,10 +703,61 @@ def render_era5_features(era5_feature_array): # Show top ERA5 features st.write("**Top 10 ERA5 Features:**") era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() - top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] + # Get all columns except 'weight' for display + display_cols = [col for col in era5_df.columns if col != "weight"] + ["weight"] + top_era5 = era5_df.nlargest(10, "weight")[display_cols] st.dataframe(top_era5, width="stretch") +def render_arcticdem_features(arcticdem_feature_array): + """Render ArcticDEM feature visualizations.""" + with st.container(border=True): + st.header("🏔️ ArcticDEM Feature Analysis") + st.markdown( + """ + Analysis of ArcticDEM topographic features showing which terrain variables and + aggregations are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating ArcticDEM dimension summaries..."): + chart_variable, chart_agg = plot_arcticdem_summary(arcticdem_feature_array) + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, width="stretch") + with col2: + st.altair_chart(chart_agg, width="stretch") + + # Detailed heatmap + st.markdown("### Detailed Heatmap") + st.markdown("Shows the weight of each variable-aggregation combination.") + with st.spinner("Generating ArcticDEM heatmap..."): + arcticdem_heatmap_chart = plot_arcticdem_heatmap(arcticdem_feature_array) + st.altair_chart(arcticdem_heatmap_chart, width="stretch") + + # Statistics + with st.expander("ArcticDEM Feature Statistics"): + st.write("**Overall Statistics:**") + n_arcticdem_features = arcticdem_feature_array.size + mean_weight = float(arcticdem_feature_array.mean().values) + max_weight = float(arcticdem_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total ArcticDEM Features", n_arcticdem_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top ArcticDEM features + st.write("**Top 10 ArcticDEM Features:**") + arcticdem_df = arcticdem_feature_array.to_dataframe(name="weight").reset_index() + top_arcticdem = arcticdem_df.nlargest(10, "weight")[["variable", "agg", "weight"]] + st.dataframe(top_arcticdem, width="stretch") + + def render_common_features(common_feature_array): """Render common feature visualizations.""" with st.container(border=True): @@ -499,7 +772,7 @@ def render_common_features(common_feature_array): # Bar chart showing all common feature weights with st.spinner("Generating common features chart..."): common_chart = plot_common_features(common_feature_array) - st.altair_chart(common_chart, use_container_width=True) + st.altair_chart(common_chart, width="stretch") # Statistics with st.expander("Common Feature Statistics"): diff --git a/src/entropice/dashboard/overview_page.py b/src/entropice/dashboard/overview_page.py index 5ea8754..e373dae 100644 --- a/src/entropice/dashboard/overview_page.py +++ b/src/entropice/dashboard/overview_page.py @@ -96,7 +96,7 @@ def render_overview_page(): st.subheader("Detailed Results") for tr in training_results: - with st.expander(tr.name): + with st.expander(tr.get_display_name("task_first")): col1, col2 = st.columns([1, 2]) with col1: diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 58eafae..75381c5 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -35,7 +35,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str): best_score = results[col].max() best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"}) - st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True) + st.dataframe(pd.DataFrame(best_scores), hide_index=True, width="stretch") with col2: st.markdown("#### Score Statistics") @@ -51,7 +51,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str): } ) - st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True) + st.dataframe(pd.DataFrame(score_stats), hide_index=True, width="stretch") # Show best parameter combination in a cleaner format (similar to old dashboard) st.markdown("#### 🏆 Best Parameter Combination") @@ -305,7 +305,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None .properties(height=250, title=param_name) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") else: # Categorical parameter - use bar chart @@ -323,7 +323,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None .properties(height=250, title=param_name) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") def render_score_vs_parameter(results: pd.DataFrame, metric: str): @@ -411,7 +411,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): .properties(height=300, title=f"{metric} vs {param_name}") ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") else: # Categorical parameter - box plot @@ -428,7 +428,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): .properties(height=300, title=f"{metric} vs {param_name}") ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") def render_parameter_correlation(results: pd.DataFrame, metric: str): @@ -481,7 +481,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str): .properties(height=max(200, len(correlations) * 30)) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") def render_binned_parameter_space(results: pd.DataFrame, metric: str): @@ -673,7 +673,7 @@ def _render_2d_param_plot( .interactive() ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") def render_score_evolution(results: pd.DataFrame, metric: str): @@ -722,7 +722,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str): .properties(height=400) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") # Show statistics col1, col2, col3, col4 = st.columns(4) @@ -831,19 +831,37 @@ def render_multi_metric_comparison(results: pd.DataFrame): else: tooltip_list.append(color_col) + # Calculate axis domains starting from min values + x_min = df_plot[metric1].min() + x_max = df_plot[metric1].max() + y_min = df_plot[metric2].min() + y_max = df_plot[metric2].max() + + # Add small padding (2% of range) for better visualization + x_padding = (x_max - x_min) * 0.02 + y_padding = (y_max - y_min) * 0.02 + chart = ( alt.Chart(df_plot) .mark_circle(size=60, opacity=0.6) .encode( - alt.X(metric1, title=metric1.replace("_", " ").title()), - alt.Y(metric2, title=metric2.replace("_", " ").title()), + alt.X( + metric1, + title=metric1.replace("_", " ").title(), + scale=alt.Scale(domain=[x_min - x_padding, x_max + x_padding]), + ), + alt.Y( + metric2, + title=metric2.replace("_", " ").title(), + scale=alt.Scale(domain=[y_min - y_padding, y_max + y_padding]), + ), alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()), tooltip=tooltip_list, ) .properties(height=500) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") # Calculate correlation corr = df_plot[[metric1, metric2]].corr().iloc[0, 1] @@ -1015,7 +1033,7 @@ def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin ) ) - st.altair_chart(chart, use_container_width=True) + st.altair_chart(chart, width="stretch") # Show statistics about the binning n_bins = len(bin_order) @@ -1073,4 +1091,4 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1 score_col_display = metric.replace("_", " ").title() display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}") - st.dataframe(display_df, hide_index=True, use_container_width=True) + st.dataframe(display_df, hide_index=True, width="stretch") diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py index 113bfac..2305fda 100644 --- a/src/entropice/dashboard/plots/model_state.py +++ b/src/entropice/dashboard/plots/model_state.py @@ -69,6 +69,13 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart: Altair chart showing the heatmap. """ + # Filter out "count" aggregation if present + if "agg" in embedding_array.dims: + agg_values = embedding_array.coords["agg"].values + non_count_aggs = [agg for agg in agg_values if agg != "count"] + if non_count_aggs: + embedding_array = embedding_array.sel(agg=non_count_aggs) + # Convert to DataFrame for plotting df = embedding_array.to_dataframe(name="weight").reset_index() @@ -81,7 +88,7 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart: y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), color=alt.Color( "weight:Q", - scale=alt.Scale(scheme="redblue", domainMid=0), + scale=alt.Scale(scheme="blues"), title="Weight", ), tooltip=[ @@ -108,15 +115,22 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a Tuple of three Altair charts (by_agg, by_band, by_year). """ + # Filter out "count" aggregation if present + if "agg" in embedding_array.dims: + agg_values = embedding_array.coords["agg"].values + non_count_aggs = [agg for agg in agg_values if agg != "count"] + if non_count_aggs: + embedding_array = embedding_array.sel(agg=non_count_aggs) + # Aggregate by different dimensions by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs() by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs() by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs() - # Create DataFrames - df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()}) - df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()}) - df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()}) + # Create DataFrames, handling potential MultiIndex + df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values}) + df_band = pd.DataFrame({"dimension": by_band.index.astype(str), "mean_abs_weight": by_band.values}) + df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values}) # Sort by weight df_agg = df_agg.sort_values("mean_abs_weight", ascending=True) @@ -185,10 +199,10 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart: - """Create a heatmap showing ERA5 feature weights across variables and time. + """Create a heatmap showing ERA5 feature weights across variables and time/season. Args: - era5_array: DataArray with dimensions (variable, time) containing feature weights. + era5_array: DataArray with dimensions (variable, time) or (variable, season, year) containing feature weights. Returns: Altair chart showing the heatmap. @@ -197,12 +211,89 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart: # Convert to DataFrame for plotting df = era5_array.to_dataframe(name="weight").reset_index() + # Determine if this is seasonal/shoulder data (has 'season' dimension) + has_season = "season" in df.columns + + if has_season: + # For seasonal/shoulder: create faceted heatmap by season + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("year:N", title="Year", sort=None), + y=alt.Y("variable:N", title="Variable", sort="-color"), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("variable:N", title="Variable"), + alt.Tooltip("season:N", title="Season"), + alt.Tooltip("year:N", title="Year"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties(width=200, height=300) + .facet(facet=alt.Facet("season:N", title="Season"), columns=4) + ) + else: + # For yearly: simple heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("time:N", title="Time", sort=None), + y=alt.Y("variable:N", title="Variable", sort="-color"), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("variable:N", title="Variable"), + alt.Tooltip("time:N", title="Time"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties( + height=400, + title="ERA5 Feature Weights Heatmap", + ) + ) + + return chart + + +def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart: + """Create a heatmap showing ERA5 feature weights by variable and year (averaging over season). + + This is specifically for seasonal/shoulder data to show temporal trends. + + Args: + era5_array: DataArray with dimensions (variable, season, year, agg) containing feature weights. + + Returns: + Altair chart showing the variable-year heatmap. + + """ + # Check if this has season dimension + if "season" not in era5_array.dims: + return None + + # Average over season and agg to get (variable, year) + dims_to_avg = [d for d in era5_array.dims if d not in ["variable", "year"]] + by_time = era5_array.mean(dim=dims_to_avg) + + # Convert to DataFrame for plotting + df = by_time.to_dataframe(name="weight").reset_index() + # Create heatmap chart = ( alt.Chart(df) .mark_rect() .encode( - x=alt.X("time:N", title="Time", sort=None), + x=alt.X("year:N", title="Year", sort=None), y=alt.Y("variable:N", title="Variable", sort="-color"), color=alt.Color( "weight:Q", @@ -211,40 +302,233 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart: ), tooltip=[ alt.Tooltip("variable:N", title="Variable"), - alt.Tooltip("time:N", title="Time"), - alt.Tooltip("weight:Q", format=".4f", title="Weight"), + alt.Tooltip("year:N", title="Year"), + alt.Tooltip("weight:Q", format=".4f", title="Avg Weight"), ], ) .properties( height=400, - title="ERA5 Feature Weights Heatmap", + title="ERA5 Feature Weights by Time (averaged over seasons)", ) ) return chart -def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: - """Create bar charts summarizing ERA5 weights by variable and time. +def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]: + """Create bar charts summarizing ERA5 weights by variable and time/season. Args: - era5_array: DataArray with dimensions (variable, time) containing feature weights. + era5_array: DataArray with dimensions (variable, time) or (variable, season, year, agg) containing feature weights. Returns: - Tuple of two Altair charts (by_variable, by_time). + Tuple of Altair charts: + - For yearly data: (by_variable, by_time) + - For seasonal/shoulder data: (by_variable, by_season, by_year) + + """ + # Determine which dimensions to average over + dims_to_avg_for_var = [d for d in era5_array.dims if d != "variable"] + + # Check if this is seasonal/shoulder data + has_season = "season" in era5_array.dims + + # Aggregate by variable (average over all other dimensions) + by_variable = era5_array.mean(dim=dims_to_avg_for_var).to_pandas().abs() + + if has_season: + # For seasonal/shoulder: aggregate by season and year + dims_to_avg_for_season = [d for d in era5_array.dims if d != "season"] + by_season = era5_array.mean(dim=dims_to_avg_for_season).to_pandas().abs() + + dims_to_avg_for_year = [d for d in era5_array.dims if d != "year"] + by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs() + + # Create DataFrames + df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) + df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values}) + df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values}) + + # Sort by weight + df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) + df_season = df_season.sort_values("mean_abs_weight", ascending=True) + df_year = df_year.sort_values("mean_abs_weight", ascending=True) + + chart_variable = ( + alt.Chart(df_variable) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="purples"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Variable"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Variable") + ) + + chart_season = ( + alt.Chart(df_season) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="teals"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Season"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Season") + ) + + chart_year = ( + alt.Chart(df_year) + .mark_bar() + .encode( + y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="oranges"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:O", title="Year"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Time") + ) + + return chart_variable, chart_season, chart_year + else: + # For yearly: aggregate by time + dims_to_avg_for_time = [d for d in era5_array.dims if d != "time"] + by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs() + + # Create DataFrames, handling potential MultiIndex + df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) + df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values}) + + # Sort by weight + df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) + df_time = df_time.sort_values("mean_abs_weight", ascending=True) + + # Create charts with different colors + chart_variable = ( + alt.Chart(df_variable) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="purples"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Variable"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Variable") + ) + + chart_time = ( + alt.Chart(df_time) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="teals"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Time"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Time") + ) + + return chart_variable, chart_time + + +def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart: + """Create a heatmap showing ArcticDEM feature weights across variables and aggregations. + + Args: + arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights. + + Returns: + Altair chart showing the heatmap. + + """ + # Convert to DataFrame for plotting + df = arcticdem_array.to_dataframe(name="weight").reset_index() + + # Create heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("agg:N", title="Aggregation", sort=None), + y=alt.Y("variable:N", title="Variable", sort="-color"), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("variable:N", title="Variable"), + alt.Tooltip("agg:N", title="Aggregation"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties( + width=400, + height=300, + title="ArcticDEM Feature Weights Heatmap", + ) + ) + + return chart + + +def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: + """Create bar charts summarizing ArcticDEM weights by variable and aggregation. + + Args: + arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights. + + Returns: + Tuple of two Altair charts (by_variable, by_agg). """ # Aggregate by different dimensions - by_variable = era5_array.mean(dim="time").to_pandas().abs() - by_time = era5_array.mean(dim="variable").to_pandas().abs() + by_variable = arcticdem_array.mean(dim="agg").to_pandas().abs() + by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs() - # Create DataFrames - df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()}) - df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()}) + # Create DataFrames, handling potential MultiIndex + df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values}) + df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values}) # Sort by weight df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) - df_time = df_time.sort_values("mean_abs_weight", ascending=True) + df_agg = df_agg.sort_values("mean_abs_weight", ascending=True) # Create charts with different colors chart_variable = ( @@ -255,7 +539,7 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", - scale=alt.Scale(scheme="purples"), + scale=alt.Scale(scheme="browns"), legend=None, ), tooltip=[ @@ -266,26 +550,26 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: .properties(width=400, height=300, title="By Variable") ) - chart_time = ( - alt.Chart(df_time) + chart_agg = ( + alt.Chart(df_agg) .mark_bar() .encode( - y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), + y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)), x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), color=alt.Color( "mean_abs_weight:Q", - scale=alt.Scale(scheme="teals"), + scale=alt.Scale(scheme="reds"), legend=None, ), tooltip=[ - alt.Tooltip("dimension:N", title="Time"), + alt.Tooltip("dimension:N", title="Aggregation"), alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), ], ) - .properties(width=400, height=300, title="By Time") + .properties(width=400, height=300, title="By Aggregation") ) - return chart_variable, chart_time + return chart_variable, chart_agg def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart: diff --git a/src/entropice/dashboard/training_analysis_page.py b/src/entropice/dashboard/training_analysis_page.py index 0b506c1..1a72de7 100644 --- a/src/entropice/dashboard/training_analysis_page.py +++ b/src/entropice/dashboard/training_analysis_page.py @@ -36,8 +36,8 @@ def render_training_analysis_page(): with st.sidebar: st.header("Select Training Run") - # Create selection options - training_options = {tr.name: tr for tr in training_results} + # Create selection options with task-first naming + training_options = {tr.get_display_name("task_first"): tr for tr in training_results} selected_name = st.selectbox( "Training Run", @@ -163,7 +163,7 @@ def render_training_analysis_page(): with st.expander("📋 Parameter Space Summary", expanded=False): param_summary = get_parameter_space_summary(results) if not param_summary.empty: - st.dataframe(param_summary, hide_index=True, use_container_width=True) + st.dataframe(param_summary, hide_index=True, width="stretch") else: st.info("No parameter information available.") @@ -226,7 +226,7 @@ def render_training_analysis_page(): data=csv_data, file_name=f"{selected_result.path.name}_results.csv", mime="text/csv", - use_container_width=True, + width="stretch", ) with col2: @@ -239,9 +239,9 @@ def render_training_analysis_page(): data=settings_json, file_name=f"{selected_result.path.name}_settings.json", mime="application/json", - use_container_width=True, + width="stretch", ) # Show raw data preview st.subheader("Raw Data Preview") - st.dataframe(results.head(100), use_container_width=True) + st.dataframe(results.head(100), width="stretch") diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py index 1e8f805..1dc280b 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/training_data_page.py @@ -78,7 +78,7 @@ def render_training_data_page(): # Form submit button load_button = st.form_submit_button( - "Load Dataset", type="primary", use_container_width=True, disabled=len(selected_members) == 0 + "Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0 ) # Create DatasetEnsemble only when form is submitted diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py index d64a70f..83515b0 100644 --- a/src/entropice/dashboard/utils/data.py +++ b/src/entropice/dashboard/utils/data.py @@ -25,6 +25,27 @@ class TrainingResult: results: pd.DataFrame created_at: float + def get_display_name(self, format_type: str = "task_first") -> str: + """Get formatted display name for the training result. + + Args: + format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state) + + Returns: + Formatted name string + + """ + task = self.settings.get("task", "Unknown").capitalize() + model = self.settings.get("model", "Unknown").upper() + grid = self.settings.get("grid", "Unknown").capitalize() + level = self.settings.get("level", "Unknown") + timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M") + + if format_type == "model_first": + return f"{model} - {task} - {grid}-{level} ({timestamp})" + else: # task_first + return f"{task} - {model} - {grid}-{level} ({timestamp})" + @classmethod def from_path(cls, result_path: Path) -> "TrainingResult": """Load a TrainingResult from a given result directory path.""" @@ -38,9 +59,11 @@ class TrainingResult: settings = toml.load(settings_file)["settings"] results = pd.read_parquet(result_file) - # Name should be "task grid-level (created_at)" + # Name should be "task model grid-level (created_at)" + model = settings.get("model", "Unknown").upper() name = ( - f"**{settings.get('task', 'Unknown').capitalize()}** -" + f"{settings.get('task', 'Unknown').capitalize()} -" + f" {model} -" f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}" f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})" ) @@ -121,11 +144,40 @@ def load_source_data(e: DatasetEnsemble, source: str): return ds, targets -def extract_embedding_features(model_state) -> xr.DataArray | None: - """Extract embedding features from the model state. +def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray: + """Get the appropriate feature importance array from model state. Args: model_state: The xarray Dataset containing the model state. + importance_type: Type of importance to extract. For ESPA: 'feature_weights'. + For XGBoost: 'feature_importance_weight', 'feature_importance_gain', etc. + For RF: 'feature_importance'. + + Returns: + xr.DataArray with feature importance values. + + """ + if importance_type in model_state: + return model_state[importance_type] + # Fallback for compatibility + if "feature_weights" in model_state: + return model_state["feature_weights"] + if "feature_importance" in model_state: + return model_state["feature_importance"] + raise ValueError(f"No feature importance array found. Available: {list(model_state.data_vars)}") + + +def extract_embedding_features( + model_state: xr.Dataset, importance_type: str = "feature_weights" +) -> xr.DataArray | None: + """Extract embedding features from the model state. + + Feature naming pattern: `embeddings_{agg}_{band}_{year}` + Example: `embeddings_mean_B02_2020` + + Args: + model_state: The xarray Dataset containing the model state. + importance_type: Type of feature importance to extract. Returns: xr.DataArray: The extracted embedding features. This DataArray has dimensions @@ -135,14 +187,17 @@ def extract_embedding_features(model_state) -> xr.DataArray | None: """ def _is_embedding_feature(feature: str) -> bool: - return feature.startswith("embedding_") + return feature.startswith("embeddings_") embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] if len(embedding_features) == 0: return None + # Get the appropriate importance array + importance_array = _get_feature_importance_array(model_state, importance_type) + # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) - embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] + embedding_feature_array = importance_array.sel(feature=embedding_features) embedding_feature_array = embedding_feature_array.assign_coords( agg=("feature", [f.split("_")[1] for f in embedding_features]), band=("feature", [f.split("_")[2] for f in embedding_features]), @@ -152,15 +207,26 @@ def extract_embedding_features(model_state) -> xr.DataArray | None: return embedding_feature_array -def extract_era5_features(model_state) -> xr.DataArray | None: +def extract_era5_features( + model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None +) -> xr.DataArray | None: """Extract ERA5 features from the model state. + Feature naming patterns: + - Without aggregations: `era5_{variable}_{time}` + Example: `era5_temperature_2020_summer` + - With aggregations: `era5_{variable}_{time}_{agg}` + Example: `era5_temperature_2020_summer_mean` + Args: model_state: The xarray Dataset containing the model state. + importance_type: Type of feature importance to extract. + temporal_group: Filter to specific temporal group ('yearly', 'seasonal', 'shoulder'). + If None, extracts all ERA5 features. Returns: xr.DataArray: The extracted ERA5 features. This DataArray has dimensions - ('variable', 'time') corresponding to the different components of the ERA5 features. + ('variable', 'time') or ('variable', 'time', 'agg') depending on the feature structure. Returns None if no ERA5 features are found. """ @@ -168,34 +234,254 @@ def extract_era5_features(model_state) -> xr.DataArray | None: def _is_era5_feature(feature: str) -> bool: return feature.startswith("era5_") + def _get_temporal_group(feature: str) -> str: + """Determine temporal group from feature name based on time component. + + Time patterns: + - Yearly: just a year number (e.g., "2020") + - Seasonal: season_year (e.g., "summer_2020", "winter_2021") + - Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021") + """ + parts = feature.split("_") + common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + + # Find where the time part starts (after "era5" and variable name) + # Pattern: era5_variable_time or era5_variable_time_agg + # Time can be: year, season_year, or SHOULDER_year + + # Remove era5 prefix and aggregation suffix if present + if parts[-1] in common_aggs: + time_parts = parts[1:-1] # Remove era5 and agg + else: + time_parts = parts[1:] # Remove era5 + + # The last part (or last two parts) is the time + # Check last part first + last_part = time_parts[-1] + + if last_part.isdigit(): + # Could be yearly (just year) or part of seasonal/shoulder (season_year or SHOULDER_year) + if len(time_parts) >= 2: + second_last = time_parts[-2] + # Check if second to last is a season or shoulder indicator + if second_last.lower() in ["summer", "winter"]: + return "seasonal" + elif second_last.upper() in ["JFM", "AMJ", "JAS", "OND"]: + return "shoulder" + # Just a year number + return "yearly" + elif last_part.lower() in ["summer", "winter"]: + return "seasonal" + elif last_part.upper() in ["JFM", "AMJ", "JAS", "OND"]: + return "shoulder" + + return "unknown" + def _extract_var_name(feature: str) -> str: parts = feature.split("_") - # era5_variablename_timetype format - return "_".join(parts[1:-1]) + # Pattern: era5_variable_year_agg OR era5_variable_season_year_agg + # Variable can include stat names like t2m_max, t2m_min, etc. + # We need to find where the year starts (it's always a 4-digit number) + + # Find the first year (4-digit number) + year_idx = None + for i, part in enumerate(parts): + if part.isdigit() and len(part) == 4: + year_idx = i + break + + if year_idx is None: + # Fallback: return everything after era5 + return "_".join(parts[1:]) + + # Variable is everything between era5 and year + # But also check if there's a season before the year + seasons = {"summer", "winter"} + shoulders = {"JFM", "AMJ", "JAS", "OND"} + + var_end_idx = year_idx + # Check if the part before year is a season/shoulder + if year_idx > 1 and (parts[year_idx - 1].lower() in seasons or parts[year_idx - 1].upper() in shoulders): + var_end_idx = year_idx - 1 + + # Variable is from index 1 (after era5) to var_end_idx + return "_".join(parts[1:var_end_idx]) def _extract_time_name(feature: str) -> str: parts = feature.split("_") - # Last part is the time type - return parts[-1] + common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + if parts[-1] in common_aggs: + # Has aggregation: era5_var_time_agg -> time is second to last + return parts[-2] + else: + # No aggregation: era5_var_time -> time is last + return parts[-1] + def _extract_season(feature: str) -> str | None: + """Extract season/shoulder from seasonal/shoulder features. + + Pattern: era5_variable_season_year_agg or era5_variable_season_year + """ + parts = feature.split("_") + common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + + # Look through parts to find season/shoulder indicators + for part in parts: + if part.lower() in ["summer", "winter"]: + return part.lower() + elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]: + return part.upper() + + return None + + def _extract_year_from_time(feature: str) -> str: + """Extract just the year component from time. + + For seasonal/shoulder features, find the year that comes after the season. + """ + parts = feature.split("_") + common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + + # Find the season/shoulder part, then the next part should be the year + for i, part in enumerate(parts): + if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]: + # Next part should be the year + if i + 1 < len(parts): + next_part = parts[i + 1] + # Skip if it's an aggregation, year should be before agg + if next_part not in common_aggs: + return next_part + + # Fallback: return last numeric part that's not an aggregation + for part in reversed(parts): + if part.isdigit(): + return part + + return "unknown" + + def _extract_agg_name(feature: str) -> str | None: + parts = feature.split("_") + common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"} + if parts[-1] in common_aggs: + return parts[-1] + return None + + # Get all ERA5 features era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] + + # Filter by temporal group if specified + if temporal_group is not None: + era5_features = [f for f in era5_features if _get_temporal_group(f) == temporal_group] + if len(era5_features) == 0: return None - # Split the single feature dimension of era5 features into separate dimensions (variable, time) - era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] - era5_features_array = era5_features_array.assign_coords( - variable=("feature", [_extract_var_name(f) for f in era5_features]), - time=("feature", [_extract_time_name(f) for f in era5_features]), - ) - era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + + # Get the appropriate importance array + importance_array = _get_feature_importance_array(model_state, importance_type) + + # Check if features have aggregations + has_agg = any(_extract_agg_name(f) is not None for f in era5_features) + + # Check if features have season/shoulder split (for seasonal/shoulder temporal groups) + has_season = any(_extract_season(f) is not None for f in era5_features) + + # Split the single feature dimension of era5 features into separate dimensions + era5_features_array = importance_array.sel(feature=era5_features) + + if has_season: + # For seasonal/shoulder: split into variable, season, year, (agg) + era5_features_array = era5_features_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in era5_features]), + season=("feature", [_extract_season(f) for f in era5_features]), + year=("feature", [_extract_year_from_time(f) for f in era5_features]), + ) + + if has_agg: + era5_features_array = era5_features_array.assign_coords( + agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), + ) + era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack( + "feature" + ) # noqa: PD010 + else: + era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year"]).unstack( + "feature" + ) # noqa: PD010 + else: + # For yearly: keep as variable, time, (agg) + era5_features_array = era5_features_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in era5_features]), + time=("feature", [_extract_time_name(f) for f in era5_features]), + ) + + if has_agg: + # Add aggregation dimension + era5_features_array = era5_features_array.assign_coords( + agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), + ) + era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010 + else: + era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + return era5_features_array -def extract_common_features(model_state) -> xr.DataArray | None: +def extract_arcticdem_features( + model_state: xr.Dataset, importance_type: str = "feature_weights" +) -> xr.DataArray | None: + """Extract ArcticDEM features from the model state. + + Feature naming pattern: `arcticdem_{variable}_{agg}` + Example: `arcticdem_elevation_mean`, `arcticdem_slope_std` + + Args: + model_state: The xarray Dataset containing the model state. + importance_type: Type of feature importance to extract. + + Returns: + xr.DataArray: The extracted ArcticDEM features. This DataArray has dimensions + ('variable', 'agg') corresponding to the different components of the ArcticDEM features. + Returns None if no ArcticDEM features are found. + + """ + + def _is_arcticdem_feature(feature: str) -> bool: + return feature.startswith("arcticdem_") + + def _extract_var_name(feature: str) -> str: + parts = feature.split("_") + # arcticdem_variable_agg format + # Variable name is everything between arcticdem_ and the last part + return "_".join(parts[1:-1]) + + def _extract_agg_name(feature: str) -> str: + parts = feature.split("_") + # Last part is the aggregation + return parts[-1] + + arcticdem_features = [f for f in model_state.feature.to_numpy() if _is_arcticdem_feature(f)] + if len(arcticdem_features) == 0: + return None + + # Get the appropriate importance array + importance_array = _get_feature_importance_array(model_state, importance_type) + + # Split the single feature dimension into separate dimensions (variable, agg) + arcticdem_feature_array = importance_array.sel(feature=arcticdem_features) + arcticdem_feature_array = arcticdem_feature_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in arcticdem_features]), + agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]), + ) + arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010 + return arcticdem_feature_array + + +def extract_common_features(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray | None: """Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state. Args: model_state: The xarray Dataset containing the model state. + importance_type: Type of feature importance to extract. Returns: xr.DataArray: The extracted common features with a single 'feature' dimension. @@ -211,6 +497,22 @@ def extract_common_features(model_state) -> xr.DataArray | None: if len(common_features) == 0: return None - # Extract the feature weights for common features - common_feature_array = model_state.sel(feature=common_features)["feature_weights"] + # Get the appropriate importance array + importance_array = _get_feature_importance_array(model_state, importance_type) + + # Extract the feature importance for common features + common_feature_array = importance_array.sel(feature=common_features) return common_feature_array + + +def get_members_from_settings(settings: dict) -> list[str]: + """Extract the list of dataset members used in training from settings. + + Args: + settings: Training settings dictionary. + + Returns: + List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']). + + """ + return settings.get("members", []) diff --git a/src/entropice/training.py b/src/entropice/training.py index bab3363..bc34803 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -278,6 +278,7 @@ def random_cv( booster = best_estimator.get_booster() # Feature importance with different importance types + # Note: get_score() returns dict with keys like 'f0', 'f1', etc. (feature indices) importance_weight = booster.get_score(importance_type="weight") importance_gain = booster.get_score(importance_type="gain") importance_cover = booster.get_score(importance_type="cover") @@ -286,8 +287,11 @@ def random_cv( # Create aligned arrays for all features (including zero-importance) def align_importance(importance_dict, features): - """Align importance dict to feature list, filling missing with 0.""" - return [importance_dict.get(f, 0.0) for f in features] + """Align importance dict to feature list, filling missing with 0. + + XGBoost returns feature indices (f0, f1, ...) as keys, so we need to map them. + """ + return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))] feature_importance_weight = xr.DataArray( align_importance(importance_weight, features), diff --git a/test_feature_extraction.py b/test_feature_extraction.py new file mode 100644 index 0000000..fee12ad --- /dev/null +++ b/test_feature_extraction.py @@ -0,0 +1,146 @@ +"""Test script to verify feature extraction works correctly.""" + +import numpy as np +import xarray as xr + +# Create a mock model state with various feature types +features = [ + # Embedding features: embedding_{agg}_{band}_{year} + "embedding_mean_B02_2020", + "embedding_std_B03_2021", + "embedding_max_B04_2022", + # ERA5 features without aggregations: era5_{variable}_{time} + "era5_temperature_2020_summer", + "era5_precipitation_2021_winter", + # ERA5 features with aggregations: era5_{variable}_{time}_{agg} + "era5_temperature_2020_summer_mean", + "era5_precipitation_2021_winter_std", + # ArcticDEM features: arcticdem_{variable}_{agg} + "arcticdem_elevation_mean", + "arcticdem_slope_std", + "arcticdem_aspect_max", + # Common features + "cell_area", + "water_area", + "land_area", + "land_ratio", + "lon", + "lat", +] + +# Create mock importance values +importance_values = np.random.rand(len(features)) + +# Create a mock model state for ESPA +model_state_espa = xr.Dataset( + { + "feature_weights": xr.DataArray( + importance_values, + dims=["feature"], + coords={"feature": features}, + ) + } +) + +# Create a mock model state for XGBoost +model_state_xgb = xr.Dataset( + { + "feature_importance_gain": xr.DataArray( + importance_values, + dims=["feature"], + coords={"feature": features}, + ), + "feature_importance_weight": xr.DataArray( + importance_values * 0.8, + dims=["feature"], + coords={"feature": features}, + ), + } +) + +# Create a mock model state for Random Forest +model_state_rf = xr.Dataset( + { + "feature_importance": xr.DataArray( + importance_values, + dims=["feature"], + coords={"feature": features}, + ) + } +) + +# Test extraction functions +from entropice.dashboard.utils.data import ( + extract_arcticdem_features, + extract_common_features, + extract_embedding_features, + extract_era5_features, +) + +print("=" * 80) +print("Testing ESPA model state") +print("=" * 80) + +embedding_array = extract_embedding_features(model_state_espa) +print(f"\nEmbedding features extracted: {embedding_array is not None}") +if embedding_array is not None: + print(f" Dimensions: {embedding_array.dims}") + print(f" Shape: {embedding_array.shape}") + print(f" Coordinates: {list(embedding_array.coords)}") + +era5_array = extract_era5_features(model_state_espa) +print(f"\nERA5 features extracted: {era5_array is not None}") +if era5_array is not None: + print(f" Dimensions: {era5_array.dims}") + print(f" Shape: {era5_array.shape}") + print(f" Coordinates: {list(era5_array.coords)}") + +arcticdem_array = extract_arcticdem_features(model_state_espa) +print(f"\nArcticDEM features extracted: {arcticdem_array is not None}") +if arcticdem_array is not None: + print(f" Dimensions: {arcticdem_array.dims}") + print(f" Shape: {arcticdem_array.shape}") + print(f" Coordinates: {list(arcticdem_array.coords)}") + +common_array = extract_common_features(model_state_espa) +print(f"\nCommon features extracted: {common_array is not None}") +if common_array is not None: + print(f" Dimensions: {common_array.dims}") + print(f" Shape: {common_array.shape}") + print(f" Size: {common_array.size}") + +print("\n" + "=" * 80) +print("Testing XGBoost model state") +print("=" * 80) + +embedding_array_xgb = extract_embedding_features(model_state_xgb, importance_type="feature_importance_gain") +print(f"\nEmbedding features (gain) extracted: {embedding_array_xgb is not None}") +if embedding_array_xgb is not None: + print(f" Dimensions: {embedding_array_xgb.dims}") + print(f" Shape: {embedding_array_xgb.shape}") + +era5_array_xgb = extract_era5_features(model_state_xgb, importance_type="feature_importance_weight") +print(f"\nERA5 features (weight) extracted: {era5_array_xgb is not None}") +if era5_array_xgb is not None: + print(f" Dimensions: {era5_array_xgb.dims}") + print(f" Shape: {era5_array_xgb.shape}") + +print("\n" + "=" * 80) +print("Testing Random Forest model state") +print("=" * 80) + +embedding_array_rf = extract_embedding_features(model_state_rf, importance_type="feature_importance") +print(f"\nEmbedding features extracted: {embedding_array_rf is not None}") +if embedding_array_rf is not None: + print(f" Dimensions: {embedding_array_rf.dims}") + print(f" Shape: {embedding_array_rf.shape}") + +arcticdem_array_rf = extract_arcticdem_features(model_state_rf, importance_type="feature_importance") +print(f"\nArcticDEM features extracted: {arcticdem_array_rf is not None}") +if arcticdem_array_rf is not None: + print(f" Dimensions: {arcticdem_array_rf.dims}") + print(f" Shape: {arcticdem_array_rf.shape}") + +print("\n" + "=" * 80) +print("All tests completed successfully!") +print("=" * 80)