Make the Model State Page great again

This commit is contained in:
Tobias Hölzer 2025-12-25 18:19:11 +01:00
parent 591da6992e
commit 1919cc6a7e
13 changed files with 1375 additions and 142 deletions

28
pixi.lock generated
View file

@ -9,6 +9,7 @@ environments:
- https://pypi.org/simple
options:
channel-priority: disabled
pypi-prerelease-mode: if-necessary-or-explicit
packages:
linux-64:
- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda
@ -478,7 +479,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
@ -498,7 +499,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
@ -567,7 +568,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e2/c7/b8f2b3e53f26f8f463002f3e8023189653b627b22ba6c00ef86eaba50b73/odc_stac-0.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl
@ -1025,10 +1026,10 @@ packages:
- astropy[dev] ; extra == 'dev-all'
- astropy[test-all] ; extra == 'dev-all'
requires_python: '>=3.11'
- pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
name: astropy-iers-data
version: 0.2025.12.15.0.40.51
sha256: 5f89c4b9d711638005407b4b57f76d82484f0182e459df91fcdfc354c3a34d8e
version: 0.2025.12.22.0.40.30
sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac
requires_dist:
- pytest ; extra == 'docs'
- hypothesis ; extra == 'test'
@ -2516,10 +2517,10 @@ packages:
- pkg:pypi/cycler?source=hash-mapping
size: 14778
timestamp: 1764466758386
- pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
name: cyclopts
version: 4.4.0
sha256: 78ff95a5e52e738a1d0f01e5a3af48049c47748fa2c255f2629a4cef54dcf2b3
version: 4.4.1
sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde
requires_dist:
- attrs>=23.1.0
- docstring-parser>=0.15,<4.0
@ -2932,7 +2933,6 @@ packages:
- ruff>=0.14.9,<0.15
- pandas-stubs>=2.3.3.251201,<3
requires_python: '>=3.13,<3.14'
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
version: 0.1.0
@ -7326,12 +7326,12 @@ packages:
- pkg:pypi/pandas?source=compressed-mapping
size: 14912799
timestamp: 1764615091147
- pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl
name: pandas-stubs
version: 2.3.3.251201
sha256: eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5
version: 2.3.3.251219
sha256: ccc6337febb51d6d8a08e4c96b479478a0da0ef704b5e08bd212423fe1cb549c
requires_dist:
- numpy>=1.23.5
- numpy>=1.23.5,<=2.3.5
- types-pytz>=2022.1.1
requires_python: '>=3.10'
- conda: https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2

View file

@ -1,11 +1,22 @@
#! /bin/bash
pixi run darts extract_darts_mllabels --grid hex --level 3
pixi run darts extract_darts_mllabels --grid hex --level 4
pixi run darts extract_darts_mllabels --grid hex --level 5
pixi run darts extract_darts_mllabels --grid hex --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 7
pixi run darts extract_darts_mllabels --grid healpix --level 8
pixi run darts extract_darts_mllabels --grid healpix --level 9
pixi run darts extract_darts_mllabels --grid healpix --level 10
# pixi shell
darts extract-darts-rts --grid hex --level 3
darts extract-darts-rts --grid hex --level 4
darts extract-darts-rts --grid hex --level 5
darts extract-darts-rts --grid hex --level 6
darts extract-darts-rts --grid healpix --level 6
darts extract-darts-rts --grid healpix --level 7
darts extract-darts-rts --grid healpix --level 8
darts extract-darts-rts --grid healpix --level 9
darts extract-darts-rts --grid healpix --level 10
darts extract-darts-mllabels --grid hex --level 3
darts extract-darts-mllabels --grid hex --level 4
darts extract-darts-mllabels --grid hex --level 5
darts extract-darts-mllabels --grid hex --level 6
darts extract-darts-mllabels --grid healpix --level 6
darts extract-darts-mllabels --grid healpix --level 7
darts extract-darts-mllabels --grid healpix --level 8
darts extract-darts-mllabels --grid healpix --level 9
darts extract-darts-mllabels --grid healpix --level 10

4
scripts/05train.sh Normal file
View file

@ -0,0 +1,4 @@
#!/bin/bash
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model espa
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model xgboost

View file

@ -0,0 +1,191 @@
"""Fix XGBoost feature importance in existing model state files.
This script repairs XGBoost model state files that have all-zero feature importance
values due to incorrect feature name lookup. It reloads the pickled models and
regenerates the feature importance arrays with the correct feature index mapping.
"""
import pickle
from pathlib import Path
import toml
import xarray as xr
from rich import print
from entropice.paths import RESULTS_DIR
def fix_xgboost_model_state(results_dir: Path) -> bool:
"""Fix a single XGBoost model state file.
Args:
results_dir: Directory containing the model files.
Returns:
True if fixed successfully, False otherwise.
"""
# Check if this is an XGBoost model
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
return False
settings = toml.load(settings_file)
model_type = settings.get("settings", {}).get("model", "")
if model_type != "xgboost":
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
return False
# Check if required files exist
model_file = results_dir / "best_estimator_model.pkl"
state_file = results_dir / "best_estimator_state.nc"
if not model_file.exists():
print(f"⚠️ Missing model file in {results_dir.name}")
return False
if not state_file.exists():
print(f"⚠️ Missing state file in {results_dir.name}")
return False
# Load the pickled model
print(f"Loading model from {results_dir.name}...")
with open(model_file, "rb") as f:
best_estimator = pickle.load(f)
# Load the old state to get feature names
old_state = xr.open_dataset(state_file, engine="h5netcdf")
features = old_state.coords["feature"].values.tolist()
old_state.close()
# Get the booster and extract feature importance with correct mapping
booster = best_estimator.get_booster()
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
importance_total_gain = booster.get_score(importance_type="total_gain")
importance_total_cover = booster.get_score(importance_type="total_cover")
# Align importance using feature indices (f0, f1, ...)
def align_importance(importance_dict, features):
"""Align importance dict to feature list using feature indices."""
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
# Create new DataArrays
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_weight",
attrs={"description": "Number of times a feature is used to split the data across all trees."},
)
feature_importance_gain = xr.DataArray(
align_importance(importance_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_gain",
attrs={"description": "Average gain across all splits the feature is used in."},
)
feature_importance_cover = xr.DataArray(
align_importance(importance_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_cover",
attrs={"description": "Average coverage across all splits the feature is used in."},
)
feature_importance_total_gain = xr.DataArray(
align_importance(importance_total_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_gain",
attrs={"description": "Total gain across all splits the feature is used in."},
)
feature_importance_total_cover = xr.DataArray(
align_importance(importance_total_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_cover",
attrs={"description": "Total coverage across all splits the feature is used in."},
)
# Create new state dataset
n_trees = booster.num_boosted_rounds()
state = xr.Dataset(
{
"feature_importance_weight": feature_importance_weight,
"feature_importance_gain": feature_importance_gain,
"feature_importance_cover": feature_importance_cover,
"feature_importance_total_gain": feature_importance_total_gain,
"feature_importance_total_cover": feature_importance_total_cover,
},
attrs={
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
"n_trees": n_trees,
"objective": str(best_estimator.objective),
},
)
# Backup the old file
backup_file = state_file.with_suffix(".nc.backup")
if not backup_file.exists():
print(f" Creating backup: {backup_file.name}")
state_file.rename(backup_file)
else:
print(" Backup already exists, removing old state file")
state_file.unlink()
# Save the fixed state
print(f" Saving fixed state to {state_file.name}")
state.to_netcdf(state_file, engine="h5netcdf")
# Verify the fix
total_importance = (
feature_importance_weight.sum().item()
+ feature_importance_gain.sum().item()
+ feature_importance_cover.sum().item()
)
if total_importance > 0:
print(f" ✓ Success! Total importance: {total_importance:.2f}")
return True
else:
print(" ✗ Warning: Total importance is still 0!")
return False
def main():
"""Find and fix all XGBoost model state files."""
print("Scanning for XGBoost model results...")
# Find all result directories
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
print(f"Found {len(result_dirs)} result directories")
fixed_count = 0
skipped_count = 0
failed_count = 0
for result_dir in sorted(result_dirs):
try:
success = fix_xgboost_model_state(result_dir)
if success:
fixed_count += 1
elif success is False:
# Explicitly skipped (not XGBoost)
skipped_count += 1
except Exception as e:
print(f"❌ Error processing {result_dir.name}: {e}")
failed_count += 1
print("\n" + "=" * 60)
print("Summary:")
print(f" ✓ Fixed: {fixed_count}")
print(f" ⊘ Skipped: {skipped_count}")
print(f" ✗ Failed: {failed_count}")
print("=" * 60)
if __name__ == "__main__":
main()

View file

@ -5,6 +5,8 @@ import xarray as xr
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap,
plot_arcticdem_summary,
plot_box_assignment_bars,
plot_box_assignments,
plot_common_features,
@ -12,12 +14,15 @@ from entropice.dashboard.plots.model_state import (
plot_embedding_heatmap,
plot_era5_heatmap,
plot_era5_summary,
plot_era5_time_heatmap,
plot_top_features,
)
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
@ -35,15 +40,21 @@ def render_model_state_page():
st.error("No training results found. Please run a training search first.")
return
# Result selection
result_options = {tr.name: tr for tr in training_results}
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
# Result selection with model-first naming
result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Select Training Result",
"Training Run",
options=list(result_options.keys()),
help="Choose a training result to visualize model state",
)
selected_result = result_options[selected_name]
st.divider()
# Get the model type from settings
model_type = selected_result.settings.get("model", "espa")
@ -62,6 +73,33 @@ def render_model_state_page():
st.write(f"**Coordinates:** {list(model_state.coords)}")
st.write(f"**Attributes:** {dict(model_state.attrs)}")
# Display dataset members summary
st.header("📊 Training Data Summary")
members = get_members_from_settings(selected_result.settings)
st.markdown(f"""
**Dataset Members Used in Training:** {len(members)}
The following data sources were used to train this model:
""")
# Create a nice display of members with emojis
member_display = {
"AlphaEarth": "🛰️ AlphaEarth (Satellite Embeddings)",
"ArcticDEM": "🏔️ ArcticDEM (Topography)",
"ERA5-yearly": "⛅ ERA5 Yearly (Climate)",
"ERA5-seasonal": "⛅ ERA5 Seasonal (Summer/Winter)",
"ERA5-shoulder": "⛅ ERA5 Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
cols = st.columns(min(len(members), 3))
for idx, member in enumerate(members):
with cols[idx % 3]:
display_name = member_display.get(member, f"📁 {member}")
st.info(display_name)
st.divider()
# Render model-specific visualizations
if model_type == "espa":
render_espa_model_state(model_state, selected_result)
@ -81,9 +119,28 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
# Extract different feature types
# Get members used in training
members = get_members_from_settings(selected_result.settings)
# Extract different feature types based on what was used in training
embedding_feature_array = None
if "AlphaEarth" in members:
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
era5_yearly_array = None
era5_seasonal_array = None
era5_shoulder_array = None
if "ERA5-yearly" in members:
era5_yearly_array = extract_era5_features(model_state, temporal_group="yearly")
if "ERA5-seasonal" in members:
era5_seasonal_array = extract_era5_features(model_state, temporal_group="seasonal")
if "ERA5-shoulder" in members:
era5_shoulder_array = extract_era5_features(model_state, temporal_group="shoulder")
arcticdem_feature_array = None
if "ArcticDEM" in members:
arcticdem_feature_array = extract_arcticdem_features(model_state)
common_feature_array = extract_common_features(model_state)
# Generate unified colormaps
@ -107,7 +164,7 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
with st.spinner("Generating feature importance plot..."):
feature_chart = plot_top_features(model_state, top_n=top_n)
st.altair_chart(feature_chart, use_container_width=True)
st.altair_chart(feature_chart, width="stretch")
st.markdown(
"""
@ -136,12 +193,12 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
with col1:
st.markdown("### Assignment Heatmap")
box_assignment_heatmap = plot_box_assignments(model_state)
st.altair_chart(box_assignment_heatmap, use_container_width=True)
st.altair_chart(box_assignment_heatmap, width="stretch")
with col2:
st.markdown("### Box Count by Class")
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
st.altair_chart(box_assignment_bars, use_container_width=True)
st.altair_chart(box_assignment_bars, width="stretch")
# Show statistics
with st.expander("Box Assignment Statistics"):
@ -179,9 +236,19 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
if embedding_feature_array is not None:
render_embedding_features(embedding_feature_array)
# ERA5 features analysis (if present)
if era5_feature_array is not None:
render_era5_features(era5_feature_array)
# ERA5 features analysis (if present) - split by temporal group
if era5_yearly_array is not None:
render_era5_features(era5_yearly_array, temporal_group="Yearly")
if era5_seasonal_array is not None:
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
if era5_shoulder_array is not None:
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
# ArcticDEM features analysis (if present)
if arcticdem_feature_array is not None:
render_arcticdem_features(arcticdem_feature_array)
# Common features analysis (if present)
if common_feature_array is not None:
@ -237,7 +304,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
with st.spinner("Generating feature importance plot..."):
importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
st.altair_chart(importance_chart, use_container_width=True)
st.altair_chart(importance_chart, width="stretch")
# Comparison of importance types
st.subheader("Importance Type Comparison")
@ -245,7 +312,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
with st.spinner("Generating importance comparison..."):
comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
st.altair_chart(comparison_chart, use_container_width=True)
st.altair_chart(comparison_chart, width="stretch")
# Statistics
with st.expander("Model Statistics"):
@ -256,6 +323,64 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
with col2:
st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
# Feature source analysis
st.subheader("Feature Importance by Data Source")
st.markdown(
"""
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
ArcticDEM topography, and common features).
"""
)
# Get members used in training
members = get_members_from_settings(selected_result.settings)
# Extract features by source using the selected importance type
importance_var = f"feature_importance_{importance_type}"
embedding_feature_array = None
if "AlphaEarth" in members:
embedding_feature_array = extract_embedding_features(model_state, importance_type=importance_var)
era5_yearly_array = None
era5_seasonal_array = None
era5_shoulder_array = None
if "ERA5-yearly" in members:
era5_yearly_array = extract_era5_features(model_state, importance_type=importance_var, temporal_group="yearly")
if "ERA5-seasonal" in members:
era5_seasonal_array = extract_era5_features(
model_state, importance_type=importance_var, temporal_group="seasonal"
)
if "ERA5-shoulder" in members:
era5_shoulder_array = extract_era5_features(
model_state, importance_type=importance_var, temporal_group="shoulder"
)
arcticdem_feature_array = None
if "ArcticDEM" in members:
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type=importance_var)
common_feature_array = extract_common_features(model_state, importance_type=importance_var)
# Render each source's features if present
if embedding_feature_array is not None:
render_embedding_features(embedding_feature_array)
if era5_yearly_array is not None:
render_era5_features(era5_yearly_array, temporal_group="Yearly")
if era5_seasonal_array is not None:
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
if era5_shoulder_array is not None:
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
if arcticdem_feature_array is not None:
render_arcticdem_features(arcticdem_feature_array)
if common_feature_array is not None:
render_common_features(common_feature_array)
def render_rf_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for Random Forest model."""
@ -302,7 +427,7 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
with st.spinner("Generating feature importance plot..."):
importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
st.altair_chart(importance_chart, use_container_width=True)
st.altair_chart(importance_chart, width="stretch")
# Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
if not is_cuml and "tree_depths" in model_state:
@ -315,11 +440,11 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_depths, use_container_width=True)
st.altair_chart(chart_depths, width="stretch")
with col2:
st.altair_chart(chart_leaves, use_container_width=True)
st.altair_chart(chart_leaves, width="stretch")
with col3:
st.altair_chart(chart_nodes, use_container_width=True)
st.altair_chart(chart_nodes, width="stretch")
# Statistics
with st.expander("Forest Statistics"):
@ -345,6 +470,64 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
st.metric("Max Nodes", f"{nodes.max()}")
st.metric("Min Nodes", f"{nodes.min()}")
# Feature source analysis
st.subheader("Feature Importance by Data Source")
st.markdown(
"""
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
ArcticDEM topography, and common features).
"""
)
# Get members used in training
members = get_members_from_settings(selected_result.settings)
# Extract features by source
embedding_feature_array = None
if "AlphaEarth" in members:
embedding_feature_array = extract_embedding_features(model_state, importance_type="feature_importance")
era5_yearly_array = None
era5_seasonal_array = None
era5_shoulder_array = None
if "ERA5-yearly" in members:
era5_yearly_array = extract_era5_features(
model_state, importance_type="feature_importance", temporal_group="yearly"
)
if "ERA5-seasonal" in members:
era5_seasonal_array = extract_era5_features(
model_state, importance_type="feature_importance", temporal_group="seasonal"
)
if "ERA5-shoulder" in members:
era5_shoulder_array = extract_era5_features(
model_state, importance_type="feature_importance", temporal_group="shoulder"
)
arcticdem_feature_array = None
if "ArcticDEM" in members:
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type="feature_importance")
common_feature_array = extract_common_features(model_state, importance_type="feature_importance")
# Render each source's features if present
if embedding_feature_array is not None:
render_embedding_features(embedding_feature_array)
if era5_yearly_array is not None:
render_era5_features(era5_yearly_array, temporal_group="Yearly")
if era5_seasonal_array is not None:
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
if era5_shoulder_array is not None:
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
if arcticdem_feature_array is not None:
render_arcticdem_features(arcticdem_feature_array)
if common_feature_array is not None:
render_common_features(common_feature_array)
def render_knn_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for KNN model."""
@ -402,18 +585,18 @@ def render_embedding_features(embedding_feature_array):
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_agg, use_container_width=True)
st.altair_chart(chart_agg, width="stretch")
with col2:
st.altair_chart(chart_band, use_container_width=True)
st.altair_chart(chart_band, width="stretch")
with col3:
st.altair_chart(chart_year, use_container_width=True)
st.altair_chart(chart_year, width="stretch")
# Detailed heatmap
st.markdown("### Detailed Heatmap by Aggregation")
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
with st.spinner("Generating heatmap..."):
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
st.altair_chart(heatmap_chart, use_container_width=True)
st.altair_chart(heatmap_chart, width="stretch")
# Statistics
with st.expander("Embedding Feature Statistics"):
@ -436,13 +619,21 @@ def render_embedding_features(embedding_feature_array):
st.dataframe(top_emb, width="stretch")
def render_era5_features(era5_feature_array):
"""Render ERA5 feature visualizations."""
with st.container(border=True):
st.header("⛅ ERA5 Feature Analysis")
st.markdown(
def render_era5_features(era5_feature_array, temporal_group: str = ""):
"""Render ERA5 feature visualizations.
Args:
era5_feature_array: ERA5 feature importance array.
temporal_group: Name of the temporal grouping (e.g., "Yearly", "Seasonal", "Shoulder").
"""
Analysis of ERA5 climate features showing which variables and time periods
group_suffix = f" ({temporal_group})" if temporal_group else ""
with st.container(border=True):
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
st.markdown(
f"""
Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods
are most important for the model predictions.
"""
)
@ -450,19 +641,50 @@ def render_era5_features(era5_feature_array):
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating ERA5 dimension summaries..."):
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
charts = plot_era5_summary(era5_feature_array)
# Check if this is seasonal/shoulder data (returns 3 charts) or yearly (returns 2 charts)
if len(charts) == 3:
chart_variable, chart_season, chart_year = charts
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_variable, width="stretch")
with col2:
st.altair_chart(chart_season, width="stretch")
with col3:
st.altair_chart(chart_year, width="stretch")
else:
chart_variable, chart_time = charts
col1, col2 = st.columns(2)
with col1:
st.altair_chart(chart_variable, use_container_width=True)
st.altair_chart(chart_variable, width="stretch")
with col2:
st.altair_chart(chart_time, use_container_width=True)
st.altair_chart(chart_time, width="stretch")
# Detailed heatmap
st.markdown("### Detailed Heatmap")
# Check if this is seasonal/shoulder data
has_season = "season" in era5_feature_array.dims
if has_season:
st.markdown("Shows the weight of each variable-season-year combination.")
with st.spinner("Generating ERA5 season heatmap..."):
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, width="stretch")
# Add time-based heatmap for seasonal/shoulder
st.markdown("### By Time Heatmap")
st.markdown("Shows temporal trends by averaging over seasons.")
with st.spinner("Generating ERA5 time heatmap..."):
era5_time_heatmap_chart = plot_era5_time_heatmap(era5_feature_array)
if era5_time_heatmap_chart is not None:
st.altair_chart(era5_time_heatmap_chart, width="stretch")
else:
st.markdown("Shows the weight of each variable-time combination.")
with st.spinner("Generating ERA5 heatmap..."):
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, use_container_width=True)
st.altair_chart(era5_heatmap_chart, width="stretch")
# Statistics
with st.expander("ERA5 Feature Statistics"):
@ -481,10 +703,61 @@ def render_era5_features(era5_feature_array):
# Show top ERA5 features
st.write("**Top 10 ERA5 Features:**")
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
# Get all columns except 'weight' for display
display_cols = [col for col in era5_df.columns if col != "weight"] + ["weight"]
top_era5 = era5_df.nlargest(10, "weight")[display_cols]
st.dataframe(top_era5, width="stretch")
def render_arcticdem_features(arcticdem_feature_array):
"""Render ArcticDEM feature visualizations."""
with st.container(border=True):
st.header("🏔️ ArcticDEM Feature Analysis")
st.markdown(
"""
Analysis of ArcticDEM topographic features showing which terrain variables and
aggregations are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating ArcticDEM dimension summaries..."):
chart_variable, chart_agg = plot_arcticdem_summary(arcticdem_feature_array)
col1, col2 = st.columns(2)
with col1:
st.altair_chart(chart_variable, width="stretch")
with col2:
st.altair_chart(chart_agg, width="stretch")
# Detailed heatmap
st.markdown("### Detailed Heatmap")
st.markdown("Shows the weight of each variable-aggregation combination.")
with st.spinner("Generating ArcticDEM heatmap..."):
arcticdem_heatmap_chart = plot_arcticdem_heatmap(arcticdem_feature_array)
st.altair_chart(arcticdem_heatmap_chart, width="stretch")
# Statistics
with st.expander("ArcticDEM Feature Statistics"):
st.write("**Overall Statistics:**")
n_arcticdem_features = arcticdem_feature_array.size
mean_weight = float(arcticdem_feature_array.mean().values)
max_weight = float(arcticdem_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total ArcticDEM Features", n_arcticdem_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top ArcticDEM features
st.write("**Top 10 ArcticDEM Features:**")
arcticdem_df = arcticdem_feature_array.to_dataframe(name="weight").reset_index()
top_arcticdem = arcticdem_df.nlargest(10, "weight")[["variable", "agg", "weight"]]
st.dataframe(top_arcticdem, width="stretch")
def render_common_features(common_feature_array):
"""Render common feature visualizations."""
with st.container(border=True):
@ -499,7 +772,7 @@ def render_common_features(common_feature_array):
# Bar chart showing all common feature weights
with st.spinner("Generating common features chart..."):
common_chart = plot_common_features(common_feature_array)
st.altair_chart(common_chart, use_container_width=True)
st.altair_chart(common_chart, width="stretch")
# Statistics
with st.expander("Common Feature Statistics"):

View file

@ -96,7 +96,7 @@ def render_overview_page():
st.subheader("Detailed Results")
for tr in training_results:
with st.expander(tr.name):
with st.expander(tr.get_display_name("task_first")):
col1, col2 = st.columns([1, 2])
with col1:

View file

@ -35,7 +35,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
best_score = results[col].max()
best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True)
st.dataframe(pd.DataFrame(best_scores), hide_index=True, width="stretch")
with col2:
st.markdown("#### Score Statistics")
@ -51,7 +51,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
}
)
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
st.dataframe(pd.DataFrame(score_stats), hide_index=True, width="stretch")
# Show best parameter combination in a cleaner format (similar to old dashboard)
st.markdown("#### 🏆 Best Parameter Combination")
@ -305,7 +305,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
else:
# Categorical parameter - use bar chart
@ -323,7 +323,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
def render_score_vs_parameter(results: pd.DataFrame, metric: str):
@ -411,7 +411,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
else:
# Categorical parameter - box plot
@ -428,7 +428,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
def render_parameter_correlation(results: pd.DataFrame, metric: str):
@ -481,7 +481,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
.properties(height=max(200, len(correlations) * 30))
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
def render_binned_parameter_space(results: pd.DataFrame, metric: str):
@ -673,7 +673,7 @@ def _render_2d_param_plot(
.interactive()
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
def render_score_evolution(results: pd.DataFrame, metric: str):
@ -722,7 +722,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
.properties(height=400)
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
# Show statistics
col1, col2, col3, col4 = st.columns(4)
@ -831,19 +831,37 @@ def render_multi_metric_comparison(results: pd.DataFrame):
else:
tooltip_list.append(color_col)
# Calculate axis domains starting from min values
x_min = df_plot[metric1].min()
x_max = df_plot[metric1].max()
y_min = df_plot[metric2].min()
y_max = df_plot[metric2].max()
# Add small padding (2% of range) for better visualization
x_padding = (x_max - x_min) * 0.02
y_padding = (y_max - y_min) * 0.02
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(metric1, title=metric1.replace("_", " ").title()),
alt.Y(metric2, title=metric2.replace("_", " ").title()),
alt.X(
metric1,
title=metric1.replace("_", " ").title(),
scale=alt.Scale(domain=[x_min - x_padding, x_max + x_padding]),
),
alt.Y(
metric2,
title=metric2.replace("_", " ").title(),
scale=alt.Scale(domain=[y_min - y_padding, y_max + y_padding]),
),
alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
tooltip=tooltip_list,
)
.properties(height=500)
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
# Calculate correlation
corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
@ -1015,7 +1033,7 @@ def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin
)
)
st.altair_chart(chart, use_container_width=True)
st.altair_chart(chart, width="stretch")
# Show statistics about the binning
n_bins = len(bin_order)
@ -1073,4 +1091,4 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
score_col_display = metric.replace("_", " ").title()
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
st.dataframe(display_df, hide_index=True, use_container_width=True)
st.dataframe(display_df, hide_index=True, width="stretch")

View file

@ -69,6 +69,13 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
Altair chart showing the heatmap.
"""
# Filter out "count" aggregation if present
if "agg" in embedding_array.dims:
agg_values = embedding_array.coords["agg"].values
non_count_aggs = [agg for agg in agg_values if agg != "count"]
if non_count_aggs:
embedding_array = embedding_array.sel(agg=non_count_aggs)
# Convert to DataFrame for plotting
df = embedding_array.to_dataframe(name="weight").reset_index()
@ -81,7 +88,7 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
scale=alt.Scale(scheme="blues"),
title="Weight",
),
tooltip=[
@ -108,15 +115,22 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a
Tuple of three Altair charts (by_agg, by_band, by_year).
"""
# Filter out "count" aggregation if present
if "agg" in embedding_array.dims:
agg_values = embedding_array.coords["agg"].values
non_count_aggs = [agg for agg in agg_values if agg != "count"]
if non_count_aggs:
embedding_array = embedding_array.sel(agg=non_count_aggs)
# Aggregate by different dimensions
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
# Create DataFrames
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
# Create DataFrames, handling potential MultiIndex
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
df_band = pd.DataFrame({"dimension": by_band.index.astype(str), "mean_abs_weight": by_band.values})
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
# Sort by weight
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
@ -185,10 +199,10 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a
def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing ERA5 feature weights across variables and time.
"""Create a heatmap showing ERA5 feature weights across variables and time/season.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
era5_array: DataArray with dimensions (variable, time) or (variable, season, year) containing feature weights.
Returns:
Altair chart showing the heatmap.
@ -197,7 +211,34 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
# Convert to DataFrame for plotting
df = era5_array.to_dataframe(name="weight").reset_index()
# Create heatmap
# Determine if this is seasonal/shoulder data (has 'season' dimension)
has_season = "season" in df.columns
if has_season:
# For seasonal/shoulder: create faceted heatmap by season
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("year:N", title="Year", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("season:N", title="Season"),
alt.Tooltip("year:N", title="Year"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(width=200, height=300)
.facet(facet=alt.Facet("season:N", title="Season"), columns=4)
)
else:
# For yearly: simple heatmap
chart = (
alt.Chart(df)
.mark_rect()
@ -224,23 +265,161 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
return chart
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ERA5 weights by variable and time.
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing ERA5 feature weights by variable and year (averaging over season).
This is specifically for seasonal/shoulder data to show temporal trends.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
era5_array: DataArray with dimensions (variable, season, year, agg) containing feature weights.
Returns:
Tuple of two Altair charts (by_variable, by_time).
Altair chart showing the variable-year heatmap.
"""
# Aggregate by different dimensions
by_variable = era5_array.mean(dim="time").to_pandas().abs()
by_time = era5_array.mean(dim="variable").to_pandas().abs()
# Check if this has season dimension
if "season" not in era5_array.dims:
return None
# Average over season and agg to get (variable, year)
dims_to_avg = [d for d in era5_array.dims if d not in ["variable", "year"]]
by_time = era5_array.mean(dim=dims_to_avg)
# Convert to DataFrame for plotting
df = by_time.to_dataframe(name="weight").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("year:N", title="Year", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("year:N", title="Year"),
alt.Tooltip("weight:Q", format=".4f", title="Avg Weight"),
],
)
.properties(
height=400,
title="ERA5 Feature Weights by Time (averaged over seasons)",
)
)
return chart
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
"""Create bar charts summarizing ERA5 weights by variable and time/season.
Args:
era5_array: DataArray with dimensions (variable, time) or (variable, season, year, agg) containing feature weights.
Returns:
Tuple of Altair charts:
- For yearly data: (by_variable, by_time)
- For seasonal/shoulder data: (by_variable, by_season, by_year)
"""
# Determine which dimensions to average over
dims_to_avg_for_var = [d for d in era5_array.dims if d != "variable"]
# Check if this is seasonal/shoulder data
has_season = "season" in era5_array.dims
# Aggregate by variable (average over all other dimensions)
by_variable = era5_array.mean(dim=dims_to_avg_for_var).to_pandas().abs()
if has_season:
# For seasonal/shoulder: aggregate by season and year
dims_to_avg_for_season = [d for d in era5_array.dims if d != "season"]
by_season = era5_array.mean(dim=dims_to_avg_for_season).to_pandas().abs()
dims_to_avg_for_year = [d for d in era5_array.dims if d != "year"]
by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs()
# Create DataFrames
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values})
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
df_season = df_season.sort_values("mean_abs_weight", ascending=True)
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
chart_variable = (
alt.Chart(df_variable)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="purples"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Variable"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Variable")
)
chart_season = (
alt.Chart(df_season)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="teals"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Season"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Season")
)
chart_year = (
alt.Chart(df_year)
.mark_bar()
.encode(
y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="oranges"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:O", title="Year"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Time")
)
return chart_variable, chart_season, chart_year
else:
# For yearly: aggregate by time
dims_to_avg_for_time = [d for d in era5_array.dims if d != "time"]
by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs()
# Create DataFrames, handling potential MultiIndex
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
@ -288,6 +467,111 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
return chart_variable, chart_time
def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing ArcticDEM feature weights across variables and aggregations.
Args:
arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = arcticdem_array.to_dataframe(name="weight").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("agg:N", title="Aggregation", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("agg:N", title="Aggregation"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(
width=400,
height=300,
title="ArcticDEM Feature Weights Heatmap",
)
)
return chart
def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ArcticDEM weights by variable and aggregation.
Args:
arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights.
Returns:
Tuple of two Altair charts (by_variable, by_agg).
"""
# Aggregate by different dimensions
by_variable = arcticdem_array.mean(dim="agg").to_pandas().abs()
by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames, handling potential MultiIndex
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_variable = (
alt.Chart(df_variable)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="browns"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Variable"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Variable")
)
chart_agg = (
alt.Chart(df_agg)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="reds"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Aggregation"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Aggregation")
)
return chart_variable, chart_agg
def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart:
"""Create a heatmap showing which boxes are assigned to which labels/classes.

View file

@ -36,8 +36,8 @@ def render_training_analysis_page():
with st.sidebar:
st.header("Select Training Run")
# Create selection options
training_options = {tr.name: tr for tr in training_results}
# Create selection options with task-first naming
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
@ -163,7 +163,7 @@ def render_training_analysis_page():
with st.expander("📋 Parameter Space Summary", expanded=False):
param_summary = get_parameter_space_summary(results)
if not param_summary.empty:
st.dataframe(param_summary, hide_index=True, use_container_width=True)
st.dataframe(param_summary, hide_index=True, width="stretch")
else:
st.info("No parameter information available.")
@ -226,7 +226,7 @@ def render_training_analysis_page():
data=csv_data,
file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv",
use_container_width=True,
width="stretch",
)
with col2:
@ -239,9 +239,9 @@ def render_training_analysis_page():
data=settings_json,
file_name=f"{selected_result.path.name}_settings.json",
mime="application/json",
use_container_width=True,
width="stretch",
)
# Show raw data preview
st.subheader("Raw Data Preview")
st.dataframe(results.head(100), use_container_width=True)
st.dataframe(results.head(100), width="stretch")

View file

@ -78,7 +78,7 @@ def render_training_data_page():
# Form submit button
load_button = st.form_submit_button(
"Load Dataset", type="primary", use_container_width=True, disabled=len(selected_members) == 0
"Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0
)
# Create DatasetEnsemble only when form is submitted

View file

@ -25,6 +25,27 @@ class TrainingResult:
results: pd.DataFrame
created_at: float
def get_display_name(self, format_type: str = "task_first") -> str:
"""Get formatted display name for the training result.
Args:
format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state)
Returns:
Formatted name string
"""
task = self.settings.get("task", "Unknown").capitalize()
model = self.settings.get("model", "Unknown").upper()
grid = self.settings.get("grid", "Unknown").capitalize()
level = self.settings.get("level", "Unknown")
timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M")
if format_type == "model_first":
return f"{model} - {task} - {grid}-{level} ({timestamp})"
else: # task_first
return f"{task} - {model} - {grid}-{level} ({timestamp})"
@classmethod
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
@ -38,9 +59,11 @@ class TrainingResult:
settings = toml.load(settings_file)["settings"]
results = pd.read_parquet(result_file)
# Name should be "task grid-level (created_at)"
# Name should be "task model grid-level (created_at)"
model = settings.get("model", "Unknown").upper()
name = (
f"**{settings.get('task', 'Unknown').capitalize()}** -"
f"{settings.get('task', 'Unknown').capitalize()} -"
f" {model} -"
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
)
@ -121,11 +144,40 @@ def load_source_data(e: DatasetEnsemble, source: str):
return ds, targets
def extract_embedding_features(model_state) -> xr.DataArray | None:
"""Extract embedding features from the model state.
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
"""Get the appropriate feature importance array from model state.
Args:
model_state: The xarray Dataset containing the model state.
importance_type: Type of importance to extract. For ESPA: 'feature_weights'.
For XGBoost: 'feature_importance_weight', 'feature_importance_gain', etc.
For RF: 'feature_importance'.
Returns:
xr.DataArray with feature importance values.
"""
if importance_type in model_state:
return model_state[importance_type]
# Fallback for compatibility
if "feature_weights" in model_state:
return model_state["feature_weights"]
if "feature_importance" in model_state:
return model_state["feature_importance"]
raise ValueError(f"No feature importance array found. Available: {list(model_state.data_vars)}")
def extract_embedding_features(
model_state: xr.Dataset, importance_type: str = "feature_weights"
) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Feature naming pattern: `embeddings_{agg}_{band}_{year}`
Example: `embeddings_mean_B02_2020`
Args:
model_state: The xarray Dataset containing the model state.
importance_type: Type of feature importance to extract.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
@ -135,14 +187,17 @@ def extract_embedding_features(model_state) -> xr.DataArray | None:
"""
def _is_embedding_feature(feature: str) -> bool:
return feature.startswith("embedding_")
return feature.startswith("embeddings_")
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Get the appropriate importance array
importance_array = _get_feature_importance_array(model_state, importance_type)
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = importance_array.sel(feature=embedding_features)
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[1] for f in embedding_features]),
band=("feature", [f.split("_")[2] for f in embedding_features]),
@ -152,15 +207,26 @@ def extract_embedding_features(model_state) -> xr.DataArray | None:
return embedding_feature_array
def extract_era5_features(model_state) -> xr.DataArray | None:
def extract_era5_features(
model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None
) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Feature naming patterns:
- Without aggregations: `era5_{variable}_{time}`
Example: `era5_temperature_2020_summer`
- With aggregations: `era5_{variable}_{time}_{agg}`
Example: `era5_temperature_2020_summer_mean`
Args:
model_state: The xarray Dataset containing the model state.
importance_type: Type of feature importance to extract.
temporal_group: Filter to specific temporal group ('yearly', 'seasonal', 'shoulder').
If None, extracts all ERA5 features.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
('variable', 'time') or ('variable', 'time', 'agg') depending on the feature structure.
Returns None if no ERA5 features are found.
"""
@ -168,34 +234,254 @@ def extract_era5_features(model_state) -> xr.DataArray | None:
def _is_era5_feature(feature: str) -> bool:
return feature.startswith("era5_")
def _get_temporal_group(feature: str) -> str:
"""Determine temporal group from feature name based on time component.
Time patterns:
- Yearly: just a year number (e.g., "2020")
- Seasonal: season_year (e.g., "summer_2020", "winter_2021")
- Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021")
"""
parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
# Find where the time part starts (after "era5" and variable name)
# Pattern: era5_variable_time or era5_variable_time_agg
# Time can be: year, season_year, or SHOULDER_year
# Remove era5 prefix and aggregation suffix if present
if parts[-1] in common_aggs:
time_parts = parts[1:-1] # Remove era5 and agg
else:
time_parts = parts[1:] # Remove era5
# The last part (or last two parts) is the time
# Check last part first
last_part = time_parts[-1]
if last_part.isdigit():
# Could be yearly (just year) or part of seasonal/shoulder (season_year or SHOULDER_year)
if len(time_parts) >= 2:
second_last = time_parts[-2]
# Check if second to last is a season or shoulder indicator
if second_last.lower() in ["summer", "winter"]:
return "seasonal"
elif second_last.upper() in ["JFM", "AMJ", "JAS", "OND"]:
return "shoulder"
# Just a year number
return "yearly"
elif last_part.lower() in ["summer", "winter"]:
return "seasonal"
elif last_part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
return "shoulder"
return "unknown"
def _extract_var_name(feature: str) -> str:
parts = feature.split("_")
# era5_variablename_timetype format
return "_".join(parts[1:-1])
# Pattern: era5_variable_year_agg OR era5_variable_season_year_agg
# Variable can include stat names like t2m_max, t2m_min, etc.
# We need to find where the year starts (it's always a 4-digit number)
# Find the first year (4-digit number)
year_idx = None
for i, part in enumerate(parts):
if part.isdigit() and len(part) == 4:
year_idx = i
break
if year_idx is None:
# Fallback: return everything after era5
return "_".join(parts[1:])
# Variable is everything between era5 and year
# But also check if there's a season before the year
seasons = {"summer", "winter"}
shoulders = {"JFM", "AMJ", "JAS", "OND"}
var_end_idx = year_idx
# Check if the part before year is a season/shoulder
if year_idx > 1 and (parts[year_idx - 1].lower() in seasons or parts[year_idx - 1].upper() in shoulders):
var_end_idx = year_idx - 1
# Variable is from index 1 (after era5) to var_end_idx
return "_".join(parts[1:var_end_idx])
def _extract_time_name(feature: str) -> str:
parts = feature.split("_")
# Last part is the time type
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
if parts[-1] in common_aggs:
# Has aggregation: era5_var_time_agg -> time is second to last
return parts[-2]
else:
# No aggregation: era5_var_time -> time is last
return parts[-1]
def _extract_season(feature: str) -> str | None:
"""Extract season/shoulder from seasonal/shoulder features.
Pattern: era5_variable_season_year_agg or era5_variable_season_year
"""
parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
# Look through parts to find season/shoulder indicators
for part in parts:
if part.lower() in ["summer", "winter"]:
return part.lower()
elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
return part.upper()
return None
def _extract_year_from_time(feature: str) -> str:
"""Extract just the year component from time.
For seasonal/shoulder features, find the year that comes after the season.
"""
parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
# Find the season/shoulder part, then the next part should be the year
for i, part in enumerate(parts):
if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
# Next part should be the year
if i + 1 < len(parts):
next_part = parts[i + 1]
# Skip if it's an aggregation, year should be before agg
if next_part not in common_aggs:
return next_part
# Fallback: return last numeric part that's not an aggregation
for part in reversed(parts):
if part.isdigit():
return part
return "unknown"
def _extract_agg_name(feature: str) -> str | None:
parts = feature.split("_")
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
if parts[-1] in common_aggs:
return parts[-1]
return None
# Get all ERA5 features
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
# Filter by temporal group if specified
if temporal_group is not None:
era5_features = [f for f in era5_features if _get_temporal_group(f) == temporal_group]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
# Get the appropriate importance array
importance_array = _get_feature_importance_array(model_state, importance_type)
# Check if features have aggregations
has_agg = any(_extract_agg_name(f) is not None for f in era5_features)
# Check if features have season/shoulder split (for seasonal/shoulder temporal groups)
has_season = any(_extract_season(f) is not None for f in era5_features)
# Split the single feature dimension of era5 features into separate dimensions
era5_features_array = importance_array.sel(feature=era5_features)
if has_season:
# For seasonal/shoulder: split into variable, season, year, (agg)
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
season=("feature", [_extract_season(f) for f in era5_features]),
year=("feature", [_extract_year_from_time(f) for f in era5_features]),
)
if has_agg:
era5_features_array = era5_features_array.assign_coords(
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack(
"feature"
) # noqa: PD010
else:
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year"]).unstack(
"feature"
) # noqa: PD010
else:
# For yearly: keep as variable, time, (agg)
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
if has_agg:
# Add aggregation dimension
era5_features_array = era5_features_array.assign_coords(
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010
else:
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
def extract_common_features(model_state) -> xr.DataArray | None:
def extract_arcticdem_features(
model_state: xr.Dataset, importance_type: str = "feature_weights"
) -> xr.DataArray | None:
"""Extract ArcticDEM features from the model state.
Feature naming pattern: `arcticdem_{variable}_{agg}`
Example: `arcticdem_elevation_mean`, `arcticdem_slope_std`
Args:
model_state: The xarray Dataset containing the model state.
importance_type: Type of feature importance to extract.
Returns:
xr.DataArray: The extracted ArcticDEM features. This DataArray has dimensions
('variable', 'agg') corresponding to the different components of the ArcticDEM features.
Returns None if no ArcticDEM features are found.
"""
def _is_arcticdem_feature(feature: str) -> bool:
return feature.startswith("arcticdem_")
def _extract_var_name(feature: str) -> str:
parts = feature.split("_")
# arcticdem_variable_agg format
# Variable name is everything between arcticdem_ and the last part
return "_".join(parts[1:-1])
def _extract_agg_name(feature: str) -> str:
parts = feature.split("_")
# Last part is the aggregation
return parts[-1]
arcticdem_features = [f for f in model_state.feature.to_numpy() if _is_arcticdem_feature(f)]
if len(arcticdem_features) == 0:
return None
# Get the appropriate importance array
importance_array = _get_feature_importance_array(model_state, importance_type)
# Split the single feature dimension into separate dimensions (variable, agg)
arcticdem_feature_array = importance_array.sel(feature=arcticdem_features)
arcticdem_feature_array = arcticdem_feature_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
)
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010
return arcticdem_feature_array
def extract_common_features(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray | None:
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
Args:
model_state: The xarray Dataset containing the model state.
importance_type: Type of feature importance to extract.
Returns:
xr.DataArray: The extracted common features with a single 'feature' dimension.
@ -211,6 +497,22 @@ def extract_common_features(model_state) -> xr.DataArray | None:
if len(common_features) == 0:
return None
# Extract the feature weights for common features
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
# Get the appropriate importance array
importance_array = _get_feature_importance_array(model_state, importance_type)
# Extract the feature importance for common features
common_feature_array = importance_array.sel(feature=common_features)
return common_feature_array
def get_members_from_settings(settings: dict) -> list[str]:
"""Extract the list of dataset members used in training from settings.
Args:
settings: Training settings dictionary.
Returns:
List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']).
"""
return settings.get("members", [])

View file

@ -278,6 +278,7 @@ def random_cv(
booster = best_estimator.get_booster()
# Feature importance with different importance types
# Note: get_score() returns dict with keys like 'f0', 'f1', etc. (feature indices)
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
@ -286,8 +287,11 @@ def random_cv(
# Create aligned arrays for all features (including zero-importance)
def align_importance(importance_dict, features):
"""Align importance dict to feature list, filling missing with 0."""
return [importance_dict.get(f, 0.0) for f in features]
"""Align importance dict to feature list, filling missing with 0.
XGBoost returns feature indices (f0, f1, ...) as keys, so we need to map them.
"""
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),

146
test_feature_extraction.py Normal file
View file

@ -0,0 +1,146 @@
"""Test script to verify feature extraction works correctly."""
import numpy as np
import xarray as xr
# Create a mock model state with various feature types
features = [
# Embedding features: embedding_{agg}_{band}_{year}
"embedding_mean_B02_2020",
"embedding_std_B03_2021",
"embedding_max_B04_2022",
# ERA5 features without aggregations: era5_{variable}_{time}
"era5_temperature_2020_summer",
"era5_precipitation_2021_winter",
# ERA5 features with aggregations: era5_{variable}_{time}_{agg}
"era5_temperature_2020_summer_mean",
"era5_precipitation_2021_winter_std",
# ArcticDEM features: arcticdem_{variable}_{agg}
"arcticdem_elevation_mean",
"arcticdem_slope_std",
"arcticdem_aspect_max",
# Common features
"cell_area",
"water_area",
"land_area",
"land_ratio",
"lon",
"lat",
]
# Create mock importance values
importance_values = np.random.rand(len(features))
# Create a mock model state for ESPA
model_state_espa = xr.Dataset(
{
"feature_weights": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
)
}
)
# Create a mock model state for XGBoost
model_state_xgb = xr.Dataset(
{
"feature_importance_gain": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
),
"feature_importance_weight": xr.DataArray(
importance_values * 0.8,
dims=["feature"],
coords={"feature": features},
),
}
)
# Create a mock model state for Random Forest
model_state_rf = xr.Dataset(
{
"feature_importance": xr.DataArray(
importance_values,
dims=["feature"],
coords={"feature": features},
)
}
)
# Test extraction functions
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
)
print("=" * 80)
print("Testing ESPA model state")
print("=" * 80)
embedding_array = extract_embedding_features(model_state_espa)
print(f"\nEmbedding features extracted: {embedding_array is not None}")
if embedding_array is not None:
print(f" Dimensions: {embedding_array.dims}")
print(f" Shape: {embedding_array.shape}")
print(f" Coordinates: {list(embedding_array.coords)}")
era5_array = extract_era5_features(model_state_espa)
print(f"\nERA5 features extracted: {era5_array is not None}")
if era5_array is not None:
print(f" Dimensions: {era5_array.dims}")
print(f" Shape: {era5_array.shape}")
print(f" Coordinates: {list(era5_array.coords)}")
arcticdem_array = extract_arcticdem_features(model_state_espa)
print(f"\nArcticDEM features extracted: {arcticdem_array is not None}")
if arcticdem_array is not None:
print(f" Dimensions: {arcticdem_array.dims}")
print(f" Shape: {arcticdem_array.shape}")
print(f" Coordinates: {list(arcticdem_array.coords)}")
common_array = extract_common_features(model_state_espa)
print(f"\nCommon features extracted: {common_array is not None}")
if common_array is not None:
print(f" Dimensions: {common_array.dims}")
print(f" Shape: {common_array.shape}")
print(f" Size: {common_array.size}")
print("\n" + "=" * 80)
print("Testing XGBoost model state")
print("=" * 80)
embedding_array_xgb = extract_embedding_features(model_state_xgb, importance_type="feature_importance_gain")
print(f"\nEmbedding features (gain) extracted: {embedding_array_xgb is not None}")
if embedding_array_xgb is not None:
print(f" Dimensions: {embedding_array_xgb.dims}")
print(f" Shape: {embedding_array_xgb.shape}")
era5_array_xgb = extract_era5_features(model_state_xgb, importance_type="feature_importance_weight")
print(f"\nERA5 features (weight) extracted: {era5_array_xgb is not None}")
if era5_array_xgb is not None:
print(f" Dimensions: {era5_array_xgb.dims}")
print(f" Shape: {era5_array_xgb.shape}")
print("\n" + "=" * 80)
print("Testing Random Forest model state")
print("=" * 80)
embedding_array_rf = extract_embedding_features(model_state_rf, importance_type="feature_importance")
print(f"\nEmbedding features extracted: {embedding_array_rf is not None}")
if embedding_array_rf is not None:
print(f" Dimensions: {embedding_array_rf.dims}")
print(f" Shape: {embedding_array_rf.shape}")
arcticdem_array_rf = extract_arcticdem_features(model_state_rf, importance_type="feature_importance")
print(f"\nArcticDEM features extracted: {arcticdem_array_rf is not None}")
if arcticdem_array_rf is not None:
print(f" Dimensions: {arcticdem_array_rf.dims}")
print(f" Shape: {arcticdem_array_rf.shape}")
print("\n" + "=" * 80)
print("All tests completed successfully!")
print("=" * 80)