entropice/scripts/recalculate_test_metrics.py

196 lines
6.9 KiB
Python
Raw Normal View History

#!/usr/bin/env python
"""Recalculate test metrics and confusion matrix for existing training results.
This script loads previously trained models and recalculates test metrics
and confusion matrices for training runs that were completed before these
outputs were added to the training pipeline.
"""
import pickle
from pathlib import Path
import cupy as cp
import numpy as np
import toml
import torch
import xarray as xr
from sklearn import set_config
from sklearn.metrics import confusion_matrix
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.paths import RESULTS_DIR
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
set_config(array_api_dispatch=True)
def recalculate_metrics(results_dir: Path):
"""Recalculate test metrics and confusion matrix for a training result.
Args:
results_dir: Path to the results directory containing the trained model.
"""
print(f"\nProcessing: {results_dir}")
# Load the search settings to get training configuration
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
print(" ❌ Missing search_settings.toml, skipping...")
return
with open(settings_file) as f:
config = toml.load(f)
settings = config["settings"]
# Check if metrics already exist
test_metrics_file = results_dir / "test_metrics.toml"
cm_file = results_dir / "confusion_matrix.nc"
# if test_metrics_file.exists() and cm_file.exists():
# print(" ✓ Metrics already exist, skipping...")
# return
# Load the best estimator
best_model_file = results_dir / "best_estimator_model.pkl"
if not best_model_file.exists():
print(" ❌ Missing best_estimator_model.pkl, skipping...")
return
print(f" Loading best estimator from {best_model_file.name}...")
with open(best_model_file, "rb") as f:
best_estimator = pickle.load(f)
# Recreate the dataset ensemble
print(" Recreating training dataset...")
dataset_ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings.get(
"members",
[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
),
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
task = settings["task"]
model = settings["model"]
device = "torch" if model in ["espa"] else "cuda"
# Create training data
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
# Prepare test data - match training.py's approach
print(" Preparing test data...")
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
y_test = (
training_data.y.test.get()
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
# Compute predictions on the test set (use original device data)
print(" Computing predictions on test set...")
y_pred = best_estimator.predict(training_data.X.test)
# Use torch
y_pred = torch.as_tensor(y_pred, device="cuda")
y_test = torch.as_tensor(y_test, device="cuda")
# Compute metrics manually to avoid device issues
print(" Computing test metrics...")
from sklearn.metrics import (
accuracy_score,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
test_metrics = {}
if task == "binary":
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["recall"] = float(recall_score(y_test, y_pred))
test_metrics["precision"] = float(precision_score(y_test, y_pred))
test_metrics["f1"] = float(f1_score(y_test, y_pred))
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
else:
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
# Get confusion matrix
print(" Computing confusion matrix...")
labels = list(range(len(training_data.y.labels)))
labels = torch.as_tensor(np.array(labels), device="cuda")
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
print(" Device of labels:", getattr(labels, "device", "cpu"))
cm = confusion_matrix(y_test, y_pred, labels=labels)
cm = cm.cpu().numpy()
labels = labels.cpu().numpy()
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
# Store the test metrics
if not test_metrics_file.exists():
print(f" Storing test metrics to {test_metrics_file.name}...")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
else:
print(" ✓ Test metrics already exist")
# Store the confusion matrix
if True:
# if not cm_file.exists():
print(f" Storing confusion matrix to {cm_file.name}...")
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
else:
print(" ✓ Confusion matrix already exists")
print(" ✓ Done!")
def main():
"""Find all training results and recalculate metrics for those missing them."""
print("Searching for training results directories...")
# Find all results directories
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
print(f"Found {len(results_dirs)} results directories.\n")
for results_dir in results_dirs:
recalculate_metrics(results_dir)
# try:
# except Exception as e:
# print(f" ❌ Error processing {results_dir.name}: {e}")
# continue
print("\n✅ All done!")
if __name__ == "__main__":
main()