entropice/scripts/fix_xgboost_importance.py
2025-12-28 20:48:49 +01:00

191 lines
6.6 KiB
Python

"""Fix XGBoost feature importance in existing model state files.
This script repairs XGBoost model state files that have all-zero feature importance
values due to incorrect feature name lookup. It reloads the pickled models and
regenerates the feature importance arrays with the correct feature index mapping.
"""
import pickle
from pathlib import Path
import toml
import xarray as xr
from rich import print
from entropice.utils.paths import RESULTS_DIR
def fix_xgboost_model_state(results_dir: Path) -> bool:
"""Fix a single XGBoost model state file.
Args:
results_dir: Directory containing the model files.
Returns:
True if fixed successfully, False otherwise.
"""
# Check if this is an XGBoost model
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
return False
settings = toml.load(settings_file)
model_type = settings.get("settings", {}).get("model", "")
if model_type != "xgboost":
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
return False
# Check if required files exist
model_file = results_dir / "best_estimator_model.pkl"
state_file = results_dir / "best_estimator_state.nc"
if not model_file.exists():
print(f"⚠️ Missing model file in {results_dir.name}")
return False
if not state_file.exists():
print(f"⚠️ Missing state file in {results_dir.name}")
return False
# Load the pickled model
print(f"Loading model from {results_dir.name}...")
with open(model_file, "rb") as f:
best_estimator = pickle.load(f)
# Load the old state to get feature names
old_state = xr.open_dataset(state_file, engine="h5netcdf")
features = old_state.coords["feature"].values.tolist()
old_state.close()
# Get the booster and extract feature importance with correct mapping
booster = best_estimator.get_booster()
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
importance_total_gain = booster.get_score(importance_type="total_gain")
importance_total_cover = booster.get_score(importance_type="total_cover")
# Align importance using feature indices (f0, f1, ...)
def align_importance(importance_dict, features):
"""Align importance dict to feature list using feature indices."""
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
# Create new DataArrays
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_weight",
attrs={"description": "Number of times a feature is used to split the data across all trees."},
)
feature_importance_gain = xr.DataArray(
align_importance(importance_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_gain",
attrs={"description": "Average gain across all splits the feature is used in."},
)
feature_importance_cover = xr.DataArray(
align_importance(importance_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_cover",
attrs={"description": "Average coverage across all splits the feature is used in."},
)
feature_importance_total_gain = xr.DataArray(
align_importance(importance_total_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_gain",
attrs={"description": "Total gain across all splits the feature is used in."},
)
feature_importance_total_cover = xr.DataArray(
align_importance(importance_total_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_cover",
attrs={"description": "Total coverage across all splits the feature is used in."},
)
# Create new state dataset
n_trees = booster.num_boosted_rounds()
state = xr.Dataset(
{
"feature_importance_weight": feature_importance_weight,
"feature_importance_gain": feature_importance_gain,
"feature_importance_cover": feature_importance_cover,
"feature_importance_total_gain": feature_importance_total_gain,
"feature_importance_total_cover": feature_importance_total_cover,
},
attrs={
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
"n_trees": n_trees,
"objective": str(best_estimator.objective),
},
)
# Backup the old file
backup_file = state_file.with_suffix(".nc.backup")
if not backup_file.exists():
print(f" Creating backup: {backup_file.name}")
state_file.rename(backup_file)
else:
print(" Backup already exists, removing old state file")
state_file.unlink()
# Save the fixed state
print(f" Saving fixed state to {state_file.name}")
state.to_netcdf(state_file, engine="h5netcdf")
# Verify the fix
total_importance = (
feature_importance_weight.sum().item()
+ feature_importance_gain.sum().item()
+ feature_importance_cover.sum().item()
)
if total_importance > 0:
print(f" ✓ Success! Total importance: {total_importance:.2f}")
return True
else:
print(" ✗ Warning: Total importance is still 0!")
return False
def main():
"""Find and fix all XGBoost model state files."""
print("Scanning for XGBoost model results...")
# Find all result directories
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
print(f"Found {len(result_dirs)} result directories")
fixed_count = 0
skipped_count = 0
failed_count = 0
for result_dir in sorted(result_dirs):
try:
success = fix_xgboost_model_state(result_dir)
if success:
fixed_count += 1
elif success is False:
# Explicitly skipped (not XGBoost)
skipped_count += 1
except Exception as e:
print(f"❌ Error processing {result_dir.name}: {e}")
failed_count += 1
print("\n" + "=" * 60)
print("Summary:")
print(f" ✓ Fixed: {fixed_count}")
print(f" ⊘ Skipped: {skipped_count}")
print(f" ✗ Failed: {failed_count}")
print("=" * 60)
if __name__ == "__main__":
main()