Make the Model State Page great again
This commit is contained in:
parent
591da6992e
commit
1919cc6a7e
13 changed files with 1375 additions and 142 deletions
28
pixi.lock
generated
28
pixi.lock
generated
|
|
@ -9,6 +9,7 @@ environments:
|
|||
- https://pypi.org/simple
|
||||
options:
|
||||
channel-priority: disabled
|
||||
pypi-prerelease-mode: if-necessary-or-explicit
|
||||
packages:
|
||||
linux-64:
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda
|
||||
|
|
@ -478,7 +479,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
||||
|
|
@ -498,7 +499,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
|
||||
|
|
@ -567,7 +568,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e2/c7/b8f2b3e53f26f8f463002f3e8023189653b627b22ba6c00ef86eaba50b73/odc_stac-0.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl
|
||||
|
|
@ -1025,10 +1026,10 @@ packages:
|
|||
- astropy[dev] ; extra == 'dev-all'
|
||||
- astropy[test-all] ; extra == 'dev-all'
|
||||
requires_python: '>=3.11'
|
||||
- pypi: https://files.pythonhosted.org/packages/e1/2b/a3d75e4a8966460a9f79635797bb3e85e29c3bff86640087c7814429d917/astropy_iers_data-0.2025.12.15.0.40.51-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
|
||||
name: astropy-iers-data
|
||||
version: 0.2025.12.15.0.40.51
|
||||
sha256: 5f89c4b9d711638005407b4b57f76d82484f0182e459df91fcdfc354c3a34d8e
|
||||
version: 0.2025.12.22.0.40.30
|
||||
sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac
|
||||
requires_dist:
|
||||
- pytest ; extra == 'docs'
|
||||
- hypothesis ; extra == 'test'
|
||||
|
|
@ -2516,10 +2517,10 @@ packages:
|
|||
- pkg:pypi/cycler?source=hash-mapping
|
||||
size: 14778
|
||||
timestamp: 1764466758386
|
||||
- pypi: https://files.pythonhosted.org/packages/07/18/5ca04dfda3e53b5d07b072033cc9f7bf10f93f78019366bff411433690d1/cyclopts-4.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
|
||||
name: cyclopts
|
||||
version: 4.4.0
|
||||
sha256: 78ff95a5e52e738a1d0f01e5a3af48049c47748fa2c255f2629a4cef54dcf2b3
|
||||
version: 4.4.1
|
||||
sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde
|
||||
requires_dist:
|
||||
- attrs>=23.1.0
|
||||
- docstring-parser>=0.15,<4.0
|
||||
|
|
@ -2932,7 +2933,6 @@ packages:
|
|||
- ruff>=0.14.9,<0.15
|
||||
- pandas-stubs>=2.3.3.251201,<3
|
||||
requires_python: '>=3.13,<3.14'
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
name: entropy
|
||||
version: 0.1.0
|
||||
|
|
@ -7326,12 +7326,12 @@ packages:
|
|||
- pkg:pypi/pandas?source=compressed-mapping
|
||||
size: 14912799
|
||||
timestamp: 1764615091147
|
||||
- pypi: https://files.pythonhosted.org/packages/e2/68/78a3c253f146254b8e2c19f4a4768f272e12ef11001d9b45ec7b165db054/pandas_stubs-2.3.3.251201-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/64/20/69f2a39792a653fd64d916cd563ed79ec6e5dcfa6408c4674021d810afcf/pandas_stubs-2.3.3.251219-py3-none-any.whl
|
||||
name: pandas-stubs
|
||||
version: 2.3.3.251201
|
||||
sha256: eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5
|
||||
version: 2.3.3.251219
|
||||
sha256: ccc6337febb51d6d8a08e4c96b479478a0da0ef704b5e08bd212423fe1cb549c
|
||||
requires_dist:
|
||||
- numpy>=1.23.5
|
||||
- numpy>=1.23.5,<=2.3.5
|
||||
- types-pytz>=2022.1.1
|
||||
requires_python: '>=3.10'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2
|
||||
|
|
|
|||
|
|
@ -1,11 +1,22 @@
|
|||
#! /bin/bash
|
||||
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 3
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 4
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 5
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 6
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 6
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 7
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 8
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 9
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 10
|
||||
# pixi shell
|
||||
darts extract-darts-rts --grid hex --level 3
|
||||
darts extract-darts-rts --grid hex --level 4
|
||||
darts extract-darts-rts --grid hex --level 5
|
||||
darts extract-darts-rts --grid hex --level 6
|
||||
darts extract-darts-rts --grid healpix --level 6
|
||||
darts extract-darts-rts --grid healpix --level 7
|
||||
darts extract-darts-rts --grid healpix --level 8
|
||||
darts extract-darts-rts --grid healpix --level 9
|
||||
darts extract-darts-rts --grid healpix --level 10
|
||||
|
||||
darts extract-darts-mllabels --grid hex --level 3
|
||||
darts extract-darts-mllabels --grid hex --level 4
|
||||
darts extract-darts-mllabels --grid hex --level 5
|
||||
darts extract-darts-mllabels --grid hex --level 6
|
||||
darts extract-darts-mllabels --grid healpix --level 6
|
||||
darts extract-darts-mllabels --grid healpix --level 7
|
||||
darts extract-darts-mllabels --grid healpix --level 8
|
||||
darts extract-darts-mllabels --grid healpix --level 9
|
||||
darts extract-darts-mllabels --grid healpix --level 10
|
||||
4
scripts/05train.sh
Normal file
4
scripts/05train.sh
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash
|
||||
|
||||
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model espa
|
||||
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model xgboost
|
||||
191
scripts/fix_xgboost_importance.py
Normal file
191
scripts/fix_xgboost_importance.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Fix XGBoost feature importance in existing model state files.
|
||||
|
||||
This script repairs XGBoost model state files that have all-zero feature importance
|
||||
values due to incorrect feature name lookup. It reloads the pickled models and
|
||||
regenerates the feature importance arrays with the correct feature index mapping.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
import xarray as xr
|
||||
from rich import print
|
||||
|
||||
from entropice.paths import RESULTS_DIR
|
||||
|
||||
|
||||
def fix_xgboost_model_state(results_dir: Path) -> bool:
|
||||
"""Fix a single XGBoost model state file.
|
||||
|
||||
Args:
|
||||
results_dir: Directory containing the model files.
|
||||
|
||||
Returns:
|
||||
True if fixed successfully, False otherwise.
|
||||
|
||||
"""
|
||||
# Check if this is an XGBoost model
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
if not settings_file.exists():
|
||||
return False
|
||||
|
||||
settings = toml.load(settings_file)
|
||||
model_type = settings.get("settings", {}).get("model", "")
|
||||
|
||||
if model_type != "xgboost":
|
||||
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
|
||||
return False
|
||||
|
||||
# Check if required files exist
|
||||
model_file = results_dir / "best_estimator_model.pkl"
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
|
||||
if not model_file.exists():
|
||||
print(f"⚠️ Missing model file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
if not state_file.exists():
|
||||
print(f"⚠️ Missing state file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
# Load the pickled model
|
||||
print(f"Loading model from {results_dir.name}...")
|
||||
with open(model_file, "rb") as f:
|
||||
best_estimator = pickle.load(f)
|
||||
|
||||
# Load the old state to get feature names
|
||||
old_state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
features = old_state.coords["feature"].values.tolist()
|
||||
old_state.close()
|
||||
|
||||
# Get the booster and extract feature importance with correct mapping
|
||||
booster = best_estimator.get_booster()
|
||||
|
||||
importance_weight = booster.get_score(importance_type="weight")
|
||||
importance_gain = booster.get_score(importance_type="gain")
|
||||
importance_cover = booster.get_score(importance_type="cover")
|
||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
||||
|
||||
# Align importance using feature indices (f0, f1, ...)
|
||||
def align_importance(importance_dict, features):
|
||||
"""Align importance dict to feature list using feature indices."""
|
||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
||||
|
||||
# Create new DataArrays
|
||||
feature_importance_weight = xr.DataArray(
|
||||
align_importance(importance_weight, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_weight",
|
||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
||||
)
|
||||
feature_importance_gain = xr.DataArray(
|
||||
align_importance(importance_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_gain",
|
||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_cover = xr.DataArray(
|
||||
align_importance(importance_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_cover",
|
||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_gain = xr.DataArray(
|
||||
align_importance(importance_total_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_gain",
|
||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_cover = xr.DataArray(
|
||||
align_importance(importance_total_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_cover",
|
||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
||||
)
|
||||
|
||||
# Create new state dataset
|
||||
n_trees = booster.num_boosted_rounds()
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"feature_importance_weight": feature_importance_weight,
|
||||
"feature_importance_gain": feature_importance_gain,
|
||||
"feature_importance_cover": feature_importance_cover,
|
||||
"feature_importance_total_gain": feature_importance_total_gain,
|
||||
"feature_importance_total_cover": feature_importance_total_cover,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
||||
"n_trees": n_trees,
|
||||
"objective": str(best_estimator.objective),
|
||||
},
|
||||
)
|
||||
|
||||
# Backup the old file
|
||||
backup_file = state_file.with_suffix(".nc.backup")
|
||||
if not backup_file.exists():
|
||||
print(f" Creating backup: {backup_file.name}")
|
||||
state_file.rename(backup_file)
|
||||
else:
|
||||
print(" Backup already exists, removing old state file")
|
||||
state_file.unlink()
|
||||
|
||||
# Save the fixed state
|
||||
print(f" Saving fixed state to {state_file.name}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Verify the fix
|
||||
total_importance = (
|
||||
feature_importance_weight.sum().item()
|
||||
+ feature_importance_gain.sum().item()
|
||||
+ feature_importance_cover.sum().item()
|
||||
)
|
||||
|
||||
if total_importance > 0:
|
||||
print(f" ✓ Success! Total importance: {total_importance:.2f}")
|
||||
return True
|
||||
else:
|
||||
print(" ✗ Warning: Total importance is still 0!")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Find and fix all XGBoost model state files."""
|
||||
print("Scanning for XGBoost model results...")
|
||||
|
||||
# Find all result directories
|
||||
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
|
||||
print(f"Found {len(result_dirs)} result directories")
|
||||
|
||||
fixed_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for result_dir in sorted(result_dirs):
|
||||
try:
|
||||
success = fix_xgboost_model_state(result_dir)
|
||||
if success:
|
||||
fixed_count += 1
|
||||
elif success is False:
|
||||
# Explicitly skipped (not XGBoost)
|
||||
skipped_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {result_dir.name}: {e}")
|
||||
failed_count += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(f" ✓ Fixed: {fixed_count}")
|
||||
print(f" ⊘ Skipped: {skipped_count}")
|
||||
print(f" ✗ Failed: {failed_count}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,6 +5,8 @@ import xarray as xr
|
|||
|
||||
from entropice.dashboard.plots.colors import generate_unified_colormap
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_arcticdem_heatmap,
|
||||
plot_arcticdem_summary,
|
||||
plot_box_assignment_bars,
|
||||
plot_box_assignments,
|
||||
plot_common_features,
|
||||
|
|
@ -12,12 +14,15 @@ from entropice.dashboard.plots.model_state import (
|
|||
plot_embedding_heatmap,
|
||||
plot_era5_heatmap,
|
||||
plot_era5_summary,
|
||||
plot_era5_time_heatmap,
|
||||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
get_members_from_settings,
|
||||
load_all_training_results,
|
||||
)
|
||||
from entropice.dashboard.utils.training import load_model_state
|
||||
|
|
@ -35,15 +40,21 @@ def render_model_state_page():
|
|||
st.error("No training results found. Please run a training search first.")
|
||||
return
|
||||
|
||||
# Result selection
|
||||
result_options = {tr.name: tr for tr in training_results}
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Result selection with model-first naming
|
||||
result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
|
||||
selected_name = st.selectbox(
|
||||
"Select Training Result",
|
||||
"Training Run",
|
||||
options=list(result_options.keys()),
|
||||
help="Choose a training result to visualize model state",
|
||||
)
|
||||
selected_result = result_options[selected_name]
|
||||
|
||||
st.divider()
|
||||
|
||||
# Get the model type from settings
|
||||
model_type = selected_result.settings.get("model", "espa")
|
||||
|
||||
|
|
@ -62,6 +73,33 @@ def render_model_state_page():
|
|||
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||
st.write(f"**Attributes:** {dict(model_state.attrs)}")
|
||||
|
||||
# Display dataset members summary
|
||||
st.header("📊 Training Data Summary")
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
st.markdown(f"""
|
||||
**Dataset Members Used in Training:** {len(members)}
|
||||
|
||||
The following data sources were used to train this model:
|
||||
""")
|
||||
|
||||
# Create a nice display of members with emojis
|
||||
member_display = {
|
||||
"AlphaEarth": "🛰️ AlphaEarth (Satellite Embeddings)",
|
||||
"ArcticDEM": "🏔️ ArcticDEM (Topography)",
|
||||
"ERA5-yearly": "⛅ ERA5 Yearly (Climate)",
|
||||
"ERA5-seasonal": "⛅ ERA5 Seasonal (Summer/Winter)",
|
||||
"ERA5-shoulder": "⛅ ERA5 Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
||||
}
|
||||
|
||||
cols = st.columns(min(len(members), 3))
|
||||
for idx, member in enumerate(members):
|
||||
with cols[idx % 3]:
|
||||
display_name = member_display.get(member, f"📁 {member}")
|
||||
st.info(display_name)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render model-specific visualizations
|
||||
if model_type == "espa":
|
||||
render_espa_model_state(model_state, selected_result)
|
||||
|
|
@ -81,9 +119,28 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
||||
# Extract different feature types
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract different feature types based on what was used in training
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
era5_feature_array = extract_era5_features(model_state)
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(model_state, temporal_group="yearly")
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(model_state, temporal_group="seasonal")
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(model_state, temporal_group="shoulder")
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state)
|
||||
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps
|
||||
|
|
@ -107,7 +164,7 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
feature_chart = plot_top_features(model_state, top_n=top_n)
|
||||
st.altair_chart(feature_chart, use_container_width=True)
|
||||
st.altair_chart(feature_chart, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
|
|
@ -136,12 +193,12 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
with col1:
|
||||
st.markdown("### Assignment Heatmap")
|
||||
box_assignment_heatmap = plot_box_assignments(model_state)
|
||||
st.altair_chart(box_assignment_heatmap, use_container_width=True)
|
||||
st.altair_chart(box_assignment_heatmap, width="stretch")
|
||||
|
||||
with col2:
|
||||
st.markdown("### Box Count by Class")
|
||||
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
|
||||
st.altair_chart(box_assignment_bars, use_container_width=True)
|
||||
st.altair_chart(box_assignment_bars, width="stretch")
|
||||
|
||||
# Show statistics
|
||||
with st.expander("Box Assignment Statistics"):
|
||||
|
|
@ -179,9 +236,19 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
# ERA5 features analysis (if present)
|
||||
if era5_feature_array is not None:
|
||||
render_era5_features(era5_feature_array)
|
||||
# ERA5 features analysis (if present) - split by temporal group
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
# ArcticDEM features analysis (if present)
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
# Common features analysis (if present)
|
||||
if common_feature_array is not None:
|
||||
|
|
@ -237,7 +304,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
|
||||
st.altair_chart(importance_chart, use_container_width=True)
|
||||
st.altair_chart(importance_chart, width="stretch")
|
||||
|
||||
# Comparison of importance types
|
||||
st.subheader("Importance Type Comparison")
|
||||
|
|
@ -245,7 +312,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
with st.spinner("Generating importance comparison..."):
|
||||
comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
|
||||
st.altair_chart(comparison_chart, use_container_width=True)
|
||||
st.altair_chart(comparison_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Model Statistics"):
|
||||
|
|
@ -256,6 +323,64 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
|||
with col2:
|
||||
st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
|
||||
|
||||
# Feature source analysis
|
||||
st.subheader("Feature Importance by Data Source")
|
||||
st.markdown(
|
||||
"""
|
||||
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
|
||||
ArcticDEM topography, and common features).
|
||||
"""
|
||||
)
|
||||
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract features by source using the selected importance type
|
||||
importance_var = f"feature_importance_{importance_type}"
|
||||
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state, importance_type=importance_var)
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(model_state, importance_type=importance_var, temporal_group="yearly")
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(
|
||||
model_state, importance_type=importance_var, temporal_group="seasonal"
|
||||
)
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(
|
||||
model_state, importance_type=importance_var, temporal_group="shoulder"
|
||||
)
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type=importance_var)
|
||||
|
||||
common_feature_array = extract_common_features(model_state, importance_type=importance_var)
|
||||
|
||||
# Render each source's features if present
|
||||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
if common_feature_array is not None:
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for Random Forest model."""
|
||||
|
|
@ -302,7 +427,7 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
|
||||
st.altair_chart(importance_chart, use_container_width=True)
|
||||
st.altair_chart(importance_chart, width="stretch")
|
||||
|
||||
# Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
|
||||
if not is_cuml and "tree_depths" in model_state:
|
||||
|
|
@ -315,11 +440,11 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
|||
chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_depths, use_container_width=True)
|
||||
st.altair_chart(chart_depths, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_leaves, use_container_width=True)
|
||||
st.altair_chart(chart_leaves, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_nodes, use_container_width=True)
|
||||
st.altair_chart(chart_nodes, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Forest Statistics"):
|
||||
|
|
@ -345,6 +470,64 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
|||
st.metric("Max Nodes", f"{nodes.max()}")
|
||||
st.metric("Min Nodes", f"{nodes.min()}")
|
||||
|
||||
# Feature source analysis
|
||||
st.subheader("Feature Importance by Data Source")
|
||||
st.markdown(
|
||||
"""
|
||||
Breakdown of feature importance by data source (AlphaEarth embeddings, ERA5 climate,
|
||||
ArcticDEM topography, and common features).
|
||||
"""
|
||||
)
|
||||
|
||||
# Get members used in training
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
# Extract features by source
|
||||
embedding_feature_array = None
|
||||
if "AlphaEarth" in members:
|
||||
embedding_feature_array = extract_embedding_features(model_state, importance_type="feature_importance")
|
||||
|
||||
era5_yearly_array = None
|
||||
era5_seasonal_array = None
|
||||
era5_shoulder_array = None
|
||||
if "ERA5-yearly" in members:
|
||||
era5_yearly_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="yearly"
|
||||
)
|
||||
if "ERA5-seasonal" in members:
|
||||
era5_seasonal_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="seasonal"
|
||||
)
|
||||
if "ERA5-shoulder" in members:
|
||||
era5_shoulder_array = extract_era5_features(
|
||||
model_state, importance_type="feature_importance", temporal_group="shoulder"
|
||||
)
|
||||
|
||||
arcticdem_feature_array = None
|
||||
if "ArcticDEM" in members:
|
||||
arcticdem_feature_array = extract_arcticdem_features(model_state, importance_type="feature_importance")
|
||||
|
||||
common_feature_array = extract_common_features(model_state, importance_type="feature_importance")
|
||||
|
||||
# Render each source's features if present
|
||||
if embedding_feature_array is not None:
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
if era5_yearly_array is not None:
|
||||
render_era5_features(era5_yearly_array, temporal_group="Yearly")
|
||||
|
||||
if era5_seasonal_array is not None:
|
||||
render_era5_features(era5_seasonal_array, temporal_group="Seasonal")
|
||||
|
||||
if era5_shoulder_array is not None:
|
||||
render_era5_features(era5_shoulder_array, temporal_group="Shoulder")
|
||||
|
||||
if arcticdem_feature_array is not None:
|
||||
render_arcticdem_features(arcticdem_feature_array)
|
||||
|
||||
if common_feature_array is not None:
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_knn_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for KNN model."""
|
||||
|
|
@ -402,18 +585,18 @@ def render_embedding_features(embedding_feature_array):
|
|||
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
st.altair_chart(chart_agg, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
st.altair_chart(chart_band, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
st.altair_chart(chart_year, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
st.altair_chart(heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
|
|
@ -436,13 +619,21 @@ def render_embedding_features(embedding_feature_array):
|
|||
st.dataframe(top_emb, width="stretch")
|
||||
|
||||
|
||||
def render_era5_features(era5_feature_array):
|
||||
"""Render ERA5 feature visualizations."""
|
||||
with st.container(border=True):
|
||||
st.header("⛅ ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
def render_era5_features(era5_feature_array, temporal_group: str = ""):
|
||||
"""Render ERA5 feature visualizations.
|
||||
|
||||
Args:
|
||||
era5_feature_array: ERA5 feature importance array.
|
||||
temporal_group: Name of the temporal grouping (e.g., "Yearly", "Seasonal", "Shoulder").
|
||||
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
group_suffix = f" ({temporal_group})" if temporal_group else ""
|
||||
|
||||
with st.container(border=True):
|
||||
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
|
||||
st.markdown(
|
||||
f"""
|
||||
Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
|
@ -450,19 +641,50 @@ def render_era5_features(era5_feature_array):
|
|||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
|
||||
charts = plot_era5_summary(era5_feature_array)
|
||||
|
||||
# Check if this is seasonal/shoulder data (returns 3 charts) or yearly (returns 2 charts)
|
||||
if len(charts) == 3:
|
||||
chart_variable, chart_season, chart_year = charts
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_season, width="stretch")
|
||||
with col3:
|
||||
st.altair_chart(chart_year, width="stretch")
|
||||
else:
|
||||
chart_variable, chart_time = charts
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
st.altair_chart(chart_time, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
|
||||
# Check if this is seasonal/shoulder data
|
||||
has_season = "season" in era5_feature_array.dims
|
||||
|
||||
if has_season:
|
||||
st.markdown("Shows the weight of each variable-season-year combination.")
|
||||
with st.spinner("Generating ERA5 season heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, width="stretch")
|
||||
|
||||
# Add time-based heatmap for seasonal/shoulder
|
||||
st.markdown("### By Time Heatmap")
|
||||
st.markdown("Shows temporal trends by averaging over seasons.")
|
||||
with st.spinner("Generating ERA5 time heatmap..."):
|
||||
era5_time_heatmap_chart = plot_era5_time_heatmap(era5_feature_array)
|
||||
if era5_time_heatmap_chart is not None:
|
||||
st.altair_chart(era5_time_heatmap_chart, width="stretch")
|
||||
else:
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
st.altair_chart(era5_heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
|
|
@ -481,10 +703,61 @@ def render_era5_features(era5_feature_array):
|
|||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
# Get all columns except 'weight' for display
|
||||
display_cols = [col for col in era5_df.columns if col != "weight"] + ["weight"]
|
||||
top_era5 = era5_df.nlargest(10, "weight")[display_cols]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
|
||||
|
||||
def render_arcticdem_features(arcticdem_feature_array):
|
||||
"""Render ArcticDEM feature visualizations."""
|
||||
with st.container(border=True):
|
||||
st.header("🏔️ ArcticDEM Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ArcticDEM topographic features showing which terrain variables and
|
||||
aggregations are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ArcticDEM dimension summaries..."):
|
||||
chart_variable, chart_agg = plot_arcticdem_summary(arcticdem_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, width="stretch")
|
||||
with col2:
|
||||
st.altair_chart(chart_agg, width="stretch")
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-aggregation combination.")
|
||||
with st.spinner("Generating ArcticDEM heatmap..."):
|
||||
arcticdem_heatmap_chart = plot_arcticdem_heatmap(arcticdem_feature_array)
|
||||
st.altair_chart(arcticdem_heatmap_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("ArcticDEM Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_arcticdem_features = arcticdem_feature_array.size
|
||||
mean_weight = float(arcticdem_feature_array.mean().values)
|
||||
max_weight = float(arcticdem_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ArcticDEM Features", n_arcticdem_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ArcticDEM features
|
||||
st.write("**Top 10 ArcticDEM Features:**")
|
||||
arcticdem_df = arcticdem_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_arcticdem = arcticdem_df.nlargest(10, "weight")[["variable", "agg", "weight"]]
|
||||
st.dataframe(top_arcticdem, width="stretch")
|
||||
|
||||
|
||||
def render_common_features(common_feature_array):
|
||||
"""Render common feature visualizations."""
|
||||
with st.container(border=True):
|
||||
|
|
@ -499,7 +772,7 @@ def render_common_features(common_feature_array):
|
|||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
st.altair_chart(common_chart, width="stretch")
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def render_overview_page():
|
|||
st.subheader("Detailed Results")
|
||||
|
||||
for tr in training_results:
|
||||
with st.expander(tr.name):
|
||||
with st.expander(tr.get_display_name("task_first")):
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
|||
best_score = results[col].max()
|
||||
best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
|
||||
|
||||
st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True)
|
||||
st.dataframe(pd.DataFrame(best_scores), hide_index=True, width="stretch")
|
||||
|
||||
with col2:
|
||||
st.markdown("#### Score Statistics")
|
||||
|
|
@ -51,7 +51,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
|||
}
|
||||
)
|
||||
|
||||
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
|
||||
st.dataframe(pd.DataFrame(score_stats), hide_index=True, width="stretch")
|
||||
|
||||
# Show best parameter combination in a cleaner format (similar to old dashboard)
|
||||
st.markdown("#### 🏆 Best Parameter Combination")
|
||||
|
|
@ -305,7 +305,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
.properties(height=250, title=param_name)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
else:
|
||||
# Categorical parameter - use bar chart
|
||||
|
|
@ -323,7 +323,7 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
.properties(height=250, title=param_name)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
|
||||
def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
||||
|
|
@ -411,7 +411,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
.properties(height=300, title=f"{metric} vs {param_name}")
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
else:
|
||||
# Categorical parameter - box plot
|
||||
|
|
@ -428,7 +428,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
.properties(height=300, title=f"{metric} vs {param_name}")
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
|
||||
def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
||||
|
|
@ -481,7 +481,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
|||
.properties(height=max(200, len(correlations) * 30))
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
|
||||
def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
||||
|
|
@ -673,7 +673,7 @@ def _render_2d_param_plot(
|
|||
.interactive()
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
|
||||
def render_score_evolution(results: pd.DataFrame, metric: str):
|
||||
|
|
@ -722,7 +722,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
.properties(height=400)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
# Show statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
|
@ -831,19 +831,37 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
|||
else:
|
||||
tooltip_list.append(color_col)
|
||||
|
||||
# Calculate axis domains starting from min values
|
||||
x_min = df_plot[metric1].min()
|
||||
x_max = df_plot[metric1].max()
|
||||
y_min = df_plot[metric2].min()
|
||||
y_max = df_plot[metric2].max()
|
||||
|
||||
# Add small padding (2% of range) for better visualization
|
||||
x_padding = (x_max - x_min) * 0.02
|
||||
y_padding = (y_max - y_min) * 0.02
|
||||
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_circle(size=60, opacity=0.6)
|
||||
.encode(
|
||||
alt.X(metric1, title=metric1.replace("_", " ").title()),
|
||||
alt.Y(metric2, title=metric2.replace("_", " ").title()),
|
||||
alt.X(
|
||||
metric1,
|
||||
title=metric1.replace("_", " ").title(),
|
||||
scale=alt.Scale(domain=[x_min - x_padding, x_max + x_padding]),
|
||||
),
|
||||
alt.Y(
|
||||
metric2,
|
||||
title=metric2.replace("_", " ").title(),
|
||||
scale=alt.Scale(domain=[y_min - y_padding, y_max + y_padding]),
|
||||
),
|
||||
alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
|
||||
tooltip=tooltip_list,
|
||||
)
|
||||
.properties(height=500)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
# Calculate correlation
|
||||
corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
|
||||
|
|
@ -1015,7 +1033,7 @@ def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin
|
|||
)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
st.altair_chart(chart, width="stretch")
|
||||
|
||||
# Show statistics about the binning
|
||||
n_bins = len(bin_order)
|
||||
|
|
@ -1073,4 +1091,4 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
|||
score_col_display = metric.replace("_", " ").title()
|
||||
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
|
|
|
|||
|
|
@ -69,6 +69,13 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
|
|||
Altair chart showing the heatmap.
|
||||
|
||||
"""
|
||||
# Filter out "count" aggregation if present
|
||||
if "agg" in embedding_array.dims:
|
||||
agg_values = embedding_array.coords["agg"].values
|
||||
non_count_aggs = [agg for agg in agg_values if agg != "count"]
|
||||
if non_count_aggs:
|
||||
embedding_array = embedding_array.sel(agg=non_count_aggs)
|
||||
|
||||
# Convert to DataFrame for plotting
|
||||
df = embedding_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
|
|
@ -81,7 +88,7 @@ def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
|
|||
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
scale=alt.Scale(scheme="blues"),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
|
|
@ -108,15 +115,22 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a
|
|||
Tuple of three Altair charts (by_agg, by_band, by_year).
|
||||
|
||||
"""
|
||||
# Filter out "count" aggregation if present
|
||||
if "agg" in embedding_array.dims:
|
||||
agg_values = embedding_array.coords["agg"].values
|
||||
non_count_aggs = [agg for agg in agg_values if agg != "count"]
|
||||
if non_count_aggs:
|
||||
embedding_array = embedding_array.sel(agg=non_count_aggs)
|
||||
|
||||
# Aggregate by different dimensions
|
||||
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
|
||||
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
|
||||
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
|
||||
|
||||
# Create DataFrames
|
||||
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
|
||||
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
|
||||
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
|
||||
# Create DataFrames, handling potential MultiIndex
|
||||
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
|
||||
df_band = pd.DataFrame({"dimension": by_band.index.astype(str), "mean_abs_weight": by_band.values})
|
||||
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
|
||||
|
||||
# Sort by weight
|
||||
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
|
||||
|
|
@ -185,10 +199,10 @@ def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[a
|
|||
|
||||
|
||||
def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a heatmap showing ERA5 feature weights across variables and time.
|
||||
"""Create a heatmap showing ERA5 feature weights across variables and time/season.
|
||||
|
||||
Args:
|
||||
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||
era5_array: DataArray with dimensions (variable, time) or (variable, season, year) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the heatmap.
|
||||
|
|
@ -197,7 +211,34 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
|||
# Convert to DataFrame for plotting
|
||||
df = era5_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
# Determine if this is seasonal/shoulder data (has 'season' dimension)
|
||||
has_season = "season" in df.columns
|
||||
|
||||
if has_season:
|
||||
# For seasonal/shoulder: create faceted heatmap by season
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("year:N", title="Year", sort=None),
|
||||
y=alt.Y("variable:N", title="Variable", sort="-color"),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("variable:N", title="Variable"),
|
||||
alt.Tooltip("season:N", title="Season"),
|
||||
alt.Tooltip("year:N", title="Year"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=200, height=300)
|
||||
.facet(facet=alt.Facet("season:N", title="Season"), columns=4)
|
||||
)
|
||||
else:
|
||||
# For yearly: simple heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
|
|
@ -224,23 +265,161 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
|||
return chart
|
||||
|
||||
|
||||
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing ERA5 weights by variable and time.
|
||||
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a heatmap showing ERA5 feature weights by variable and year (averaging over season).
|
||||
|
||||
This is specifically for seasonal/shoulder data to show temporal trends.
|
||||
|
||||
Args:
|
||||
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||
era5_array: DataArray with dimensions (variable, season, year, agg) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Tuple of two Altair charts (by_variable, by_time).
|
||||
Altair chart showing the variable-year heatmap.
|
||||
|
||||
"""
|
||||
# Aggregate by different dimensions
|
||||
by_variable = era5_array.mean(dim="time").to_pandas().abs()
|
||||
by_time = era5_array.mean(dim="variable").to_pandas().abs()
|
||||
# Check if this has season dimension
|
||||
if "season" not in era5_array.dims:
|
||||
return None
|
||||
|
||||
# Average over season and agg to get (variable, year)
|
||||
dims_to_avg = [d for d in era5_array.dims if d not in ["variable", "year"]]
|
||||
by_time = era5_array.mean(dim=dims_to_avg)
|
||||
|
||||
# Convert to DataFrame for plotting
|
||||
df = by_time.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("year:N", title="Year", sort=None),
|
||||
y=alt.Y("variable:N", title="Variable", sort="-color"),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("variable:N", title="Variable"),
|
||||
alt.Tooltip("year:N", title="Year"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Avg Weight"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
height=400,
|
||||
title="ERA5 Feature Weights by Time (averaged over seasons)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, ...]:
|
||||
"""Create bar charts summarizing ERA5 weights by variable and time/season.
|
||||
|
||||
Args:
|
||||
era5_array: DataArray with dimensions (variable, time) or (variable, season, year, agg) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Tuple of Altair charts:
|
||||
- For yearly data: (by_variable, by_time)
|
||||
- For seasonal/shoulder data: (by_variable, by_season, by_year)
|
||||
|
||||
"""
|
||||
# Determine which dimensions to average over
|
||||
dims_to_avg_for_var = [d for d in era5_array.dims if d != "variable"]
|
||||
|
||||
# Check if this is seasonal/shoulder data
|
||||
has_season = "season" in era5_array.dims
|
||||
|
||||
# Aggregate by variable (average over all other dimensions)
|
||||
by_variable = era5_array.mean(dim=dims_to_avg_for_var).to_pandas().abs()
|
||||
|
||||
if has_season:
|
||||
# For seasonal/shoulder: aggregate by season and year
|
||||
dims_to_avg_for_season = [d for d in era5_array.dims if d != "season"]
|
||||
by_season = era5_array.mean(dim=dims_to_avg_for_season).to_pandas().abs()
|
||||
|
||||
dims_to_avg_for_year = [d for d in era5_array.dims if d != "year"]
|
||||
by_year = era5_array.mean(dim=dims_to_avg_for_year).to_pandas().abs()
|
||||
|
||||
# Create DataFrames
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
|
||||
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_season = pd.DataFrame({"dimension": by_season.index.astype(str), "mean_abs_weight": by_season.values})
|
||||
df_year = pd.DataFrame({"dimension": by_year.index.astype(str), "mean_abs_weight": by_year.values})
|
||||
|
||||
# Sort by weight
|
||||
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
|
||||
df_season = df_season.sort_values("mean_abs_weight", ascending=True)
|
||||
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
|
||||
|
||||
chart_variable = (
|
||||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="purples"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Variable"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Variable")
|
||||
)
|
||||
|
||||
chart_season = (
|
||||
alt.Chart(df_season)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Season", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="teals"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Season"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Season")
|
||||
)
|
||||
|
||||
chart_year = (
|
||||
alt.Chart(df_year)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:O", title="Year", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="oranges"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:O", title="Year"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Time")
|
||||
)
|
||||
|
||||
return chart_variable, chart_season, chart_year
|
||||
else:
|
||||
# For yearly: aggregate by time
|
||||
dims_to_avg_for_time = [d for d in era5_array.dims if d != "time"]
|
||||
by_time = era5_array.mean(dim=dims_to_avg_for_time).to_pandas().abs()
|
||||
|
||||
# Create DataFrames, handling potential MultiIndex
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_time = pd.DataFrame({"dimension": by_time.index.astype(str), "mean_abs_weight": by_time.values})
|
||||
|
||||
# Sort by weight
|
||||
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
|
||||
|
|
@ -288,6 +467,111 @@ def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
|
|||
return chart_variable, chart_time
|
||||
|
||||
|
||||
def plot_arcticdem_heatmap(arcticdem_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a heatmap showing ArcticDEM feature weights across variables and aggregations.
|
||||
|
||||
Args:
|
||||
arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the heatmap.
|
||||
|
||||
"""
|
||||
# Convert to DataFrame for plotting
|
||||
df = arcticdem_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("agg:N", title="Aggregation", sort=None),
|
||||
y=alt.Y("variable:N", title="Variable", sort="-color"),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("variable:N", title="Variable"),
|
||||
alt.Tooltip("agg:N", title="Aggregation"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=400,
|
||||
height=300,
|
||||
title="ArcticDEM Feature Weights Heatmap",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_arcticdem_summary(arcticdem_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing ArcticDEM weights by variable and aggregation.
|
||||
|
||||
Args:
|
||||
arcticdem_array: DataArray with dimensions (variable, agg) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Tuple of two Altair charts (by_variable, by_agg).
|
||||
|
||||
"""
|
||||
# Aggregate by different dimensions
|
||||
by_variable = arcticdem_array.mean(dim="agg").to_pandas().abs()
|
||||
by_agg = arcticdem_array.mean(dim="variable").to_pandas().abs()
|
||||
|
||||
# Create DataFrames, handling potential MultiIndex
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index.astype(str), "mean_abs_weight": by_variable.values})
|
||||
df_agg = pd.DataFrame({"dimension": by_agg.index.astype(str), "mean_abs_weight": by_agg.values})
|
||||
|
||||
# Sort by weight
|
||||
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
|
||||
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
|
||||
|
||||
# Create charts with different colors
|
||||
chart_variable = (
|
||||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="browns"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Variable"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Variable")
|
||||
)
|
||||
|
||||
chart_agg = (
|
||||
alt.Chart(df_agg)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Aggregation", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="reds"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Aggregation"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Aggregation")
|
||||
)
|
||||
|
||||
return chart_variable, chart_agg
|
||||
|
||||
|
||||
def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart:
|
||||
"""Create a heatmap showing which boxes are assigned to which labels/classes.
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ def render_training_analysis_page():
|
|||
with st.sidebar:
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options
|
||||
training_options = {tr.name: tr for tr in training_results}
|
||||
# Create selection options with task-first naming
|
||||
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
|
|
@ -163,7 +163,7 @@ def render_training_analysis_page():
|
|||
with st.expander("📋 Parameter Space Summary", expanded=False):
|
||||
param_summary = get_parameter_space_summary(results)
|
||||
if not param_summary.empty:
|
||||
st.dataframe(param_summary, hide_index=True, use_container_width=True)
|
||||
st.dataframe(param_summary, hide_index=True, width="stretch")
|
||||
else:
|
||||
st.info("No parameter information available.")
|
||||
|
||||
|
|
@ -226,7 +226,7 @@ def render_training_analysis_page():
|
|||
data=csv_data,
|
||||
file_name=f"{selected_result.path.name}_results.csv",
|
||||
mime="text/csv",
|
||||
use_container_width=True,
|
||||
width="stretch",
|
||||
)
|
||||
|
||||
with col2:
|
||||
|
|
@ -239,9 +239,9 @@ def render_training_analysis_page():
|
|||
data=settings_json,
|
||||
file_name=f"{selected_result.path.name}_settings.json",
|
||||
mime="application/json",
|
||||
use_container_width=True,
|
||||
width="stretch",
|
||||
)
|
||||
|
||||
# Show raw data preview
|
||||
st.subheader("Raw Data Preview")
|
||||
st.dataframe(results.head(100), use_container_width=True)
|
||||
st.dataframe(results.head(100), width="stretch")
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def render_training_data_page():
|
|||
|
||||
# Form submit button
|
||||
load_button = st.form_submit_button(
|
||||
"Load Dataset", type="primary", use_container_width=True, disabled=len(selected_members) == 0
|
||||
"Load Dataset", type="primary", width="stretch", disabled=len(selected_members) == 0
|
||||
)
|
||||
|
||||
# Create DatasetEnsemble only when form is submitted
|
||||
|
|
|
|||
|
|
@ -25,6 +25,27 @@ class TrainingResult:
|
|||
results: pd.DataFrame
|
||||
created_at: float
|
||||
|
||||
def get_display_name(self, format_type: str = "task_first") -> str:
|
||||
"""Get formatted display name for the training result.
|
||||
|
||||
Args:
|
||||
format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state)
|
||||
|
||||
Returns:
|
||||
Formatted name string
|
||||
|
||||
"""
|
||||
task = self.settings.get("task", "Unknown").capitalize()
|
||||
model = self.settings.get("model", "Unknown").upper()
|
||||
grid = self.settings.get("grid", "Unknown").capitalize()
|
||||
level = self.settings.get("level", "Unknown")
|
||||
timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
if format_type == "model_first":
|
||||
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
||||
else: # task_first
|
||||
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
|
|
@ -38,9 +59,11 @@ class TrainingResult:
|
|||
settings = toml.load(settings_file)["settings"]
|
||||
results = pd.read_parquet(result_file)
|
||||
|
||||
# Name should be "task grid-level (created_at)"
|
||||
# Name should be "task model grid-level (created_at)"
|
||||
model = settings.get("model", "Unknown").upper()
|
||||
name = (
|
||||
f"**{settings.get('task', 'Unknown').capitalize()}** -"
|
||||
f"{settings.get('task', 'Unknown').capitalize()} -"
|
||||
f" {model} -"
|
||||
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
|
||||
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
|
||||
)
|
||||
|
|
@ -121,11 +144,40 @@ def load_source_data(e: DatasetEnsemble, source: str):
|
|||
return ds, targets
|
||||
|
||||
|
||||
def extract_embedding_features(model_state) -> xr.DataArray | None:
|
||||
"""Extract embedding features from the model state.
|
||||
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
|
||||
"""Get the appropriate feature importance array from model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
importance_type: Type of importance to extract. For ESPA: 'feature_weights'.
|
||||
For XGBoost: 'feature_importance_weight', 'feature_importance_gain', etc.
|
||||
For RF: 'feature_importance'.
|
||||
|
||||
Returns:
|
||||
xr.DataArray with feature importance values.
|
||||
|
||||
"""
|
||||
if importance_type in model_state:
|
||||
return model_state[importance_type]
|
||||
# Fallback for compatibility
|
||||
if "feature_weights" in model_state:
|
||||
return model_state["feature_weights"]
|
||||
if "feature_importance" in model_state:
|
||||
return model_state["feature_importance"]
|
||||
raise ValueError(f"No feature importance array found. Available: {list(model_state.data_vars)}")
|
||||
|
||||
|
||||
def extract_embedding_features(
|
||||
model_state: xr.Dataset, importance_type: str = "feature_weights"
|
||||
) -> xr.DataArray | None:
|
||||
"""Extract embedding features from the model state.
|
||||
|
||||
Feature naming pattern: `embeddings_{agg}_{band}_{year}`
|
||||
Example: `embeddings_mean_B02_2020`
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
importance_type: Type of feature importance to extract.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||
|
|
@ -135,14 +187,17 @@ def extract_embedding_features(model_state) -> xr.DataArray | None:
|
|||
"""
|
||||
|
||||
def _is_embedding_feature(feature: str) -> bool:
|
||||
return feature.startswith("embedding_")
|
||||
return feature.startswith("embeddings_")
|
||||
|
||||
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||
if len(embedding_features) == 0:
|
||||
return None
|
||||
|
||||
# Get the appropriate importance array
|
||||
importance_array = _get_feature_importance_array(model_state, importance_type)
|
||||
|
||||
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||
embedding_feature_array = importance_array.sel(feature=embedding_features)
|
||||
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||
agg=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||
|
|
@ -152,15 +207,26 @@ def extract_embedding_features(model_state) -> xr.DataArray | None:
|
|||
return embedding_feature_array
|
||||
|
||||
|
||||
def extract_era5_features(model_state) -> xr.DataArray | None:
|
||||
def extract_era5_features(
|
||||
model_state: xr.Dataset, importance_type: str = "feature_weights", temporal_group: str | None = None
|
||||
) -> xr.DataArray | None:
|
||||
"""Extract ERA5 features from the model state.
|
||||
|
||||
Feature naming patterns:
|
||||
- Without aggregations: `era5_{variable}_{time}`
|
||||
Example: `era5_temperature_2020_summer`
|
||||
- With aggregations: `era5_{variable}_{time}_{agg}`
|
||||
Example: `era5_temperature_2020_summer_mean`
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
importance_type: Type of feature importance to extract.
|
||||
temporal_group: Filter to specific temporal group ('yearly', 'seasonal', 'shoulder').
|
||||
If None, extracts all ERA5 features.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||
('variable', 'time') or ('variable', 'time', 'agg') depending on the feature structure.
|
||||
Returns None if no ERA5 features are found.
|
||||
|
||||
"""
|
||||
|
|
@ -168,34 +234,254 @@ def extract_era5_features(model_state) -> xr.DataArray | None:
|
|||
def _is_era5_feature(feature: str) -> bool:
|
||||
return feature.startswith("era5_")
|
||||
|
||||
def _get_temporal_group(feature: str) -> str:
|
||||
"""Determine temporal group from feature name based on time component.
|
||||
|
||||
Time patterns:
|
||||
- Yearly: just a year number (e.g., "2020")
|
||||
- Seasonal: season_year (e.g., "summer_2020", "winter_2021")
|
||||
- Shoulder: SHOULDER_year (e.g., "JFM_2020", "OND_2021")
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
|
||||
# Find where the time part starts (after "era5" and variable name)
|
||||
# Pattern: era5_variable_time or era5_variable_time_agg
|
||||
# Time can be: year, season_year, or SHOULDER_year
|
||||
|
||||
# Remove era5 prefix and aggregation suffix if present
|
||||
if parts[-1] in common_aggs:
|
||||
time_parts = parts[1:-1] # Remove era5 and agg
|
||||
else:
|
||||
time_parts = parts[1:] # Remove era5
|
||||
|
||||
# The last part (or last two parts) is the time
|
||||
# Check last part first
|
||||
last_part = time_parts[-1]
|
||||
|
||||
if last_part.isdigit():
|
||||
# Could be yearly (just year) or part of seasonal/shoulder (season_year or SHOULDER_year)
|
||||
if len(time_parts) >= 2:
|
||||
second_last = time_parts[-2]
|
||||
# Check if second to last is a season or shoulder indicator
|
||||
if second_last.lower() in ["summer", "winter"]:
|
||||
return "seasonal"
|
||||
elif second_last.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
return "shoulder"
|
||||
# Just a year number
|
||||
return "yearly"
|
||||
elif last_part.lower() in ["summer", "winter"]:
|
||||
return "seasonal"
|
||||
elif last_part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
return "shoulder"
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _extract_var_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# era5_variablename_timetype format
|
||||
return "_".join(parts[1:-1])
|
||||
# Pattern: era5_variable_year_agg OR era5_variable_season_year_agg
|
||||
# Variable can include stat names like t2m_max, t2m_min, etc.
|
||||
# We need to find where the year starts (it's always a 4-digit number)
|
||||
|
||||
# Find the first year (4-digit number)
|
||||
year_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part.isdigit() and len(part) == 4:
|
||||
year_idx = i
|
||||
break
|
||||
|
||||
if year_idx is None:
|
||||
# Fallback: return everything after era5
|
||||
return "_".join(parts[1:])
|
||||
|
||||
# Variable is everything between era5 and year
|
||||
# But also check if there's a season before the year
|
||||
seasons = {"summer", "winter"}
|
||||
shoulders = {"JFM", "AMJ", "JAS", "OND"}
|
||||
|
||||
var_end_idx = year_idx
|
||||
# Check if the part before year is a season/shoulder
|
||||
if year_idx > 1 and (parts[year_idx - 1].lower() in seasons or parts[year_idx - 1].upper() in shoulders):
|
||||
var_end_idx = year_idx - 1
|
||||
|
||||
# Variable is from index 1 (after era5) to var_end_idx
|
||||
return "_".join(parts[1:var_end_idx])
|
||||
|
||||
def _extract_time_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# Last part is the time type
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
if parts[-1] in common_aggs:
|
||||
# Has aggregation: era5_var_time_agg -> time is second to last
|
||||
return parts[-2]
|
||||
else:
|
||||
# No aggregation: era5_var_time -> time is last
|
||||
return parts[-1]
|
||||
|
||||
def _extract_season(feature: str) -> str | None:
|
||||
"""Extract season/shoulder from seasonal/shoulder features.
|
||||
|
||||
Pattern: era5_variable_season_year_agg or era5_variable_season_year
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
|
||||
# Look through parts to find season/shoulder indicators
|
||||
for part in parts:
|
||||
if part.lower() in ["summer", "winter"]:
|
||||
return part.lower()
|
||||
elif part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
return part.upper()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_year_from_time(feature: str) -> str:
|
||||
"""Extract just the year component from time.
|
||||
|
||||
For seasonal/shoulder features, find the year that comes after the season.
|
||||
"""
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
|
||||
# Find the season/shoulder part, then the next part should be the year
|
||||
for i, part in enumerate(parts):
|
||||
if part.lower() in ["summer", "winter"] or part.upper() in ["JFM", "AMJ", "JAS", "OND"]:
|
||||
# Next part should be the year
|
||||
if i + 1 < len(parts):
|
||||
next_part = parts[i + 1]
|
||||
# Skip if it's an aggregation, year should be before agg
|
||||
if next_part not in common_aggs:
|
||||
return next_part
|
||||
|
||||
# Fallback: return last numeric part that's not an aggregation
|
||||
for part in reversed(parts):
|
||||
if part.isdigit():
|
||||
return part
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _extract_agg_name(feature: str) -> str | None:
|
||||
parts = feature.split("_")
|
||||
common_aggs = {"mean", "std", "min", "max", "median", "sum", "count", "q25", "q75"}
|
||||
if parts[-1] in common_aggs:
|
||||
return parts[-1]
|
||||
return None
|
||||
|
||||
# Get all ERA5 features
|
||||
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||
|
||||
# Filter by temporal group if specified
|
||||
if temporal_group is not None:
|
||||
era5_features = [f for f in era5_features if _get_temporal_group(f) == temporal_group]
|
||||
|
||||
if len(era5_features) == 0:
|
||||
return None
|
||||
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||
|
||||
# Get the appropriate importance array
|
||||
importance_array = _get_feature_importance_array(model_state, importance_type)
|
||||
|
||||
# Check if features have aggregations
|
||||
has_agg = any(_extract_agg_name(f) is not None for f in era5_features)
|
||||
|
||||
# Check if features have season/shoulder split (for seasonal/shoulder temporal groups)
|
||||
has_season = any(_extract_season(f) is not None for f in era5_features)
|
||||
|
||||
# Split the single feature dimension of era5 features into separate dimensions
|
||||
era5_features_array = importance_array.sel(feature=era5_features)
|
||||
|
||||
if has_season:
|
||||
# For seasonal/shoulder: split into variable, season, year, (agg)
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||
season=("feature", [_extract_season(f) for f in era5_features]),
|
||||
year=("feature", [_extract_year_from_time(f) for f in era5_features]),
|
||||
)
|
||||
|
||||
if has_agg:
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year", "agg"]).unstack(
|
||||
"feature"
|
||||
) # noqa: PD010
|
||||
else:
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "season", "year"]).unstack(
|
||||
"feature"
|
||||
) # noqa: PD010
|
||||
else:
|
||||
# For yearly: keep as variable, time, (agg)
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||
)
|
||||
|
||||
if has_agg:
|
||||
# Add aggregation dimension
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010
|
||||
else:
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||
|
||||
return era5_features_array
|
||||
|
||||
|
||||
def extract_common_features(model_state) -> xr.DataArray | None:
|
||||
def extract_arcticdem_features(
|
||||
model_state: xr.Dataset, importance_type: str = "feature_weights"
|
||||
) -> xr.DataArray | None:
|
||||
"""Extract ArcticDEM features from the model state.
|
||||
|
||||
Feature naming pattern: `arcticdem_{variable}_{agg}`
|
||||
Example: `arcticdem_elevation_mean`, `arcticdem_slope_std`
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
importance_type: Type of feature importance to extract.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted ArcticDEM features. This DataArray has dimensions
|
||||
('variable', 'agg') corresponding to the different components of the ArcticDEM features.
|
||||
Returns None if no ArcticDEM features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_arcticdem_feature(feature: str) -> bool:
|
||||
return feature.startswith("arcticdem_")
|
||||
|
||||
def _extract_var_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# arcticdem_variable_agg format
|
||||
# Variable name is everything between arcticdem_ and the last part
|
||||
return "_".join(parts[1:-1])
|
||||
|
||||
def _extract_agg_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# Last part is the aggregation
|
||||
return parts[-1]
|
||||
|
||||
arcticdem_features = [f for f in model_state.feature.to_numpy() if _is_arcticdem_feature(f)]
|
||||
if len(arcticdem_features) == 0:
|
||||
return None
|
||||
|
||||
# Get the appropriate importance array
|
||||
importance_array = _get_feature_importance_array(model_state, importance_type)
|
||||
|
||||
# Split the single feature dimension into separate dimensions (variable, agg)
|
||||
arcticdem_feature_array = importance_array.sel(feature=arcticdem_features)
|
||||
arcticdem_feature_array = arcticdem_feature_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
|
||||
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
|
||||
)
|
||||
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010
|
||||
return arcticdem_feature_array
|
||||
|
||||
|
||||
def extract_common_features(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray | None:
|
||||
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
importance_type: Type of feature importance to extract.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted common features with a single 'feature' dimension.
|
||||
|
|
@ -211,6 +497,22 @@ def extract_common_features(model_state) -> xr.DataArray | None:
|
|||
if len(common_features) == 0:
|
||||
return None
|
||||
|
||||
# Extract the feature weights for common features
|
||||
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
|
||||
# Get the appropriate importance array
|
||||
importance_array = _get_feature_importance_array(model_state, importance_type)
|
||||
|
||||
# Extract the feature importance for common features
|
||||
common_feature_array = importance_array.sel(feature=common_features)
|
||||
return common_feature_array
|
||||
|
||||
|
||||
def get_members_from_settings(settings: dict) -> list[str]:
|
||||
"""Extract the list of dataset members used in training from settings.
|
||||
|
||||
Args:
|
||||
settings: Training settings dictionary.
|
||||
|
||||
Returns:
|
||||
List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']).
|
||||
|
||||
"""
|
||||
return settings.get("members", [])
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ def random_cv(
|
|||
booster = best_estimator.get_booster()
|
||||
|
||||
# Feature importance with different importance types
|
||||
# Note: get_score() returns dict with keys like 'f0', 'f1', etc. (feature indices)
|
||||
importance_weight = booster.get_score(importance_type="weight")
|
||||
importance_gain = booster.get_score(importance_type="gain")
|
||||
importance_cover = booster.get_score(importance_type="cover")
|
||||
|
|
@ -286,8 +287,11 @@ def random_cv(
|
|||
|
||||
# Create aligned arrays for all features (including zero-importance)
|
||||
def align_importance(importance_dict, features):
|
||||
"""Align importance dict to feature list, filling missing with 0."""
|
||||
return [importance_dict.get(f, 0.0) for f in features]
|
||||
"""Align importance dict to feature list, filling missing with 0.
|
||||
|
||||
XGBoost returns feature indices (f0, f1, ...) as keys, so we need to map them.
|
||||
"""
|
||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
||||
|
||||
feature_importance_weight = xr.DataArray(
|
||||
align_importance(importance_weight, features),
|
||||
|
|
|
|||
146
test_feature_extraction.py
Normal file
146
test_feature_extraction.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Test script to verify feature extraction works correctly."""
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
|
||||
# Create a mock model state with various feature types
|
||||
features = [
|
||||
# Embedding features: embedding_{agg}_{band}_{year}
|
||||
"embedding_mean_B02_2020",
|
||||
"embedding_std_B03_2021",
|
||||
"embedding_max_B04_2022",
|
||||
# ERA5 features without aggregations: era5_{variable}_{time}
|
||||
"era5_temperature_2020_summer",
|
||||
"era5_precipitation_2021_winter",
|
||||
# ERA5 features with aggregations: era5_{variable}_{time}_{agg}
|
||||
"era5_temperature_2020_summer_mean",
|
||||
"era5_precipitation_2021_winter_std",
|
||||
# ArcticDEM features: arcticdem_{variable}_{agg}
|
||||
"arcticdem_elevation_mean",
|
||||
"arcticdem_slope_std",
|
||||
"arcticdem_aspect_max",
|
||||
# Common features
|
||||
"cell_area",
|
||||
"water_area",
|
||||
"land_area",
|
||||
"land_ratio",
|
||||
"lon",
|
||||
"lat",
|
||||
]
|
||||
|
||||
# Create mock importance values
|
||||
importance_values = np.random.rand(len(features))
|
||||
|
||||
# Create a mock model state for ESPA
|
||||
model_state_espa = xr.Dataset(
|
||||
{
|
||||
"feature_weights": xr.DataArray(
|
||||
importance_values,
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock model state for XGBoost
|
||||
model_state_xgb = xr.Dataset(
|
||||
{
|
||||
"feature_importance_gain": xr.DataArray(
|
||||
importance_values,
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
),
|
||||
"feature_importance_weight": xr.DataArray(
|
||||
importance_values * 0.8,
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock model state for Random Forest
|
||||
model_state_rf = xr.Dataset(
|
||||
{
|
||||
"feature_importance": xr.DataArray(
|
||||
importance_values,
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Test extraction functions
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print("Testing ESPA model state")
|
||||
print("=" * 80)
|
||||
|
||||
embedding_array = extract_embedding_features(model_state_espa)
|
||||
print(f"\nEmbedding features extracted: {embedding_array is not None}")
|
||||
if embedding_array is not None:
|
||||
print(f" Dimensions: {embedding_array.dims}")
|
||||
print(f" Shape: {embedding_array.shape}")
|
||||
print(f" Coordinates: {list(embedding_array.coords)}")
|
||||
|
||||
era5_array = extract_era5_features(model_state_espa)
|
||||
print(f"\nERA5 features extracted: {era5_array is not None}")
|
||||
if era5_array is not None:
|
||||
print(f" Dimensions: {era5_array.dims}")
|
||||
print(f" Shape: {era5_array.shape}")
|
||||
print(f" Coordinates: {list(era5_array.coords)}")
|
||||
|
||||
arcticdem_array = extract_arcticdem_features(model_state_espa)
|
||||
print(f"\nArcticDEM features extracted: {arcticdem_array is not None}")
|
||||
if arcticdem_array is not None:
|
||||
print(f" Dimensions: {arcticdem_array.dims}")
|
||||
print(f" Shape: {arcticdem_array.shape}")
|
||||
print(f" Coordinates: {list(arcticdem_array.coords)}")
|
||||
|
||||
common_array = extract_common_features(model_state_espa)
|
||||
print(f"\nCommon features extracted: {common_array is not None}")
|
||||
if common_array is not None:
|
||||
print(f" Dimensions: {common_array.dims}")
|
||||
print(f" Shape: {common_array.shape}")
|
||||
print(f" Size: {common_array.size}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing XGBoost model state")
|
||||
print("=" * 80)
|
||||
|
||||
embedding_array_xgb = extract_embedding_features(model_state_xgb, importance_type="feature_importance_gain")
|
||||
print(f"\nEmbedding features (gain) extracted: {embedding_array_xgb is not None}")
|
||||
if embedding_array_xgb is not None:
|
||||
print(f" Dimensions: {embedding_array_xgb.dims}")
|
||||
print(f" Shape: {embedding_array_xgb.shape}")
|
||||
|
||||
era5_array_xgb = extract_era5_features(model_state_xgb, importance_type="feature_importance_weight")
|
||||
print(f"\nERA5 features (weight) extracted: {era5_array_xgb is not None}")
|
||||
if era5_array_xgb is not None:
|
||||
print(f" Dimensions: {era5_array_xgb.dims}")
|
||||
print(f" Shape: {era5_array_xgb.shape}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing Random Forest model state")
|
||||
print("=" * 80)
|
||||
|
||||
embedding_array_rf = extract_embedding_features(model_state_rf, importance_type="feature_importance")
|
||||
print(f"\nEmbedding features extracted: {embedding_array_rf is not None}")
|
||||
if embedding_array_rf is not None:
|
||||
print(f" Dimensions: {embedding_array_rf.dims}")
|
||||
print(f" Shape: {embedding_array_rf.shape}")
|
||||
|
||||
arcticdem_array_rf = extract_arcticdem_features(model_state_rf, importance_type="feature_importance")
|
||||
print(f"\nArcticDEM features extracted: {arcticdem_array_rf is not None}")
|
||||
if arcticdem_array_rf is not None:
|
||||
print(f" Dimensions: {arcticdem_array_rf.dims}")
|
||||
print(f" Shape: {arcticdem_array_rf.shape}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("All tests completed successfully!")
|
||||
print("=" * 80)
|
||||
Loading…
Add table
Add a link
Reference in a new issue