Enhance training analysis page with test metrics and confusion matrix

- Added a section to display test metrics for model performance on the held-out test set.
- Implemented confusion matrix visualization to analyze prediction breakdown.
- Refactored sidebar settings to streamline metric selection and improve user experience.
- Updated cross-validation statistics to compare CV performance with test metrics.
- Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully.
- Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation.
- Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
This commit is contained in:
Tobias Hölzer 2026-01-07 15:56:02 +01:00
parent 4fecac535c
commit c92e856c55
23 changed files with 1845 additions and 484 deletions

View file

@ -0,0 +1,195 @@
#!/usr/bin/env python
"""Recalculate test metrics and confusion matrix for existing training results.
This script loads previously trained models and recalculates test metrics
and confusion matrices for training runs that were completed before these
outputs were added to the training pipeline.
"""
import pickle
from pathlib import Path
import cupy as cp
import numpy as np
import toml
import torch
import xarray as xr
from sklearn import set_config
from sklearn.metrics import confusion_matrix
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.paths import RESULTS_DIR
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
set_config(array_api_dispatch=True)
def recalculate_metrics(results_dir: Path):
"""Recalculate test metrics and confusion matrix for a training result.
Args:
results_dir: Path to the results directory containing the trained model.
"""
print(f"\nProcessing: {results_dir}")
# Load the search settings to get training configuration
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
print(" ❌ Missing search_settings.toml, skipping...")
return
with open(settings_file) as f:
config = toml.load(f)
settings = config["settings"]
# Check if metrics already exist
test_metrics_file = results_dir / "test_metrics.toml"
cm_file = results_dir / "confusion_matrix.nc"
# if test_metrics_file.exists() and cm_file.exists():
# print(" ✓ Metrics already exist, skipping...")
# return
# Load the best estimator
best_model_file = results_dir / "best_estimator_model.pkl"
if not best_model_file.exists():
print(" ❌ Missing best_estimator_model.pkl, skipping...")
return
print(f" Loading best estimator from {best_model_file.name}...")
with open(best_model_file, "rb") as f:
best_estimator = pickle.load(f)
# Recreate the dataset ensemble
print(" Recreating training dataset...")
dataset_ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings.get(
"members",
[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
),
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
task = settings["task"]
model = settings["model"]
device = "torch" if model in ["espa"] else "cuda"
# Create training data
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
# Prepare test data - match training.py's approach
print(" Preparing test data...")
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
y_test = (
training_data.y.test.get()
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
# Compute predictions on the test set (use original device data)
print(" Computing predictions on test set...")
y_pred = best_estimator.predict(training_data.X.test)
# Use torch
y_pred = torch.as_tensor(y_pred, device="cuda")
y_test = torch.as_tensor(y_test, device="cuda")
# Compute metrics manually to avoid device issues
print(" Computing test metrics...")
from sklearn.metrics import (
accuracy_score,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
test_metrics = {}
if task == "binary":
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["recall"] = float(recall_score(y_test, y_pred))
test_metrics["precision"] = float(precision_score(y_test, y_pred))
test_metrics["f1"] = float(f1_score(y_test, y_pred))
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
else:
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
# Get confusion matrix
print(" Computing confusion matrix...")
labels = list(range(len(training_data.y.labels)))
labels = torch.as_tensor(np.array(labels), device="cuda")
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
print(" Device of labels:", getattr(labels, "device", "cpu"))
cm = confusion_matrix(y_test, y_pred, labels=labels)
cm = cm.cpu().numpy()
labels = labels.cpu().numpy()
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
# Store the test metrics
if not test_metrics_file.exists():
print(f" Storing test metrics to {test_metrics_file.name}...")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
else:
print(" ✓ Test metrics already exist")
# Store the confusion matrix
if True:
# if not cm_file.exists():
print(f" Storing confusion matrix to {cm_file.name}...")
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
else:
print(" ✓ Confusion matrix already exists")
print(" ✓ Done!")
def main():
"""Find all training results and recalculate metrics for those missing them."""
print("Searching for training results directories...")
# Find all results directories
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
print(f"Found {len(results_dirs)} results directories.\n")
for results_dir in results_dirs:
recalculate_metrics(results_dir)
# try:
# except Exception as e:
# print(f" ❌ Error processing {results_dir.name}: {e}")
# continue
print("\n✅ All done!")
if __name__ == "__main__":
main()