Enhance training analysis page with test metrics and confusion matrix

- Added a section to display test metrics for model performance on the held-out test set.
- Implemented confusion matrix visualization to analyze prediction breakdown.
- Refactored sidebar settings to streamline metric selection and improve user experience.
- Updated cross-validation statistics to compare CV performance with test metrics.
- Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully.
- Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation.
- Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
This commit is contained in:
Tobias Hölzer 2026-01-07 15:56:02 +01:00
parent 4fecac535c
commit c92e856c55
23 changed files with 1845 additions and 484 deletions

View file

@ -0,0 +1,195 @@
#!/usr/bin/env python
"""Recalculate test metrics and confusion matrix for existing training results.
This script loads previously trained models and recalculates test metrics
and confusion matrices for training runs that were completed before these
outputs were added to the training pipeline.
"""
import pickle
from pathlib import Path
import cupy as cp
import numpy as np
import toml
import torch
import xarray as xr
from sklearn import set_config
from sklearn.metrics import confusion_matrix
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.paths import RESULTS_DIR
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
set_config(array_api_dispatch=True)
def recalculate_metrics(results_dir: Path):
"""Recalculate test metrics and confusion matrix for a training result.
Args:
results_dir: Path to the results directory containing the trained model.
"""
print(f"\nProcessing: {results_dir}")
# Load the search settings to get training configuration
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
print(" ❌ Missing search_settings.toml, skipping...")
return
with open(settings_file) as f:
config = toml.load(f)
settings = config["settings"]
# Check if metrics already exist
test_metrics_file = results_dir / "test_metrics.toml"
cm_file = results_dir / "confusion_matrix.nc"
# if test_metrics_file.exists() and cm_file.exists():
# print(" ✓ Metrics already exist, skipping...")
# return
# Load the best estimator
best_model_file = results_dir / "best_estimator_model.pkl"
if not best_model_file.exists():
print(" ❌ Missing best_estimator_model.pkl, skipping...")
return
print(f" Loading best estimator from {best_model_file.name}...")
with open(best_model_file, "rb") as f:
best_estimator = pickle.load(f)
# Recreate the dataset ensemble
print(" Recreating training dataset...")
dataset_ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings.get(
"members",
[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
),
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
task = settings["task"]
model = settings["model"]
device = "torch" if model in ["espa"] else "cuda"
# Create training data
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
# Prepare test data - match training.py's approach
print(" Preparing test data...")
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
y_test = (
training_data.y.test.get()
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
# Compute predictions on the test set (use original device data)
print(" Computing predictions on test set...")
y_pred = best_estimator.predict(training_data.X.test)
# Use torch
y_pred = torch.as_tensor(y_pred, device="cuda")
y_test = torch.as_tensor(y_test, device="cuda")
# Compute metrics manually to avoid device issues
print(" Computing test metrics...")
from sklearn.metrics import (
accuracy_score,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
test_metrics = {}
if task == "binary":
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["recall"] = float(recall_score(y_test, y_pred))
test_metrics["precision"] = float(precision_score(y_test, y_pred))
test_metrics["f1"] = float(f1_score(y_test, y_pred))
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
else:
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
# Get confusion matrix
print(" Computing confusion matrix...")
labels = list(range(len(training_data.y.labels)))
labels = torch.as_tensor(np.array(labels), device="cuda")
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
print(" Device of labels:", getattr(labels, "device", "cpu"))
cm = confusion_matrix(y_test, y_pred, labels=labels)
cm = cm.cpu().numpy()
labels = labels.cpu().numpy()
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
# Store the test metrics
if not test_metrics_file.exists():
print(f" Storing test metrics to {test_metrics_file.name}...")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
else:
print(" ✓ Test metrics already exist")
# Store the confusion matrix
if True:
# if not cm_file.exists():
print(f" Storing confusion matrix to {cm_file.name}...")
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
else:
print(" ✓ Confusion matrix already exists")
print(" ✓ Done!")
def main():
"""Find all training results and recalculate metrics for those missing them."""
print("Searching for training results directories...")
# Find all results directories
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
print(f"Found {len(results_dirs)} results directories.\n")
for results_dir in results_dirs:
recalculate_metrics(results_dir)
# try:
# except Exception as e:
# print(f" ❌ Error processing {results_dir.name}: {e}")
# continue
print("\n✅ All done!")
if __name__ == "__main__":
main()

58
scripts/rechunk_zarr.py Normal file
View file

@ -0,0 +1,58 @@
import xarray as xr
import zarr
from rich import print
import dask.distributed as dd
from entropice.utils.paths import get_era5_stores
import entropice.utils.codecs
def print_info(daily_raw = None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print("=== Daily INFO ===")
print(f" Dims: {daily_raw.sizes}")
numchunks = 1
chunksizes = {}
approxchunksize = 4 # 4 Bytes = float32
for d, cs in daily_raw.chunksizes.items():
numchunks *= len(cs)
chunksizes[d] = max(cs)
approxchunksize *= max(cs)
approxchunksize /= 10e6 # MB
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
print(f" Encoding: {daily_raw.encoding}")
if show_vars:
print(" Variables:")
for var in daily_raw.data_vars:
da = daily_raw[var]
print(f" {var} Encoding:")
print(da.encoding)
print("")
def rechunk():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print_info(daily_raw, False)
daily_raw = daily_raw.chunk({
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1 # Should be 3600
})
print_info(daily_raw, False)
encoding = entropice.utils.codecs.from_ds(daily_raw)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
if __name__ == "__main__":
with (
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
rechunk()
print("Done.")

View file

@ -0,0 +1,144 @@
#!/usr/bin/env python
"""Rerun inference for training results that are missing predicted probabilities.
This script searches through training result directories and identifies those that have
a trained model but are missing inference results. It then loads the model and dataset
configuration, reruns inference, and saves the results.
"""
import pickle
from pathlib import Path
import toml
from rich.console import Console
from rich.progress import track
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba
from entropice.utils.paths import RESULTS_DIR
console = Console()
def find_incomplete_trainings() -> list[Path]:
"""Find training result directories missing inference results.
Returns:
list[Path]: List of directories with trained models but missing predictions.
"""
incomplete = []
if not RESULTS_DIR.exists():
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
return incomplete
# Search for all training result directories
for result_dir in RESULTS_DIR.glob("*_cv*"):
if not result_dir.is_dir():
continue
model_file = result_dir / "best_estimator_model.pkl"
settings_file = result_dir / "search_settings.toml"
predictions_file = result_dir / "predicted_probabilities.parquet"
# Check if model and settings exist but predictions are missing
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
incomplete.append(result_dir)
return incomplete
def rerun_inference(result_dir: Path) -> bool:
"""Rerun inference for a training result directory.
Args:
result_dir (Path): Path to the training result directory.
Returns:
bool: True if successful, False otherwise.
"""
try:
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
# Load settings
settings_file = result_dir / "search_settings.toml"
with open(settings_file) as f:
settings_data = toml.load(f)
settings = settings_data["settings"]
# Reconstruct DatasetEnsemble from settings
ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings["members"],
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
# Load trained model
model_file = result_dir / "best_estimator_model.pkl"
with open(model_file, "rb") as f:
clf = pickle.load(f)
console.print("[green]✓[/green] Loaded model and settings")
# Get class labels
classes = settings["classes"]
# Run inference
console.print("[yellow]Running inference...[/yellow]")
preds = predict_proba(ensemble, clf=clf, classes=classes)
# Save predictions
preds_file = result_dir / "predicted_probabilities.parquet"
preds.to_parquet(preds_file)
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
return True
except Exception as e:
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
import traceback
console.print(f"[red]{traceback.format_exc()}[/red]")
return False
def main():
"""Rerun missing inferences for incomplete training results."""
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
incomplete_dirs = find_incomplete_trainings()
if not incomplete_dirs:
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
return
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
for d in incomplete_dirs:
console.print(f"{d.name}")
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
successful = 0
failed = 0
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
if rerun_inference(result_dir):
successful += 1
else:
failed += 1
console.print("\n[bold]Summary:[/bold]")
console.print(f" [green]Successful: {successful}[/green]")
console.print(f" [red]Failed: {failed}[/red]")
if __name__ == "__main__":
main()