192 lines
6.6 KiB
Python
192 lines
6.6 KiB
Python
|
|
"""Fix XGBoost feature importance in existing model state files.
|
||
|
|
|
||
|
|
This script repairs XGBoost model state files that have all-zero feature importance
|
||
|
|
values due to incorrect feature name lookup. It reloads the pickled models and
|
||
|
|
regenerates the feature importance arrays with the correct feature index mapping.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pickle
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import toml
|
||
|
|
import xarray as xr
|
||
|
|
from rich import print
|
||
|
|
|
||
|
|
from entropice.paths import RESULTS_DIR
|
||
|
|
|
||
|
|
|
||
|
|
def fix_xgboost_model_state(results_dir: Path) -> bool:
|
||
|
|
"""Fix a single XGBoost model state file.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
results_dir: Directory containing the model files.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if fixed successfully, False otherwise.
|
||
|
|
|
||
|
|
"""
|
||
|
|
# Check if this is an XGBoost model
|
||
|
|
settings_file = results_dir / "search_settings.toml"
|
||
|
|
if not settings_file.exists():
|
||
|
|
return False
|
||
|
|
|
||
|
|
settings = toml.load(settings_file)
|
||
|
|
model_type = settings.get("settings", {}).get("model", "")
|
||
|
|
|
||
|
|
if model_type != "xgboost":
|
||
|
|
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Check if required files exist
|
||
|
|
model_file = results_dir / "best_estimator_model.pkl"
|
||
|
|
state_file = results_dir / "best_estimator_state.nc"
|
||
|
|
|
||
|
|
if not model_file.exists():
|
||
|
|
print(f"⚠️ Missing model file in {results_dir.name}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not state_file.exists():
|
||
|
|
print(f"⚠️ Missing state file in {results_dir.name}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Load the pickled model
|
||
|
|
print(f"Loading model from {results_dir.name}...")
|
||
|
|
with open(model_file, "rb") as f:
|
||
|
|
best_estimator = pickle.load(f)
|
||
|
|
|
||
|
|
# Load the old state to get feature names
|
||
|
|
old_state = xr.open_dataset(state_file, engine="h5netcdf")
|
||
|
|
features = old_state.coords["feature"].values.tolist()
|
||
|
|
old_state.close()
|
||
|
|
|
||
|
|
# Get the booster and extract feature importance with correct mapping
|
||
|
|
booster = best_estimator.get_booster()
|
||
|
|
|
||
|
|
importance_weight = booster.get_score(importance_type="weight")
|
||
|
|
importance_gain = booster.get_score(importance_type="gain")
|
||
|
|
importance_cover = booster.get_score(importance_type="cover")
|
||
|
|
importance_total_gain = booster.get_score(importance_type="total_gain")
|
||
|
|
importance_total_cover = booster.get_score(importance_type="total_cover")
|
||
|
|
|
||
|
|
# Align importance using feature indices (f0, f1, ...)
|
||
|
|
def align_importance(importance_dict, features):
|
||
|
|
"""Align importance dict to feature list using feature indices."""
|
||
|
|
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
||
|
|
|
||
|
|
# Create new DataArrays
|
||
|
|
feature_importance_weight = xr.DataArray(
|
||
|
|
align_importance(importance_weight, features),
|
||
|
|
dims=["feature"],
|
||
|
|
coords={"feature": features},
|
||
|
|
name="feature_importance_weight",
|
||
|
|
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
||
|
|
)
|
||
|
|
feature_importance_gain = xr.DataArray(
|
||
|
|
align_importance(importance_gain, features),
|
||
|
|
dims=["feature"],
|
||
|
|
coords={"feature": features},
|
||
|
|
name="feature_importance_gain",
|
||
|
|
attrs={"description": "Average gain across all splits the feature is used in."},
|
||
|
|
)
|
||
|
|
feature_importance_cover = xr.DataArray(
|
||
|
|
align_importance(importance_cover, features),
|
||
|
|
dims=["feature"],
|
||
|
|
coords={"feature": features},
|
||
|
|
name="feature_importance_cover",
|
||
|
|
attrs={"description": "Average coverage across all splits the feature is used in."},
|
||
|
|
)
|
||
|
|
feature_importance_total_gain = xr.DataArray(
|
||
|
|
align_importance(importance_total_gain, features),
|
||
|
|
dims=["feature"],
|
||
|
|
coords={"feature": features},
|
||
|
|
name="feature_importance_total_gain",
|
||
|
|
attrs={"description": "Total gain across all splits the feature is used in."},
|
||
|
|
)
|
||
|
|
feature_importance_total_cover = xr.DataArray(
|
||
|
|
align_importance(importance_total_cover, features),
|
||
|
|
dims=["feature"],
|
||
|
|
coords={"feature": features},
|
||
|
|
name="feature_importance_total_cover",
|
||
|
|
attrs={"description": "Total coverage across all splits the feature is used in."},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create new state dataset
|
||
|
|
n_trees = booster.num_boosted_rounds()
|
||
|
|
state = xr.Dataset(
|
||
|
|
{
|
||
|
|
"feature_importance_weight": feature_importance_weight,
|
||
|
|
"feature_importance_gain": feature_importance_gain,
|
||
|
|
"feature_importance_cover": feature_importance_cover,
|
||
|
|
"feature_importance_total_gain": feature_importance_total_gain,
|
||
|
|
"feature_importance_total_cover": feature_importance_total_cover,
|
||
|
|
},
|
||
|
|
attrs={
|
||
|
|
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
||
|
|
"n_trees": n_trees,
|
||
|
|
"objective": str(best_estimator.objective),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Backup the old file
|
||
|
|
backup_file = state_file.with_suffix(".nc.backup")
|
||
|
|
if not backup_file.exists():
|
||
|
|
print(f" Creating backup: {backup_file.name}")
|
||
|
|
state_file.rename(backup_file)
|
||
|
|
else:
|
||
|
|
print(" Backup already exists, removing old state file")
|
||
|
|
state_file.unlink()
|
||
|
|
|
||
|
|
# Save the fixed state
|
||
|
|
print(f" Saving fixed state to {state_file.name}")
|
||
|
|
state.to_netcdf(state_file, engine="h5netcdf")
|
||
|
|
|
||
|
|
# Verify the fix
|
||
|
|
total_importance = (
|
||
|
|
feature_importance_weight.sum().item()
|
||
|
|
+ feature_importance_gain.sum().item()
|
||
|
|
+ feature_importance_cover.sum().item()
|
||
|
|
)
|
||
|
|
|
||
|
|
if total_importance > 0:
|
||
|
|
print(f" ✓ Success! Total importance: {total_importance:.2f}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(" ✗ Warning: Total importance is still 0!")
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Find and fix all XGBoost model state files."""
|
||
|
|
print("Scanning for XGBoost model results...")
|
||
|
|
|
||
|
|
# Find all result directories
|
||
|
|
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
|
||
|
|
print(f"Found {len(result_dirs)} result directories")
|
||
|
|
|
||
|
|
fixed_count = 0
|
||
|
|
skipped_count = 0
|
||
|
|
failed_count = 0
|
||
|
|
|
||
|
|
for result_dir in sorted(result_dirs):
|
||
|
|
try:
|
||
|
|
success = fix_xgboost_model_state(result_dir)
|
||
|
|
if success:
|
||
|
|
fixed_count += 1
|
||
|
|
elif success is False:
|
||
|
|
# Explicitly skipped (not XGBoost)
|
||
|
|
skipped_count += 1
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error processing {result_dir.name}: {e}")
|
||
|
|
failed_count += 1
|
||
|
|
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
print("Summary:")
|
||
|
|
print(f" ✓ Fixed: {fixed_count}")
|
||
|
|
print(f" ⊘ Skipped: {skipped_count}")
|
||
|
|
print(f" ✗ Failed: {failed_count}")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|