#!/usr/bin/env python """Recalculate test metrics and confusion matrix for existing training results. This script loads previously trained models and recalculates test metrics and confusion matrices for training runs that were completed before these outputs were added to the training pipeline. """ import pickle from pathlib import Path import cupy as cp import numpy as np import toml import torch import xarray as xr from sklearn import set_config from sklearn.metrics import confusion_matrix from entropice.ml.dataset import DatasetEnsemble from entropice.utils.paths import RESULTS_DIR # Enable array_api_dispatch to handle CuPy/NumPy namespace properly set_config(array_api_dispatch=True) def recalculate_metrics(results_dir: Path): """Recalculate test metrics and confusion matrix for a training result. Args: results_dir: Path to the results directory containing the trained model. """ print(f"\nProcessing: {results_dir}") # Load the search settings to get training configuration settings_file = results_dir / "search_settings.toml" if not settings_file.exists(): print(" ❌ Missing search_settings.toml, skipping...") return with open(settings_file) as f: config = toml.load(f) settings = config["settings"] # Check if metrics already exist test_metrics_file = results_dir / "test_metrics.toml" cm_file = results_dir / "confusion_matrix.nc" # if test_metrics_file.exists() and cm_file.exists(): # print(" ✓ Metrics already exist, skipping...") # return # Load the best estimator best_model_file = results_dir / "best_estimator_model.pkl" if not best_model_file.exists(): print(" ❌ Missing best_estimator_model.pkl, skipping...") return print(f" Loading best estimator from {best_model_file.name}...") with open(best_model_file, "rb") as f: best_estimator = pickle.load(f) # Recreate the dataset ensemble print(" Recreating training dataset...") dataset_ensemble = DatasetEnsemble( grid=settings["grid"], level=settings["level"], target=settings["target"], members=settings.get( "members", [ "AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder", ], ), dimension_filters=settings.get("dimension_filters", {}), variable_filters=settings.get("variable_filters", {}), filter_target=settings.get("filter_target", False), add_lonlat=settings.get("add_lonlat", True), ) task = settings["task"] model = settings["model"] device = "torch" if model in ["espa"] else "cuda" # Create training data training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device) # Prepare test data - match training.py's approach print(" Preparing test data...") # For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py) y_test = ( training_data.y.test.get() if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray) else training_data.y.test ) # Compute predictions on the test set (use original device data) print(" Computing predictions on test set...") y_pred = best_estimator.predict(training_data.X.test) # Use torch y_pred = torch.as_tensor(y_pred, device="cuda") y_test = torch.as_tensor(y_test, device="cuda") # Compute metrics manually to avoid device issues print(" Computing test metrics...") from sklearn.metrics import ( accuracy_score, f1_score, jaccard_score, precision_score, recall_score, ) test_metrics = {} if task == "binary": test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred)) test_metrics["recall"] = float(recall_score(y_test, y_pred)) test_metrics["precision"] = float(precision_score(y_test, y_pred)) test_metrics["f1"] = float(f1_score(y_test, y_pred)) test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred)) else: test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred)) test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro")) test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted")) test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0)) test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0)) test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro")) test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro")) test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro")) test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted")) # Get confusion matrix print(" Computing confusion matrix...") labels = list(range(len(training_data.y.labels))) labels = torch.as_tensor(np.array(labels), device="cuda") print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu")) print(" Device of y_pred:", getattr(y_pred, "device", "cpu")) print(" Device of labels:", getattr(labels, "device", "cpu")) cm = confusion_matrix(y_test, y_pred, labels=labels) cm = cm.cpu().numpy() labels = labels.cpu().numpy() label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))] cm_xr = xr.DataArray( cm, dims=["true_label", "predicted_label"], coords={"true_label": label_names, "predicted_label": label_names}, name="confusion_matrix", ) # Store the test metrics if not test_metrics_file.exists(): print(f" Storing test metrics to {test_metrics_file.name}...") with open(test_metrics_file, "w") as f: toml.dump({"test_metrics": test_metrics}, f) else: print(" ✓ Test metrics already exist") # Store the confusion matrix if True: # if not cm_file.exists(): print(f" Storing confusion matrix to {cm_file.name}...") cm_xr.to_netcdf(cm_file, engine="h5netcdf") else: print(" ✓ Confusion matrix already exists") print(" ✓ Done!") def main(): """Find all training results and recalculate metrics for those missing them.""" print("Searching for training results directories...") # Find all results directories results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()]) print(f"Found {len(results_dirs)} results directories.\n") for results_dir in results_dirs: recalculate_metrics(results_dir) # try: # except Exception as e: # print(f" ❌ Error processing {results_dir.name}: {e}") # continue print("\n✅ All done!") if __name__ == "__main__": main()