Make the Model State Page great again

This commit is contained in:
Tobias Hölzer 2025-12-25 18:19:11 +01:00
parent 591da6992e
commit 1919cc6a7e
13 changed files with 1375 additions and 142 deletions

View file

@ -1,11 +1,22 @@
#! /bin/bash
pixi run darts extract_darts_mllabels --grid hex --level 3
pixi run darts extract_darts_mllabels --grid hex --level 4
pixi run darts extract_darts_mllabels --grid hex --level 5
pixi run darts extract_darts_mllabels --grid hex --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 7
pixi run darts extract_darts_mllabels --grid healpix --level 8
pixi run darts extract_darts_mllabels --grid healpix --level 9
pixi run darts extract_darts_mllabels --grid healpix --level 10
# pixi shell
darts extract-darts-rts --grid hex --level 3
darts extract-darts-rts --grid hex --level 4
darts extract-darts-rts --grid hex --level 5
darts extract-darts-rts --grid hex --level 6
darts extract-darts-rts --grid healpix --level 6
darts extract-darts-rts --grid healpix --level 7
darts extract-darts-rts --grid healpix --level 8
darts extract-darts-rts --grid healpix --level 9
darts extract-darts-rts --grid healpix --level 10
darts extract-darts-mllabels --grid hex --level 3
darts extract-darts-mllabels --grid hex --level 4
darts extract-darts-mllabels --grid hex --level 5
darts extract-darts-mllabels --grid hex --level 6
darts extract-darts-mllabels --grid healpix --level 6
darts extract-darts-mllabels --grid healpix --level 7
darts extract-darts-mllabels --grid healpix --level 8
darts extract-darts-mllabels --grid healpix --level 9
darts extract-darts-mllabels --grid healpix --level 10

4
scripts/05train.sh Normal file
View file

@ -0,0 +1,4 @@
#!/bin/bash
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model espa
pixi run train --grid hex --level 5 --target darts_mllabels --n-iter 1000 --task density --model xgboost

View file

@ -0,0 +1,191 @@
"""Fix XGBoost feature importance in existing model state files.
This script repairs XGBoost model state files that have all-zero feature importance
values due to incorrect feature name lookup. It reloads the pickled models and
regenerates the feature importance arrays with the correct feature index mapping.
"""
import pickle
from pathlib import Path
import toml
import xarray as xr
from rich import print
from entropice.paths import RESULTS_DIR
def fix_xgboost_model_state(results_dir: Path) -> bool:
"""Fix a single XGBoost model state file.
Args:
results_dir: Directory containing the model files.
Returns:
True if fixed successfully, False otherwise.
"""
# Check if this is an XGBoost model
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
return False
settings = toml.load(settings_file)
model_type = settings.get("settings", {}).get("model", "")
if model_type != "xgboost":
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
return False
# Check if required files exist
model_file = results_dir / "best_estimator_model.pkl"
state_file = results_dir / "best_estimator_state.nc"
if not model_file.exists():
print(f"⚠️ Missing model file in {results_dir.name}")
return False
if not state_file.exists():
print(f"⚠️ Missing state file in {results_dir.name}")
return False
# Load the pickled model
print(f"Loading model from {results_dir.name}...")
with open(model_file, "rb") as f:
best_estimator = pickle.load(f)
# Load the old state to get feature names
old_state = xr.open_dataset(state_file, engine="h5netcdf")
features = old_state.coords["feature"].values.tolist()
old_state.close()
# Get the booster and extract feature importance with correct mapping
booster = best_estimator.get_booster()
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
importance_total_gain = booster.get_score(importance_type="total_gain")
importance_total_cover = booster.get_score(importance_type="total_cover")
# Align importance using feature indices (f0, f1, ...)
def align_importance(importance_dict, features):
"""Align importance dict to feature list using feature indices."""
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
# Create new DataArrays
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_weight",
attrs={"description": "Number of times a feature is used to split the data across all trees."},
)
feature_importance_gain = xr.DataArray(
align_importance(importance_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_gain",
attrs={"description": "Average gain across all splits the feature is used in."},
)
feature_importance_cover = xr.DataArray(
align_importance(importance_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_cover",
attrs={"description": "Average coverage across all splits the feature is used in."},
)
feature_importance_total_gain = xr.DataArray(
align_importance(importance_total_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_gain",
attrs={"description": "Total gain across all splits the feature is used in."},
)
feature_importance_total_cover = xr.DataArray(
align_importance(importance_total_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_cover",
attrs={"description": "Total coverage across all splits the feature is used in."},
)
# Create new state dataset
n_trees = booster.num_boosted_rounds()
state = xr.Dataset(
{
"feature_importance_weight": feature_importance_weight,
"feature_importance_gain": feature_importance_gain,
"feature_importance_cover": feature_importance_cover,
"feature_importance_total_gain": feature_importance_total_gain,
"feature_importance_total_cover": feature_importance_total_cover,
},
attrs={
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
"n_trees": n_trees,
"objective": str(best_estimator.objective),
},
)
# Backup the old file
backup_file = state_file.with_suffix(".nc.backup")
if not backup_file.exists():
print(f" Creating backup: {backup_file.name}")
state_file.rename(backup_file)
else:
print(" Backup already exists, removing old state file")
state_file.unlink()
# Save the fixed state
print(f" Saving fixed state to {state_file.name}")
state.to_netcdf(state_file, engine="h5netcdf")
# Verify the fix
total_importance = (
feature_importance_weight.sum().item()
+ feature_importance_gain.sum().item()
+ feature_importance_cover.sum().item()
)
if total_importance > 0:
print(f" ✓ Success! Total importance: {total_importance:.2f}")
return True
else:
print(" ✗ Warning: Total importance is still 0!")
return False
def main():
"""Find and fix all XGBoost model state files."""
print("Scanning for XGBoost model results...")
# Find all result directories
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
print(f"Found {len(result_dirs)} result directories")
fixed_count = 0
skipped_count = 0
failed_count = 0
for result_dir in sorted(result_dirs):
try:
success = fix_xgboost_model_state(result_dir)
if success:
fixed_count += 1
elif success is False:
# Explicitly skipped (not XGBoost)
skipped_count += 1
except Exception as e:
print(f"❌ Error processing {result_dir.name}: {e}")
failed_count += 1
print("\n" + "=" * 60)
print("Summary:")
print(f" ✓ Fixed: {fixed_count}")
print(f" ⊘ Skipped: {skipped_count}")
print(f" ✗ Failed: {failed_count}")
print("=" * 60)
if __name__ == "__main__":
main()