Fix training and finalize dataset page

This commit is contained in:
Tobias Hölzer 2026-01-18 20:16:47 +01:00
parent c358bb63bc
commit 636c034b55
30 changed files with 533 additions and 851 deletions

View file

@ -19,3 +19,13 @@ pixi run alpha-earth combine-to-zarr --grid healpix --level 7
pixi run alpha-earth combine-to-zarr --grid healpix --level 8 pixi run alpha-earth combine-to-zarr --grid healpix --level 8
pixi run alpha-earth combine-to-zarr --grid healpix --level 9 pixi run alpha-earth combine-to-zarr --grid healpix --level 9
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10 # pixi run alpha-earth combine-to-zarr --grid healpix --level 10
pixi run alpha-earth compute-synopsis --grid hex --level 3
pixi run alpha-earth compute-synopsis --grid hex --level 4
pixi run alpha-earth compute-synopsis --grid hex --level 5
pixi run alpha-earth compute-synopsis --grid hex --level 6
pixi run alpha-earth compute-synopsis --grid healpix --level 6
pixi run alpha-earth compute-synopsis --grid healpix --level 7
pixi run alpha-earth compute-synopsis --grid healpix --level 8
pixi run alpha-earth compute-synopsis --grid healpix --level 9
pixi run alpha-earth compute-synopsis --grid healpix --level 10

View file

@ -3,13 +3,24 @@
# pixi run era5 download # pixi run era5 download
# pixi run era5 enrich # pixi run era5 enrich
pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20 # pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20
pixi run era5 compute-synopsis --grid hex --level 3
pixi run era5 compute-synopsis --grid hex --level 4
pixi run era5 compute-synopsis --grid hex --level 5
pixi run era5 compute-synopsis --grid hex --level 6
pixi run era5 compute-synopsis --grid healpix --level 6
pixi run era5 compute-synopsis --grid healpix --level 7
pixi run era5 compute-synopsis --grid healpix --level 8
pixi run era5 compute-synopsis --grid healpix --level 9
pixi run era5 compute-synopsis --grid healpix --level 10

View file

@ -0,0 +1,37 @@
#!/bin/bash
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
# Check if running inside the pixi environment
which autogluon >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "This script must be run inside the pixi environment."
exit 1
fi
for grid in hex healpix; do
if [ "$grid" = "hex" ]; then
levels=(3 4 5 6)
else
levels=(6 7 8 9 10)
fi
for level in "${levels[@]}"; do
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
else
era5_dimension_filters=""
fi
for target in darts_v1 darts_mllabels; do
for task in binary density count; do
echo
echo "----------------------------------------"
echo "Running autogluon training for grid=$grid, level=$level, target=$target, task=$task"
autogluon --grid "$grid" --level "$level" --target "$target" --task "$task" --time-limit 600 --temporal-mode synopsis --experiment "static-variables-autogluon" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
echo "----------------------------------------"
echo
done
done
done
done

View file

@ -0,0 +1,49 @@
#!/bin/bash
# Check if running inside the pixi environment
which train >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "This script must be run inside the pixi environment."
exit 1
fi
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
grid="healpix"
target="darts_mllabels"
# levels=(6 7 8 9 10)
levels=(8 9 10)
for level in "${levels[@]}"; do
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
else
era5_dimension_filters=""
fi
for task in binary density count; do
for model in espa xgboost rf knn; do
# Skip if task is density or count and model is espa because espa only supports binary
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
continue
fi
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
continue
fi
# Set number of iterations (use less for slow models)
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
niter=5
else
niter=100
fi
echo
echo "----------------------------------------"
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
echo "----------------------------------------"
echo
done
done
done

View file

@ -0,0 +1,48 @@
#!/bin/bash
# Check if running inside the pixi environment
which train >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "This script must be run inside the pixi environment."
exit 1
fi
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
grid="healpix"
target="darts_v1"
#levels=(6 7 8 9 10)
levels=(8 9 10)
for level in "${levels[@]}"; do
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
else
era5_dimension_filters=""
fi
for task in binary density count; do
for model in espa xgboost rf knn; do
# Skip if task is density or count and model is espa because espa only supports binary
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
continue
fi
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
continue
fi
# Set number of iterations (use less for slow models)
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
niter=5
else
niter=100
fi
echo
echo "----------------------------------------"
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
echo "----------------------------------------"
echo
done
done
done

View file

@ -0,0 +1,48 @@
#!/bin/bash
# Check if running inside the pixi environment
which train >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "This script must be run inside the pixi environment."
exit 1
fi
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
grid="hex"
target="darts_mllabels"
# levels=(3 4 5 6)
levels=(5 6)
for level in "${levels[@]}"; do
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
else
era5_dimension_filters=""
fi
for task in binary density count; do
for model in espa xgboost rf knn; do
# Skip if task is density or count and model is espa because espa only supports binary
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
continue
fi
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
continue
fi
# Set number of iterations (use less for slow models)
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
niter=5
else
niter=100
fi
echo
echo "----------------------------------------"
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
echo "----------------------------------------"
echo
done
done
done

View file

@ -0,0 +1,48 @@
#!/bin/bash
# Check if running inside the pixi environment
which train >/dev/null 2>&1
if [ $? -ne 0 ]; then
echo "This script must be run inside the pixi environment."
exit 1
fi
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
grid="hex"
target="darts_v1"
# levels=(3 4 5 6)
levels=(5 6)
for level in "${levels[@]}"; do
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
else
era5_dimension_filters=""
fi
for task in binary density count; do
for model in espa xgboost rf knn; do
# Skip if task is density or count and model is espa because espa only supports binary
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
continue
fi
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
continue
fi
# Set number of iterations (use less for slow models)
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
niter=5
else
niter=100
fi
echo
echo "----------------------------------------"
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
echo "----------------------------------------"
echo
done
done
done

View file

@ -1,191 +0,0 @@
"""Fix XGBoost feature importance in existing model state files.
This script repairs XGBoost model state files that have all-zero feature importance
values due to incorrect feature name lookup. It reloads the pickled models and
regenerates the feature importance arrays with the correct feature index mapping.
"""
import pickle
from pathlib import Path
import toml
import xarray as xr
from rich import print
from entropice.utils.paths import RESULTS_DIR
def fix_xgboost_model_state(results_dir: Path) -> bool:
"""Fix a single XGBoost model state file.
Args:
results_dir: Directory containing the model files.
Returns:
True if fixed successfully, False otherwise.
"""
# Check if this is an XGBoost model
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
return False
settings = toml.load(settings_file)
model_type = settings.get("settings", {}).get("model", "")
if model_type != "xgboost":
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
return False
# Check if required files exist
model_file = results_dir / "best_estimator_model.pkl"
state_file = results_dir / "best_estimator_state.nc"
if not model_file.exists():
print(f"⚠️ Missing model file in {results_dir.name}")
return False
if not state_file.exists():
print(f"⚠️ Missing state file in {results_dir.name}")
return False
# Load the pickled model
print(f"Loading model from {results_dir.name}...")
with open(model_file, "rb") as f:
best_estimator = pickle.load(f)
# Load the old state to get feature names
old_state = xr.open_dataset(state_file, engine="h5netcdf")
features = old_state.coords["feature"].values.tolist()
old_state.close()
# Get the booster and extract feature importance with correct mapping
booster = best_estimator.get_booster()
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
importance_total_gain = booster.get_score(importance_type="total_gain")
importance_total_cover = booster.get_score(importance_type="total_cover")
# Align importance using feature indices (f0, f1, ...)
def align_importance(importance_dict, features):
"""Align importance dict to feature list using feature indices."""
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
# Create new DataArrays
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_weight",
attrs={"description": "Number of times a feature is used to split the data across all trees."},
)
feature_importance_gain = xr.DataArray(
align_importance(importance_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_gain",
attrs={"description": "Average gain across all splits the feature is used in."},
)
feature_importance_cover = xr.DataArray(
align_importance(importance_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_cover",
attrs={"description": "Average coverage across all splits the feature is used in."},
)
feature_importance_total_gain = xr.DataArray(
align_importance(importance_total_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_gain",
attrs={"description": "Total gain across all splits the feature is used in."},
)
feature_importance_total_cover = xr.DataArray(
align_importance(importance_total_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_cover",
attrs={"description": "Total coverage across all splits the feature is used in."},
)
# Create new state dataset
n_trees = booster.num_boosted_rounds()
state = xr.Dataset(
{
"feature_importance_weight": feature_importance_weight,
"feature_importance_gain": feature_importance_gain,
"feature_importance_cover": feature_importance_cover,
"feature_importance_total_gain": feature_importance_total_gain,
"feature_importance_total_cover": feature_importance_total_cover,
},
attrs={
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
"n_trees": n_trees,
"objective": str(best_estimator.objective),
},
)
# Backup the old file
backup_file = state_file.with_suffix(".nc.backup")
if not backup_file.exists():
print(f" Creating backup: {backup_file.name}")
state_file.rename(backup_file)
else:
print(" Backup already exists, removing old state file")
state_file.unlink()
# Save the fixed state
print(f" Saving fixed state to {state_file.name}")
state.to_netcdf(state_file, engine="h5netcdf")
# Verify the fix
total_importance = (
feature_importance_weight.sum().item()
+ feature_importance_gain.sum().item()
+ feature_importance_cover.sum().item()
)
if total_importance > 0:
print(f" ✓ Success! Total importance: {total_importance:.2f}")
return True
else:
print(" ✗ Warning: Total importance is still 0!")
return False
def main():
"""Find and fix all XGBoost model state files."""
print("Scanning for XGBoost model results...")
# Find all result directories
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
print(f"Found {len(result_dirs)} result directories")
fixed_count = 0
skipped_count = 0
failed_count = 0
for result_dir in sorted(result_dirs):
try:
success = fix_xgboost_model_state(result_dir)
if success:
fixed_count += 1
elif success is False:
# Explicitly skipped (not XGBoost)
skipped_count += 1
except Exception as e:
print(f"❌ Error processing {result_dir.name}: {e}")
failed_count += 1
print("\n" + "=" * 60)
print("Summary:")
print(f" ✓ Fixed: {fixed_count}")
print(f" ⊘ Skipped: {skipped_count}")
print(f" ✗ Failed: {failed_count}")
print("=" * 60)
if __name__ == "__main__":
main()

View file

@ -1,195 +0,0 @@
#!/usr/bin/env python
"""Recalculate test metrics and confusion matrix for existing training results.
This script loads previously trained models and recalculates test metrics
and confusion matrices for training runs that were completed before these
outputs were added to the training pipeline.
"""
import pickle
from pathlib import Path
import cupy as cp
import numpy as np
import toml
import torch
import xarray as xr
from sklearn import set_config
from sklearn.metrics import confusion_matrix
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.paths import RESULTS_DIR
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
set_config(array_api_dispatch=True)
def recalculate_metrics(results_dir: Path):
"""Recalculate test metrics and confusion matrix for a training result.
Args:
results_dir: Path to the results directory containing the trained model.
"""
print(f"\nProcessing: {results_dir}")
# Load the search settings to get training configuration
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
print(" ❌ Missing search_settings.toml, skipping...")
return
with open(settings_file) as f:
config = toml.load(f)
settings = config["settings"]
# Check if metrics already exist
test_metrics_file = results_dir / "test_metrics.toml"
cm_file = results_dir / "confusion_matrix.nc"
# if test_metrics_file.exists() and cm_file.exists():
# print(" ✓ Metrics already exist, skipping...")
# return
# Load the best estimator
best_model_file = results_dir / "best_estimator_model.pkl"
if not best_model_file.exists():
print(" ❌ Missing best_estimator_model.pkl, skipping...")
return
print(f" Loading best estimator from {best_model_file.name}...")
with open(best_model_file, "rb") as f:
best_estimator = pickle.load(f)
# Recreate the dataset ensemble
print(" Recreating training dataset...")
dataset_ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings.get(
"members",
[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
),
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
task = settings["task"]
model = settings["model"]
device = "torch" if model in ["espa"] else "cuda"
# Create training data
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
# Prepare test data - match training.py's approach
print(" Preparing test data...")
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
y_test = (
training_data.y.test.get()
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
# Compute predictions on the test set (use original device data)
print(" Computing predictions on test set...")
y_pred = best_estimator.predict(training_data.X.test)
# Use torch
y_pred = torch.as_tensor(y_pred, device="cuda")
y_test = torch.as_tensor(y_test, device="cuda")
# Compute metrics manually to avoid device issues
print(" Computing test metrics...")
from sklearn.metrics import (
accuracy_score,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
test_metrics = {}
if task == "binary":
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["recall"] = float(recall_score(y_test, y_pred))
test_metrics["precision"] = float(precision_score(y_test, y_pred))
test_metrics["f1"] = float(f1_score(y_test, y_pred))
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
else:
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
# Get confusion matrix
print(" Computing confusion matrix...")
labels = list(range(len(training_data.y.labels)))
labels = torch.as_tensor(np.array(labels), device="cuda")
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
print(" Device of labels:", getattr(labels, "device", "cpu"))
cm = confusion_matrix(y_test, y_pred, labels=labels)
cm = cm.cpu().numpy()
labels = labels.cpu().numpy()
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
# Store the test metrics
if not test_metrics_file.exists():
print(f" Storing test metrics to {test_metrics_file.name}...")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
else:
print(" ✓ Test metrics already exist")
# Store the confusion matrix
if True:
# if not cm_file.exists():
print(f" Storing confusion matrix to {cm_file.name}...")
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
else:
print(" ✓ Confusion matrix already exists")
print(" ✓ Done!")
def main():
"""Find all training results and recalculate metrics for those missing them."""
print("Searching for training results directories...")
# Find all results directories
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
print(f"Found {len(results_dirs)} results directories.\n")
for results_dir in results_dirs:
recalculate_metrics(results_dir)
# try:
# except Exception as e:
# print(f" ❌ Error processing {results_dir.name}: {e}")
# continue
print("\n✅ All done!")
if __name__ == "__main__":
main()

View file

@ -1,138 +0,0 @@
import dask.distributed as dd
import xarray as xr
from rich import print
import entropice.utils.codecs
from entropice.utils.paths import get_era5_stores
def print_info(daily_raw=None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print("=== Daily INFO ===")
print(f" Dims: {daily_raw.sizes}")
numchunks = 1
chunksizes = {}
approxchunksize = 4 # 4 Bytes = float32
for d, cs in daily_raw.chunksizes.items():
numchunks *= len(cs)
chunksizes[d] = max(cs)
approxchunksize *= max(cs)
approxchunksize /= 10e6 # MB
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
print(f" Encoding: {daily_raw.encoding}")
if show_vars:
print(" Variables:")
for var in daily_raw.data_vars:
da = daily_raw[var]
print(f" {var} Encoding:")
print(da.encoding)
print("")
def rechunk(use_shards: bool = False):
if use_shards:
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
with (
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
dd.Client(cluster) as client,
):
print(f"Dashboard: {client.dashboard_link}")
daily_store = get_era5_stores("daily")
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
daily_raw = daily_raw.chunk(
{
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1, # Should be 3600
}
)
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
for var in daily_raw.data_vars:
encoding[var]["chunks"] = (120, 337, 3600)
if use_shards:
encoding[var]["shards"] = (1200, 337, 3600)
print(encoding)
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
def validate():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
print("\n=== Comparing Datasets ===")
# Compare dimensions
if daily_raw.sizes != daily_rechunked.sizes:
print("❌ Dimensions differ:")
print(f" Original: {daily_raw.sizes}")
print(f" Rechunked: {daily_rechunked.sizes}")
else:
print("✅ Dimensions match")
# Compare variables
raw_vars = set(daily_raw.data_vars)
rechunked_vars = set(daily_rechunked.data_vars)
if raw_vars != rechunked_vars:
print("❌ Variables differ:")
print(f" Only in original: {raw_vars - rechunked_vars}")
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
else:
print("✅ Variables match")
# Compare each variable
print("\n=== Variable Comparison ===")
all_equal = True
for var in raw_vars & rechunked_vars:
raw_var = daily_raw[var]
rechunked_var = daily_rechunked[var]
if raw_var.equals(rechunked_var):
print(f"{var}: Equal")
else:
all_equal = False
print(f"{var}: NOT Equal")
# Check if values are equal
try:
values_equal = raw_var.values.shape == rechunked_var.values.shape
if values_equal:
import numpy as np
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
if values_equal:
print(" → Values are numerically equal (likely metadata/encoding difference)")
else:
print(" → Values differ!")
print(f" Original shape: {raw_var.values.shape}")
print(f" Rechunked shape: {rechunked_var.values.shape}")
except Exception as e:
print(f" → Error comparing values: {e}")
# Check attributes
if raw_var.attrs != rechunked_var.attrs:
print(" → Attributes differ:")
print(f" Original: {raw_var.attrs}")
print(f" Rechunked: {rechunked_var.attrs}")
# Check encoding
if raw_var.encoding != rechunked_var.encoding:
print(" → Encoding differs:")
print(f" Original: {raw_var.encoding}")
print(f" Rechunked: {rechunked_var.encoding}")
if all_equal:
print("\n✅ Validation successful: All datasets are equal.")
else:
print("\n❌ Validation failed: Datasets have differences (see above).")
if __name__ == "__main__":
validate()

View file

@ -1,144 +0,0 @@
#!/usr/bin/env python
"""Rerun inference for training results that are missing predicted probabilities.
This script searches through training result directories and identifies those that have
a trained model but are missing inference results. It then loads the model and dataset
configuration, reruns inference, and saves the results.
"""
import pickle
from pathlib import Path
import toml
from rich.console import Console
from rich.progress import track
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba
from entropice.utils.paths import RESULTS_DIR
console = Console()
def find_incomplete_trainings() -> list[Path]:
"""Find training result directories missing inference results.
Returns:
list[Path]: List of directories with trained models but missing predictions.
"""
incomplete = []
if not RESULTS_DIR.exists():
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
return incomplete
# Search for all training result directories
for result_dir in RESULTS_DIR.glob("*_cv*"):
if not result_dir.is_dir():
continue
model_file = result_dir / "best_estimator_model.pkl"
settings_file = result_dir / "search_settings.toml"
predictions_file = result_dir / "predicted_probabilities.parquet"
# Check if model and settings exist but predictions are missing
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
incomplete.append(result_dir)
return incomplete
def rerun_inference(result_dir: Path) -> bool:
"""Rerun inference for a training result directory.
Args:
result_dir (Path): Path to the training result directory.
Returns:
bool: True if successful, False otherwise.
"""
try:
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
# Load settings
settings_file = result_dir / "search_settings.toml"
with open(settings_file) as f:
settings_data = toml.load(f)
settings = settings_data["settings"]
# Reconstruct DatasetEnsemble from settings
ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings["members"],
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
# Load trained model
model_file = result_dir / "best_estimator_model.pkl"
with open(model_file, "rb") as f:
clf = pickle.load(f)
console.print("[green]✓[/green] Loaded model and settings")
# Get class labels
classes = settings["classes"]
# Run inference
console.print("[yellow]Running inference...[/yellow]")
preds = predict_proba(ensemble, clf=clf, classes=classes)
# Save predictions
preds_file = result_dir / "predicted_probabilities.parquet"
preds.to_parquet(preds_file)
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
return True
except Exception as e:
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
import traceback
console.print(f"[red]{traceback.format_exc()}[/red]")
return False
def main():
"""Rerun missing inferences for incomplete training results."""
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
incomplete_dirs = find_incomplete_trainings()
if not incomplete_dirs:
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
return
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
for d in incomplete_dirs:
console.print(f"{d.name}")
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
successful = 0
failed = 0
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
if rerun_inference(result_dir):
successful += 1
else:
failed += 1
console.print("\n[bold]Summary:[/bold]")
console.print(f" [green]Successful: {successful}[/green]")
console.print(f" [red]Failed: {failed}[/red]")
if __name__ == "__main__":
main()

View file

@ -32,9 +32,9 @@ def create_climate_map(
""" """
# Subsample if too many cells for performance # Subsample if too many cells for performance
n_cells = len(climate_values["cell_ids"]) n_cells = len(climate_values["cell_ids"])
if n_cells > 100000: if n_cells > 50000:
rng = np.random.default_rng(42) rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=100000, replace=False) cell_indices = rng.choice(n_cells, size=50000, replace=False)
climate_values = climate_values.isel(cell_ids=cell_indices) climate_values = climate_values.isel(cell_ids=cell_indices)
# Create a copy to avoid modifying the original # Create a copy to avoid modifying the original
gdf = grid_gdf.copy().to_crs("EPSG:4326") gdf = grid_gdf.copy().to_crs("EPSG:4326")

View file

@ -29,9 +29,9 @@ def create_embedding_map(
""" """
# Subsample if too many cells for performance # Subsample if too many cells for performance
n_cells = len(embedding_values["cell_ids"]) n_cells = len(embedding_values["cell_ids"])
if n_cells > 100000: if n_cells > 50000:
rng = np.random.default_rng(42) # Fixed seed for reproducibility rng = np.random.default_rng(42) # Fixed seed for reproducibility
cell_indices = rng.choice(n_cells, size=100000, replace=False) cell_indices = rng.choice(n_cells, size=50000, replace=False)
embedding_values = embedding_values.isel(cell_ids=cell_indices) embedding_values = embedding_values.isel(cell_ids=cell_indices)
# Create a copy to avoid modifying the original # Create a copy to avoid modifying the original

View file

@ -25,6 +25,11 @@ def create_grid_areas_map(
pdk.Deck: A PyDeck map visualization of the specified grid statistic. pdk.Deck: A PyDeck map visualization of the specified grid statistic.
""" """
# Subsample if too many cells for performance
n_cells = len(grid_gdf)
if n_cells > 50000:
grid_gdf = grid_gdf.sample(n=50000, random_state=42)
# Create a copy to avoid modifying the original # Create a copy to avoid modifying the original
gdf = grid_gdf.copy().to_crs("EPSG:4326") gdf = grid_gdf.copy().to_crs("EPSG:4326")

View file

@ -33,9 +33,9 @@ def create_terrain_map(
""" """
# Subsample if too many cells for performance # Subsample if too many cells for performance
n_cells = len(terrain_values) n_cells = len(terrain_values)
if n_cells > 100000: if n_cells > 50000:
rng = np.random.default_rng(42) rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=100000, replace=False) cell_indices = rng.choice(n_cells, size=50000, replace=False)
terrain_values = terrain_values.iloc[cell_indices] terrain_values = terrain_values.iloc[cell_indices]
# Create a copy to avoid modifying the original # Create a copy to avoid modifying the original
@ -159,9 +159,9 @@ def create_terrain_distribution_plot(arcticdem_ds: xr.Dataset, features: list[st
""" """
# Subsample if too many cells for performance # Subsample if too many cells for performance
n_cells = len(arcticdem_ds.cell_ids) n_cells = len(arcticdem_ds.cell_ids)
if n_cells > 10000: if n_cells > 50000:
rng = np.random.default_rng(42) rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=10000, replace=False) cell_indices = rng.choice(n_cells, size=50000, replace=False)
arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices) arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices)
# Determine aggregation types available # Determine aggregation types available

View file

@ -78,8 +78,8 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
# Check if subsampling will occur # Check if subsampling will occur
n_cells = len(embedding_values["cell_ids"]) n_cells = len(embedding_values["cell_ids"])
if n_cells > 100000: if n_cells > 50000:
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.") st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
map_deck = create_embedding_map( map_deck = create_embedding_map(
embedding_values=embedding_values, embedding_values=embedding_values,

View file

@ -78,8 +78,8 @@ def _render_terrain_map(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
# Check if subsampling will occur # Check if subsampling will occur
n_cells = len(terrain_series) n_cells = len(terrain_series)
if n_cells > 100000: if n_cells > 50000:
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.") st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
# Create map # Create map
map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map) map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map)

View file

@ -23,6 +23,11 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame):
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d")) make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d"))
# Check if subsampling will occur
n_cells = len(grid_gdf)
if n_cells > 50000:
st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map) map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
st.pydeck_chart(map_deck) st.pydeck_chart(map_deck)

View file

@ -702,6 +702,26 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
m: stats.members[m] for m in selected_members if m in stats.members m: stats.members[m] for m in selected_members if m in stats.members
} }
# Write a possible training command
cmd = (
f"pixi run train --grid {selected_grid_config.grid} --level {selected_grid_config.level} "
f"--temporal-mode {selected_temporal_mode} --target {selected_target} --task {selected_task}"
)
if set(selected_members) != set(available_members):
for member in selected_members:
cmd += f" --member {member}"
if dimension_filters:
for member, dims in dimension_filters.items():
for dim_name, dim_values in dims.items():
dim_filter_str = ",".join(str(v) for v in dim_values)
cmd += f" --dimension-filters.{member}.{dim_name}={dim_filter_str}"
# dim_filters_str = json.dumps(dimension_filters)
# cmd += f" --dimension-filters '{dim_filters_str}'"
st.markdown("#### Equivalent Training Command")
st.code(cmd, language="bash")
st.code(cmd.replace("train", "autogluon"), language="bash")
# Render configuration summary and statistics # Render configuration summary and statistics
_render_configuration_summary( _render_configuration_summary(
selected_members=selected_members, selected_members=selected_members,

View file

@ -57,8 +57,8 @@ def _render_climate_variable_map(climate_values: xr.DataArray, grid_gdf: gpd.Geo
# Create map # Create map
n_cells = len(climate_values) n_cells = len(climate_values)
if n_cells > 100000: if n_cells > 50000:
st.info(f"Showing 100,000 / {n_cells:,} cells for performance") st.info(f"Showing 50,000 / {n_cells:,} cells for performance")
deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d) deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d)
st.pydeck_chart(deck, use_container_width=True) st.pydeck_chart(deck, use_container_width=True)
@ -211,9 +211,9 @@ def render_era5_tab(
climate_values = climate_values.sel(month=selected_month) climate_values = climate_values.sel(month=selected_month)
climate_values = climate_values.compute() climate_values = climate_values.compute()
_render_climate_variable_map(climate_values, grid_gdf, selected_var)
if "year" in climate_values.dims: if "year" in climate_values.dims:
st.divider() st.divider()
_render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month) _render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month)
_render_climate_variable_map(climate_values, grid_gdf, selected_var)

View file

@ -113,29 +113,33 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
with col1: with col1:
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
st.write("**Configuration:**") st.write(
st.write(f"- **Experiment:** {tr.experiment}") "**Configuration:**\n"
st.write(f"- **Task:** {tr.settings.task}") f"- **Experiment:** {tr.experiment}\n"
st.write(f"- **Model:** {tr.settings.model}") f"- **Task:** {tr.settings.task}\n"
st.write(f"- **Grid:** {grid_config.display_name}") f"- **Target:** {tr.settings.target}\n"
st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}") f"- **Model:** {tr.settings.model}\n"
st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}") f"- **Grid:** {grid_config.display_name}\n"
st.write(f"- **Members:** {', '.join(tr.settings.members)}") f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n"
st.write(f"- **CV Splits:** {tr.settings.cv_splits}") f"- **Temporal Mode:** {tr.settings.temporal_mode}\n"
st.write(f"- **Classes:** {tr.settings.classes}") f"- **Members:** {', '.join(tr.settings.members)}\n"
f"- **CV Splits:** {tr.settings.cv_splits}\n"
f"- **Classes:** {tr.settings.classes}\n"
)
st.write("\n**Files:**") file_str = "\n**Files:**\n"
for file in tr.files: for file in tr.files:
if file.name == "search_settings.toml": if file.name == "search_settings.toml":
st.write(f"- ⚙️ `{file.name}`") file_str += f"- ⚙️ `{file.name}`\n"
elif file.name == "best_estimator_model.pkl": elif file.name == "best_estimator_model.pkl":
st.write(f"- 🧮 `{file.name}`") file_str += f"- 🧮 `{file.name}`\n"
elif file.name == "search_results.parquet": elif file.name == "search_results.parquet":
st.write(f"- 📊 `{file.name}`") file_str += f"- 📊 `{file.name}`\n"
elif file.name == "predicted_probabilities.parquet": elif file.name == "predicted_probabilities.parquet":
st.write(f"- 🎯 `{file.name}`") file_str += f"- 🎯 `{file.name}`\n"
else: else:
st.write(f"- 📄 `{file.name}`") file_str += f"- 📄 `{file.name}`\n"
st.write(file_str)
with col2: with col2:
st.write("**CV Score Summary:**") st.write("**CV Score Summary:**")

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Literal
from entropice.utils.types import Grid, Model, Task from entropice.utils.types import Grid, Model, TargetDataset, Task
@dataclass @dataclass
@ -58,6 +58,7 @@ task_display_infos: dict[Task, TaskDisplayInfo] = {
@dataclass @dataclass
class TrainingResultDisplayInfo: class TrainingResultDisplayInfo:
task: Task task: Task
target: TargetDataset
model: Model model: Model
grid: Grid grid: Grid
level: int level: int
@ -65,15 +66,16 @@ class TrainingResultDisplayInfo:
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str: def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
task = self.task.capitalize() task = self.task.capitalize()
target = self.target.replace("_", " ").title()
model = self.model.upper() model = self.model.upper()
grid = self.grid.capitalize() grid = self.grid.capitalize()
level = self.level level = self.level
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M") timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
if format_type == "model_first": if format_type == "model_first":
return f"{model} - {task} - {grid}-{level} ({timestamp})" return f"{model} - {task}@{target} - {grid}-{level} ({timestamp})"
else: # task_first else: # task_first
return f"{task} - {model} - {grid}-{level} ({timestamp})" return f"{task}@{target} - {model} - {grid}-{level} ({timestamp})"
def format_metric_name(metric: str) -> str: def format_metric_name(metric: str) -> str:

View file

@ -104,6 +104,7 @@ class TrainingResult:
"""Get display information for the training result.""" """Get display information for the training result."""
return TrainingResultDisplayInfo( return TrainingResultDisplayInfo(
task=self.settings.task, task=self.settings.task,
target=self.settings.target,
model=self.settings.model, model=self.settings.model,
grid=self.settings.grid, grid=self.settings.grid,
level=self.settings.level, level=self.settings.level,
@ -198,6 +199,7 @@ class TrainingResult:
record = { record = {
"Experiment": tr.experiment if tr.experiment else "N/A", "Experiment": tr.experiment if tr.experiment else "N/A",
"Task": info.task, "Task": info.task,
"Target": info.target,
"Model": info.model, "Model": info.model,
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name, "Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"), "Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
@ -228,7 +230,7 @@ def load_all_training_results() -> list[TrainingResult]:
if not experiment_path.is_dir(): if not experiment_path.is_dir():
continue continue
try: try:
experiment_name = experiment_path.name experiment_name = experiment_path.parent.name
training_result = TrainingResult.from_path(experiment_path, experiment_name) training_result = TrainingResult.from_path(experiment_path, experiment_name)
training_results.append(training_result) training_results.append(training_result)
is_experiment_dir = True is_experiment_dir = True

View file

@ -196,6 +196,60 @@ def combine_to_zarr(grid: Grid, level: int):
print(f"Saved combined embeddings to {zarr_path}.") print(f"Saved combined embeddings to {zarr_path}.")
def _safe_polyfit(embeddings: xr.Dataset) -> xr.Dataset:
batch_size = 1000
embedding_trends = []
for i in track(list(range(0, embeddings.sizes["cell_ids"], batch_size))):
embeddings_islice = embeddings.isel(cell_ids=slice(i, i + batch_size))
try:
embeddings_trend = embeddings_islice.polyfit(dim="year", deg=1, skipna=True).sel(degree=1, drop=True)
embedding_trends.append(embeddings_trend)
except Exception as e:
print(f"Error processing embeddings {i} to {i + batch_size}: {e}")
return xr.concat(embedding_trends, dim="cell_ids")
@cli.command
def compute_synopsis(
grid: Grid,
level: int,
):
"""Create synopsis datasets for spatially aggregated AlphaEarth embeddings.
Loads spatially aggregated AlphaEarth embeddings and computes mean and trend for each variable.
The resulting synopsis datasets are saved to new zarr stores.
Args:
grid (Grid): Grid type.
level (int): Grid resolution level.
"""
store = get_embeddings_store(grid=grid, level=level)
embeddings = xr.open_zarr(store, consolidated=False).compute()
embeddings_mean = embeddings.mean(dim="year")
print(f"{embeddings.sizes=} -> {embeddings.nbytes * 1e-9:.2f} GB")
if (grid == "hex" and level >= 6) or (grid == "healpix" and level >= 10):
# We need to process the trend in chunks, because very large arrays can lead to
# numerical issues in the least squares solver used by polyfit.
embeddings_trend = _safe_polyfit(embeddings)
else:
embeddings_trend = embeddings.polyfit(dim="year", deg=1, skipna=True).sel(degree=1, drop=True)
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
embeddings_trend = embeddings_trend.rename(
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in embeddings_trend.data_vars}
)
embeddings_synopsis = xr.merge([embeddings_mean, embeddings_trend])
synopsis_store = get_embeddings_store(grid=grid, level=level, temporal="synopsis")
encoding = codecs.from_ds(embeddings_synopsis)
for var in embeddings_synopsis.data_vars:
encoding[var]["chunks"] = (min(v, 100000) for v in embeddings_synopsis[var].shape)
print(f"Saving AlphaEarth embeddings synopsis data to {synopsis_store}.")
embeddings_synopsis.to_zarr(synopsis_store, mode="w", encoding=encoding, consolidated=False)
def main(): # noqa: D103 def main(): # noqa: D103
cli() cli()

View file

@ -775,6 +775,10 @@ def spatial_agg(
pxbuffer=10, pxbuffer=10,
) )
# These are artifacts from previous processing steps, drop them
if "latitude" in aggregated.coords and "longitude" in aggregated.coords:
aggregated = aggregated.drop_vars(["latitude", "longitude"])
aggregated = aggregated.chunk( aggregated = aggregated.chunk(
{ {
"cell_ids": min(len(aggregated.cell_ids), 10000), "cell_ids": min(len(aggregated.cell_ids), 10000),
@ -790,5 +794,83 @@ def spatial_agg(
stopwatch.summary() stopwatch.summary()
# ===========================
# === Spatial Aggregation ===
# ===========================
def unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
"""Unstacks the time dimension of an ERA5 dataset into year and month/season dimensions.
Args:
era5 (xr.Dataset): The ERA5 dataset with a time dimension to be unstack
aggregation (Literal["yearly", "seasonal", "shoulder"]): The type of aggregation to perform.
- "yearly": No unstacking, just renames time to year.
- "seasonal": Unstacks into year and season (winter/summer).
- "shoulder": Unstacks into year and shoulder season (OND, JFM, AMJ, JAS).
"""
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
if aggregation == "yearly":
era5 = era5.rename({"time": "year"})
era5.coords["year"] = era5["year"].dt.year
return era5
# Make the time index a MultiIndex of year and month
era5.coords["year"] = era5.time.dt.year
era5.coords["month"] = era5.time.dt.month
era5["time"] = pd.MultiIndex.from_arrays(
[
era5.time.dt.year.values, # noqa: PD011
era5.time.dt.month.values, # noqa: PD011
],
names=("year", "month"),
)
era5 = era5.unstack("time") # noqa: PD010
seasons = {10: "winter", 4: "summer"}
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
era5.coords["month"] = era5["month"].to_series().map(month_map)
return era5
@cli.command
def compute_synopsis(
grid: Grid,
level: int,
):
"""Create synopsis datasets for spatially aggregated ERA5 data.
Loads spatially aggregated ERA5 data and computes mean and trend for each variable.
The resulting synopsis datasets are saved to new zarr stores.
Args:
grid (Grid): Grid type.
level (int): Grid resolution level.
"""
for agg in ("yearly", "seasonal", "shoulder"):
store = get_era5_stores(agg, grid=grid, level=level)
era5 = xr.open_zarr(store, consolidated=False).compute()
# For the trend calculation it is necessary to unstack the time dimension
era5 = unstack_era5_time(era5, agg)
era5_mean = era5.mean(dim="year")
era5_trend = era5.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
era5_trend = era5_trend.rename(
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in era5_trend.data_vars}
)
era5_synopsis = xr.merge([era5_mean, era5_trend])
synopsis_store = get_era5_stores(agg, grid=grid, level=level, temporal="synopsis")
encoding = codecs.from_ds(era5_synopsis)
for var in era5_synopsis.data_vars:
encoding[var]["chunks"] = tuple(era5_synopsis[var].shape)
print(f"Saving ERA5 {agg} synopsis data to {synopsis_store}.")
era5_synopsis.to_zarr(synopsis_store, mode="w", encoding=encoding, consolidated=False)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View file

@ -1,13 +1,10 @@
# ruff: noqa: N803, N806
"""Training with AutoGluon TabularPredictor for automated ML.""" """Training with AutoGluon TabularPredictor for automated ML."""
import json
import pickle import pickle
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import cyclopts import cyclopts
import pandas as pd import pandas as pd
import shap
import toml import toml
from autogluon.tabular import TabularDataset, TabularPredictor from autogluon.tabular import TabularDataset, TabularPredictor
from rich import pretty, traceback from rich import pretty, traceback
@ -15,7 +12,6 @@ from sklearn import set_config
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba
from entropice.utils.paths import get_autogluon_results_dir from entropice.utils.paths import get_autogluon_results_dir
from entropice.utils.types import TargetDataset, Task from entropice.utils.types import TargetDataset, Task
@ -76,12 +72,14 @@ def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
def autogluon_train( def autogluon_train(
dataset_ensemble: DatasetEnsemble, dataset_ensemble: DatasetEnsemble,
settings: AutoGluonSettings = AutoGluonSettings(), settings: AutoGluonSettings = AutoGluonSettings(),
experiment: str | None = None,
): ):
"""Train models using AutoGluon TabularPredictor. """Train models using AutoGluon TabularPredictor.
Args: Args:
dataset_ensemble: Dataset ensemble configuration dataset_ensemble: Dataset ensemble configuration
settings: AutoGluon training settings settings: AutoGluon training settings
experiment: Optional experiment name for organizing results
""" """
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target) training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
@ -104,6 +102,7 @@ def autogluon_train(
# Create results directory # Create results directory
results_dir = get_autogluon_results_dir( results_dir = get_autogluon_results_dir(
experiment=experiment,
grid=dataset_ensemble.grid, grid=dataset_ensemble.grid,
level=dataset_ensemble.level, level=dataset_ensemble.level,
task=settings.task, task=settings.task,
@ -165,64 +164,6 @@ def autogluon_train(
except Exception as e: except Exception as e:
print(f"⚠️ Could not compute feature importance: {e}") print(f"⚠️ Could not compute feature importance: {e}")
# Compute SHAP values for the best model
print("\n🎨 Computing SHAP values...")
with stopwatch("SHAP computation"):
try:
# Use a subset of test data for SHAP (can be slow)
shap_sample_size = min(500, len(test_data))
test_sample = test_data.sample(n=shap_sample_size, random_state=42)
# Get the best model name
best_model = predictor.model_best
# Get predictions function for SHAP
def predict_fn(X):
"""Prediction function for SHAP."""
df = pd.DataFrame(X, columns=training_data.feature_names)
if problem_type == "binary":
# For binary, return probability of positive class
proba = predictor.predict_proba(df, model=best_model)
return proba.iloc[:, 1].to_numpy() # ty:ignore[possibly-missing-attribute, invalid-argument-type]
elif problem_type == "multiclass":
# For multiclass, return all class probabilities
return predictor.predict_proba(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
else:
# For regression
return predictor.predict(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
# Create SHAP explainer
# Use a background sample (smaller is faster)
background_size = min(100, len(train_data))
background = train_data.drop(columns=["label"]).sample(n=background_size, random_state=42).to_numpy()
explainer = shap.KernelExplainer(predict_fn, background)
# Compute SHAP values
test_sample_X = test_sample.drop(columns=["label"]).to_numpy()
shap_values = explainer.shap_values(test_sample_X)
# Save SHAP values
shap_file = results_dir / "shap_values.pkl"
print(f"💾 Saving SHAP values to {shap_file}")
with open(shap_file, "wb") as f:
pickle.dump(
{
"shap_values": shap_values,
"base_values": explainer.expected_value,
"data": test_sample_X,
"feature_names": training_data.feature_names,
},
f,
protocol=pickle.HIGHEST_PROTOCOL,
)
except Exception as e:
print(f"⚠️ Could not compute SHAP values: {e}")
import traceback as tb
tb.print_exc()
# Save training settings # Save training settings
print("\n💾 Saving training settings...") print("\n💾 Saving training settings...")
combined_settings = AutoGluonTrainingSettings( combined_settings = AutoGluonTrainingSettings(
@ -236,22 +177,15 @@ def autogluon_train(
toml.dump({"settings": asdict(combined_settings)}, f) toml.dump({"settings": asdict(combined_settings)}, f)
# Save test metrics # Save test metrics
test_metrics_file = results_dir / "test_metrics.json" test_metrics_file = results_dir / "test_metrics.pickle"
print(f"💾 Saving test metrics to {test_metrics_file}") print(f"💾 Saving test metrics to {test_metrics_file}")
with open(test_metrics_file, "w") as f: with open(test_metrics_file, "wb") as f:
json.dump(test_score, f, indent=2) pickle.dump(test_score, f, protocol=pickle.HIGHEST_PROTOCOL)
# Generate predictions for all cells # Save the predictor
print("\n🗺️ Generating predictions for all cells...") predictor_file = results_dir / "tabular_predictor.pkl"
with stopwatch("Prediction"): print(f"💾 Saving TabularPredictor to {predictor_file}")
preds = predict_proba(dataset_ensemble, model=predictor, device="cpu") predictor.save()
if training_data.targets["y"].dtype == "category":
preds["predicted"] = preds["predicted"].astype("category")
preds["predicted"].cat.categories = training_data.targets["y"].cat.categories
# Save predictions
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"💾 Saving predictions to {preds_file}")
preds.to_parquet(preds_file)
# Print summary # Print summary
print("\n" + "=" * 80) print("\n" + "=" * 80)

View file

@ -35,6 +35,7 @@ from stopuhr import stopwatch
import entropice.spatial.grids import entropice.spatial.grids
import entropice.utils.paths import entropice.utils.paths
from entropice.ingest.era5 import unstack_era5_time
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
traceback.install() traceback.install()
@ -45,31 +46,6 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid") sns.set_theme("talk", "whitegrid")
def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
if aggregation == "yearly":
era5 = era5.rename({"time": "year"})
era5.coords["year"] = era5["year"].dt.year
return era5
# Make the time index a MultiIndex of year and month
era5.coords["year"] = era5.time.dt.year
era5.coords["month"] = era5.time.dt.month
era5["time"] = pd.MultiIndex.from_arrays(
[
era5.time.dt.year.values, # noqa: PD011
era5.time.dt.month.values, # noqa: PD011
],
names=("year", "month"),
)
era5 = era5.unstack("time") # noqa: PD010
seasons = {10: "winter", 4: "summer"}
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
era5.coords["month"] = era5["month"].to_series().map(month_map)
return era5
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame: def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
collapsed = ds.to_dataframe() collapsed = ds.to_dataframe()
# Make a dummy row to avoid empty dataframe issues # Make a dummy row to avoid empty dataframe issues
@ -255,7 +231,7 @@ class TrainingSet:
pd.DataFrame: The training DataFrame. pd.DataFrame: The training DataFrame.
""" """
dataset = self.targets[["y"]].join(self.features) dataset = self.targets[["y"]].rename(columns={"y": "label"}).join(self.features)
if split is not None: if split is not None:
dataset = dataset[self.split == split] dataset = dataset[self.split == split]
assert len(dataset) > 0, "No valid samples found after joining features and targets." assert len(dataset) > 0, "No valid samples found after joining features and targets."
@ -330,14 +306,18 @@ class DatasetEnsemble:
@cache @cache
def read_grid(self) -> gpd.GeoDataFrame: def read_grid(self) -> gpd.GeoDataFrame:
"""Load the grid dataframe and enrich it with lat-lon information.""" """Load the grid dataframe and enrich it with lat-lon information."""
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level) columns_to_load = ["cell_id", "geometry", "cell_area", "land_area", "water_area", "land_ratio"]
# The name add_lonlat has legacy reasons and should be add_location
# Add the lat / lon of the cell centers # If add_location is true, keep the x and y
# For future reworks: "lat" and "lon" are also available columns
if self.add_lonlat: if self.add_lonlat:
grid_gdf["lon"] = grid_gdf.geometry.centroid.x columns_to_load.extend(["x", "y"])
grid_gdf["lat"] = grid_gdf.geometry.centroid.y
# Convert hex cell_id to int # Reading the data takes for the largest grids ~1.7s
gridfile = entropice.utils.paths.get_grid_file(self.grid, self.level)
grid_gdf = gpd.read_parquet(gridfile, columns=columns_to_load)
# Convert hex cell_id to int (for the largest grids ~0.4s)
if self.grid == "hex": if self.grid == "hex":
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64) grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
@ -410,6 +390,8 @@ class DatasetEnsemble:
raise NotImplementedError(f"Task {task} not supported.") raise NotImplementedError(f"Task {task} not supported.")
cell_ids = targets["cell_ids"].to_series() cell_ids = targets["cell_ids"].to_series()
geometries = self.geometries.loc[cell_ids] geometries = self.geometries.loc[cell_ids]
# ! Warning: For some stupid unknown reason, it is not possible to enforce the uint64 dtype on the index
# This will result in joining issues later on if the dtypes do not match!
return gpd.GeoDataFrame( return gpd.GeoDataFrame(
{ {
"cell_id": cell_ids, "cell_id": cell_ids,
@ -437,10 +419,14 @@ class DatasetEnsemble:
""" """
match member: match member:
case "AlphaEarth": case "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.utils.paths.get_embeddings_store(
grid=self.grid, level=self.level, temporal=self.temporal_mode
)
case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder": case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level) store = entropice.utils.paths.get_era5_stores(
era5_agg, grid=self.grid, level=self.level, temporal=self.temporal_mode
)
case "ArcticDEM": case "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level) store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
case _: case _:
@ -458,32 +444,18 @@ class DatasetEnsemble:
# Only load target cell ids # Only load target cell ids
if cell_ids is None: if cell_ids is None:
cell_ids = self.cell_ids cell_ids = self.cell_ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy())) is_intersecting_cell_ids = np.isin(ds["cell_ids"].values, cell_ids.to_numpy(), assume_unique=True)
ds = ds.sel(cell_ids=list(intersecting_cell_ids)) if not is_intersecting_cell_ids.all():
ds = ds[{"cell_ids": is_intersecting_cell_ids}]
# Unstack era5 data if needed # Unstack era5 data if needed (already done for synopsis mode)
if member.startswith("ERA5"): if member.startswith("ERA5") and self.temporal_mode != "synopsis":
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
ds = _unstack_era5_time(ds, era5_agg) ds = unstack_era5_time(ds, era5_agg)
# Apply the temporal mode # Apply the temporal mode
match (member.split("-"), self.temporal_mode): if isinstance(self.temporal_mode, int):
case (["ArcticDEM"], _): ds = ds.sel(year=self.temporal_mode, drop=True)
pass # No temporal dimension
case (_, "feature"):
pass
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
ds_mean = ds.mean(dim="year")
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
ds_trend = ds_trend.rename(
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
)
ds = xr.merge([ds_mean, ds_trend])
case (_, int() as year):
ds = ds.sel(year=year, drop=True)
case _:
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
# Actually read data into memory # Actually read data into memory
if not lazy: if not lazy:

View file

@ -42,6 +42,9 @@ def predict_proba(
grid_gdf = e.read_grid() grid_gdf = e.read_grid()
for batch in e.create_inference_df(batch_size=batch_size): for batch in e.create_inference_df(batch_size=batch_size):
# Filter rows containing NaN values
batch = batch.dropna(axis=0, how="any")
# Skip empty batches (all rows had NaN values) # Skip empty batches (all rows had NaN values)
if len(batch) == 0: if len(batch) == 0:
continue continue

View file

@ -315,6 +315,14 @@ def cli(grid: Grid, level: int):
print("No valid grid cells found.") print("No valid grid cells found.")
return return
# Add location to the grid
grid_4326_centroids = grid_gdf.to_crs("EPSG:4326").geometry.centroid
grid_gdf["lon"] = grid_4326_centroids.x
grid_gdf["lat"] = grid_4326_centroids.y
grid_centroids = grid_gdf.geometry.centroid
grid_gdf["x"] = grid_centroids.x
grid_gdf["y"] = grid_centroids.y
grid_file = get_grid_file(grid, level) grid_file = get_grid_file(grid, level)
grid_gdf.to_parquet(grid_file) grid_gdf.to_parquet(grid_file)
print(f"Saved to {grid_file}") print(f"Saved to {grid_file}")

View file

@ -6,7 +6,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from entropice.utils.types import Grid, Task from entropice.utils.types import Grid, Task, TemporalMode
DATA_DIR = ( DATA_DIR = (
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice" Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
@ -79,8 +79,10 @@ def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
return embfile return embfile
def get_embeddings_store(grid: Grid, level: int) -> Path: def get_embeddings_store(grid: Grid, level: int, temporal: TemporalMode | None = None) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
if temporal == "synopsis":
gridname += "_synopsis"
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr" embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
return embstore return embstore
@ -89,11 +91,11 @@ def get_era5_stores(
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily", agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
grid: Grid | None = None, grid: Grid | None = None,
level: int | None = None, level: int | None = None,
temporal: Literal["synopsis"] | None = None, temporal: TemporalMode | None = None,
) -> Path: ) -> Path:
pdir = ERA5_DIR pdir = ERA5_DIR
if temporal is not None: if temporal == "synopsis":
agg += f"_{temporal}" # ty:ignore[invalid-assignment] agg += "_synopsis" # ty:ignore[invalid-assignment]
fname = f"{agg}_climate.zarr" fname = f"{agg}_climate.zarr"
if grid is None or level is None: if grid is None or level is None:
@ -165,12 +167,18 @@ def get_cv_results_dir(
def get_autogluon_results_dir( def get_autogluon_results_dir(
experiment: str | None,
grid: Grid, grid: Grid,
level: int, level: int,
task: Task, task: Task,
) -> Path: ) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}" if experiment is not None:
experiment_dir = RESULTS_DIR / experiment
experiment_dir.mkdir(parents=True, exist_ok=True)
else:
experiment_dir = RESULTS_DIR
results_dir = experiment_dir / f"{gridname}_autogluon_{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
return results_dir return results_dir