"""Fix XGBoost feature importance in existing model state files. This script repairs XGBoost model state files that have all-zero feature importance values due to incorrect feature name lookup. It reloads the pickled models and regenerates the feature importance arrays with the correct feature index mapping. """ import pickle from pathlib import Path import toml import xarray as xr from rich import print from entropice.paths import RESULTS_DIR def fix_xgboost_model_state(results_dir: Path) -> bool: """Fix a single XGBoost model state file. Args: results_dir: Directory containing the model files. Returns: True if fixed successfully, False otherwise. """ # Check if this is an XGBoost model settings_file = results_dir / "search_settings.toml" if not settings_file.exists(): return False settings = toml.load(settings_file) model_type = settings.get("settings", {}).get("model", "") if model_type != "xgboost": print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})") return False # Check if required files exist model_file = results_dir / "best_estimator_model.pkl" state_file = results_dir / "best_estimator_state.nc" if not model_file.exists(): print(f"⚠️ Missing model file in {results_dir.name}") return False if not state_file.exists(): print(f"⚠️ Missing state file in {results_dir.name}") return False # Load the pickled model print(f"Loading model from {results_dir.name}...") with open(model_file, "rb") as f: best_estimator = pickle.load(f) # Load the old state to get feature names old_state = xr.open_dataset(state_file, engine="h5netcdf") features = old_state.coords["feature"].values.tolist() old_state.close() # Get the booster and extract feature importance with correct mapping booster = best_estimator.get_booster() importance_weight = booster.get_score(importance_type="weight") importance_gain = booster.get_score(importance_type="gain") importance_cover = booster.get_score(importance_type="cover") importance_total_gain = booster.get_score(importance_type="total_gain") importance_total_cover = booster.get_score(importance_type="total_cover") # Align importance using feature indices (f0, f1, ...) def align_importance(importance_dict, features): """Align importance dict to feature list using feature indices.""" return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))] # Create new DataArrays feature_importance_weight = xr.DataArray( align_importance(importance_weight, features), dims=["feature"], coords={"feature": features}, name="feature_importance_weight", attrs={"description": "Number of times a feature is used to split the data across all trees."}, ) feature_importance_gain = xr.DataArray( align_importance(importance_gain, features), dims=["feature"], coords={"feature": features}, name="feature_importance_gain", attrs={"description": "Average gain across all splits the feature is used in."}, ) feature_importance_cover = xr.DataArray( align_importance(importance_cover, features), dims=["feature"], coords={"feature": features}, name="feature_importance_cover", attrs={"description": "Average coverage across all splits the feature is used in."}, ) feature_importance_total_gain = xr.DataArray( align_importance(importance_total_gain, features), dims=["feature"], coords={"feature": features}, name="feature_importance_total_gain", attrs={"description": "Total gain across all splits the feature is used in."}, ) feature_importance_total_cover = xr.DataArray( align_importance(importance_total_cover, features), dims=["feature"], coords={"feature": features}, name="feature_importance_total_cover", attrs={"description": "Total coverage across all splits the feature is used in."}, ) # Create new state dataset n_trees = booster.num_boosted_rounds() state = xr.Dataset( { "feature_importance_weight": feature_importance_weight, "feature_importance_gain": feature_importance_gain, "feature_importance_cover": feature_importance_cover, "feature_importance_total_gain": feature_importance_total_gain, "feature_importance_total_cover": feature_importance_total_cover, }, attrs={ "description": "Inner state of the best XGBClassifier from RandomizedSearchCV.", "n_trees": n_trees, "objective": str(best_estimator.objective), }, ) # Backup the old file backup_file = state_file.with_suffix(".nc.backup") if not backup_file.exists(): print(f" Creating backup: {backup_file.name}") state_file.rename(backup_file) else: print(" Backup already exists, removing old state file") state_file.unlink() # Save the fixed state print(f" Saving fixed state to {state_file.name}") state.to_netcdf(state_file, engine="h5netcdf") # Verify the fix total_importance = ( feature_importance_weight.sum().item() + feature_importance_gain.sum().item() + feature_importance_cover.sum().item() ) if total_importance > 0: print(f" ✓ Success! Total importance: {total_importance:.2f}") return True else: print(" ✗ Warning: Total importance is still 0!") return False def main(): """Find and fix all XGBoost model state files.""" print("Scanning for XGBoost model results...") # Find all result directories result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()] print(f"Found {len(result_dirs)} result directories") fixed_count = 0 skipped_count = 0 failed_count = 0 for result_dir in sorted(result_dirs): try: success = fix_xgboost_model_state(result_dir) if success: fixed_count += 1 elif success is False: # Explicitly skipped (not XGBoost) skipped_count += 1 except Exception as e: print(f"❌ Error processing {result_dir.name}: {e}") failed_count += 1 print("\n" + "=" * 60) print("Summary:") print(f" ✓ Fixed: {fixed_count}") print(f" ⊘ Skipped: {skipped_count}") print(f" ✗ Failed: {failed_count}") print("=" * 60) if __name__ == "__main__": main()