Make the Model State Page great again
This commit is contained in:
parent
591da6992e
commit
1919cc6a7e
13 changed files with 1375 additions and 142 deletions
191
scripts/fix_xgboost_importance.py
Normal file
191
scripts/fix_xgboost_importance.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Fix XGBoost feature importance in existing model state files.
|
||||
|
||||
This script repairs XGBoost model state files that have all-zero feature importance
|
||||
values due to incorrect feature name lookup. It reloads the pickled models and
|
||||
regenerates the feature importance arrays with the correct feature index mapping.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
import xarray as xr
|
||||
from rich import print
|
||||
|
||||
from entropice.paths import RESULTS_DIR
|
||||
|
||||
|
||||
def fix_xgboost_model_state(results_dir: Path) -> bool:
|
||||
"""Fix a single XGBoost model state file.
|
||||
|
||||
Args:
|
||||
results_dir: Directory containing the model files.
|
||||
|
||||
Returns:
|
||||
True if fixed successfully, False otherwise.
|
||||
|
||||
"""
|
||||
# Check if this is an XGBoost model
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
if not settings_file.exists():
|
||||
return False
|
||||
|
||||
settings = toml.load(settings_file)
|
||||
model_type = settings.get("settings", {}).get("model", "")
|
||||
|
||||
if model_type != "xgboost":
|
||||
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
|
||||
return False
|
||||
|
||||
# Check if required files exist
|
||||
model_file = results_dir / "best_estimator_model.pkl"
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
|
||||
if not model_file.exists():
|
||||
print(f"⚠️ Missing model file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
if not state_file.exists():
|
||||
print(f"⚠️ Missing state file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
# Load the pickled model
|
||||
print(f"Loading model from {results_dir.name}...")
|
||||
with open(model_file, "rb") as f:
|
||||
best_estimator = pickle.load(f)
|
||||
|
||||
# Load the old state to get feature names
|
||||
old_state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
features = old_state.coords["feature"].values.tolist()
|
||||
old_state.close()
|
||||
|
||||
# Get the booster and extract feature importance with correct mapping
|
||||
booster = best_estimator.get_booster()
|
||||
|
||||
importance_weight = booster.get_score(importance_type="weight")
|
||||
importance_gain = booster.get_score(importance_type="gain")
|
||||
importance_cover = booster.get_score(importance_type="cover")
|
||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
||||
|
||||
# Align importance using feature indices (f0, f1, ...)
|
||||
def align_importance(importance_dict, features):
|
||||
"""Align importance dict to feature list using feature indices."""
|
||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
||||
|
||||
# Create new DataArrays
|
||||
feature_importance_weight = xr.DataArray(
|
||||
align_importance(importance_weight, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_weight",
|
||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
||||
)
|
||||
feature_importance_gain = xr.DataArray(
|
||||
align_importance(importance_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_gain",
|
||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_cover = xr.DataArray(
|
||||
align_importance(importance_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_cover",
|
||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_gain = xr.DataArray(
|
||||
align_importance(importance_total_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_gain",
|
||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_cover = xr.DataArray(
|
||||
align_importance(importance_total_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_cover",
|
||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
||||
)
|
||||
|
||||
# Create new state dataset
|
||||
n_trees = booster.num_boosted_rounds()
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"feature_importance_weight": feature_importance_weight,
|
||||
"feature_importance_gain": feature_importance_gain,
|
||||
"feature_importance_cover": feature_importance_cover,
|
||||
"feature_importance_total_gain": feature_importance_total_gain,
|
||||
"feature_importance_total_cover": feature_importance_total_cover,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
||||
"n_trees": n_trees,
|
||||
"objective": str(best_estimator.objective),
|
||||
},
|
||||
)
|
||||
|
||||
# Backup the old file
|
||||
backup_file = state_file.with_suffix(".nc.backup")
|
||||
if not backup_file.exists():
|
||||
print(f" Creating backup: {backup_file.name}")
|
||||
state_file.rename(backup_file)
|
||||
else:
|
||||
print(" Backup already exists, removing old state file")
|
||||
state_file.unlink()
|
||||
|
||||
# Save the fixed state
|
||||
print(f" Saving fixed state to {state_file.name}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Verify the fix
|
||||
total_importance = (
|
||||
feature_importance_weight.sum().item()
|
||||
+ feature_importance_gain.sum().item()
|
||||
+ feature_importance_cover.sum().item()
|
||||
)
|
||||
|
||||
if total_importance > 0:
|
||||
print(f" ✓ Success! Total importance: {total_importance:.2f}")
|
||||
return True
|
||||
else:
|
||||
print(" ✗ Warning: Total importance is still 0!")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Find and fix all XGBoost model state files."""
|
||||
print("Scanning for XGBoost model results...")
|
||||
|
||||
# Find all result directories
|
||||
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
|
||||
print(f"Found {len(result_dirs)} result directories")
|
||||
|
||||
fixed_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for result_dir in sorted(result_dirs):
|
||||
try:
|
||||
success = fix_xgboost_model_state(result_dir)
|
||||
if success:
|
||||
fixed_count += 1
|
||||
elif success is False:
|
||||
# Explicitly skipped (not XGBoost)
|
||||
skipped_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {result_dir.name}: {e}")
|
||||
failed_count += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(f" ✓ Fixed: {fixed_count}")
|
||||
print(f" ⊘ Skipped: {skipped_count}")
|
||||
print(f" ✗ Failed: {failed_count}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue