Fix training and finalize dataset page
This commit is contained in:
parent
c358bb63bc
commit
636c034b55
30 changed files with 533 additions and 851 deletions
|
|
@ -19,3 +19,13 @@ pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
|||
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||
|
||||
pixi run alpha-earth compute-synopsis --grid hex --level 3
|
||||
pixi run alpha-earth compute-synopsis --grid hex --level 4
|
||||
pixi run alpha-earth compute-synopsis --grid hex --level 5
|
||||
pixi run alpha-earth compute-synopsis --grid hex --level 6
|
||||
pixi run alpha-earth compute-synopsis --grid healpix --level 6
|
||||
pixi run alpha-earth compute-synopsis --grid healpix --level 7
|
||||
pixi run alpha-earth compute-synopsis --grid healpix --level 8
|
||||
pixi run alpha-earth compute-synopsis --grid healpix --level 9
|
||||
pixi run alpha-earth compute-synopsis --grid healpix --level 10
|
||||
|
|
|
|||
|
|
@ -3,13 +3,24 @@
|
|||
# pixi run era5 download
|
||||
# pixi run era5 enrich
|
||||
|
||||
pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
|
||||
|
||||
pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
|
||||
pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
|
||||
# pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20
|
||||
|
||||
pixi run era5 compute-synopsis --grid hex --level 3
|
||||
pixi run era5 compute-synopsis --grid hex --level 4
|
||||
pixi run era5 compute-synopsis --grid hex --level 5
|
||||
pixi run era5 compute-synopsis --grid hex --level 6
|
||||
|
||||
pixi run era5 compute-synopsis --grid healpix --level 6
|
||||
pixi run era5 compute-synopsis --grid healpix --level 7
|
||||
pixi run era5 compute-synopsis --grid healpix --level 8
|
||||
pixi run era5 compute-synopsis --grid healpix --level 9
|
||||
pixi run era5 compute-synopsis --grid healpix --level 10
|
||||
37
scripts/06static/autogluon.sh
Normal file
37
scripts/06static/autogluon.sh
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which autogluon >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for grid in hex healpix; do
|
||||
if [ "$grid" = "hex" ]; then
|
||||
levels=(3 4 5 6)
|
||||
else
|
||||
levels=(6 7 8 9 10)
|
||||
fi
|
||||
for level in "${levels[@]}"; do
|
||||
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||
else
|
||||
era5_dimension_filters=""
|
||||
fi
|
||||
|
||||
for target in darts_v1 darts_mllabels; do
|
||||
for task in binary density count; do
|
||||
echo
|
||||
echo "----------------------------------------"
|
||||
echo "Running autogluon training for grid=$grid, level=$level, target=$target, task=$task"
|
||||
autogluon --grid "$grid" --level "$level" --target "$target" --task "$task" --time-limit 600 --temporal-mode synopsis --experiment "static-variables-autogluon" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||
echo "----------------------------------------"
|
||||
echo
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
49
scripts/06static/healpix_darts_mllabels.sh
Normal file
49
scripts/06static/healpix_darts_mllabels.sh
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which train >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||
grid="healpix"
|
||||
target="darts_mllabels"
|
||||
# levels=(6 7 8 9 10)
|
||||
levels=(8 9 10)
|
||||
for level in "${levels[@]}"; do
|
||||
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||
else
|
||||
era5_dimension_filters=""
|
||||
fi
|
||||
for task in binary density count; do
|
||||
for model in espa xgboost rf knn; do
|
||||
# Skip if task is density or count and model is espa because espa only supports binary
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||
continue
|
||||
fi
|
||||
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Set number of iterations (use less for slow models)
|
||||
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||
niter=5
|
||||
else
|
||||
niter=100
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "----------------------------------------"
|
||||
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||
echo "----------------------------------------"
|
||||
echo
|
||||
done
|
||||
done
|
||||
done
|
||||
48
scripts/06static/healpix_darts_v1.sh
Normal file
48
scripts/06static/healpix_darts_v1.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which train >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||
grid="healpix"
|
||||
target="darts_v1"
|
||||
#levels=(6 7 8 9 10)
|
||||
levels=(8 9 10)
|
||||
for level in "${levels[@]}"; do
|
||||
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||
else
|
||||
era5_dimension_filters=""
|
||||
fi
|
||||
for task in binary density count; do
|
||||
for model in espa xgboost rf knn; do
|
||||
# Skip if task is density or count and model is espa because espa only supports binary
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||
continue
|
||||
fi
|
||||
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Set number of iterations (use less for slow models)
|
||||
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||
niter=5
|
||||
else
|
||||
niter=100
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "----------------------------------------"
|
||||
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||
echo "----------------------------------------"
|
||||
echo
|
||||
done
|
||||
done
|
||||
done
|
||||
48
scripts/06static/hex_darts_mllabels.sh
Normal file
48
scripts/06static/hex_darts_mllabels.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which train >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||
grid="hex"
|
||||
target="darts_mllabels"
|
||||
# levels=(3 4 5 6)
|
||||
levels=(5 6)
|
||||
for level in "${levels[@]}"; do
|
||||
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||
else
|
||||
era5_dimension_filters=""
|
||||
fi
|
||||
for task in binary density count; do
|
||||
for model in espa xgboost rf knn; do
|
||||
# Skip if task is density or count and model is espa because espa only supports binary
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||
continue
|
||||
fi
|
||||
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Set number of iterations (use less for slow models)
|
||||
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||
niter=5
|
||||
else
|
||||
niter=100
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "----------------------------------------"
|
||||
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||
echo "----------------------------------------"
|
||||
echo
|
||||
done
|
||||
done
|
||||
done
|
||||
48
scripts/06static/hex_darts_v1.sh
Normal file
48
scripts/06static/hex_darts_v1.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Check if running inside the pixi environment
|
||||
which train >/dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "This script must be run inside the pixi environment."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||
grid="hex"
|
||||
target="darts_v1"
|
||||
# levels=(3 4 5 6)
|
||||
levels=(5 6)
|
||||
for level in "${levels[@]}"; do
|
||||
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||
else
|
||||
era5_dimension_filters=""
|
||||
fi
|
||||
for task in binary density count; do
|
||||
for model in espa xgboost rf knn; do
|
||||
# Skip if task is density or count and model is espa because espa only supports binary
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||
continue
|
||||
fi
|
||||
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Set number of iterations (use less for slow models)
|
||||
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||
niter=5
|
||||
else
|
||||
niter=100
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "----------------------------------------"
|
||||
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||
echo "----------------------------------------"
|
||||
echo
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
"""Fix XGBoost feature importance in existing model state files.
|
||||
|
||||
This script repairs XGBoost model state files that have all-zero feature importance
|
||||
values due to incorrect feature name lookup. It reloads the pickled models and
|
||||
regenerates the feature importance arrays with the correct feature index mapping.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
import xarray as xr
|
||||
from rich import print
|
||||
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
|
||||
|
||||
def fix_xgboost_model_state(results_dir: Path) -> bool:
|
||||
"""Fix a single XGBoost model state file.
|
||||
|
||||
Args:
|
||||
results_dir: Directory containing the model files.
|
||||
|
||||
Returns:
|
||||
True if fixed successfully, False otherwise.
|
||||
|
||||
"""
|
||||
# Check if this is an XGBoost model
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
if not settings_file.exists():
|
||||
return False
|
||||
|
||||
settings = toml.load(settings_file)
|
||||
model_type = settings.get("settings", {}).get("model", "")
|
||||
|
||||
if model_type != "xgboost":
|
||||
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
|
||||
return False
|
||||
|
||||
# Check if required files exist
|
||||
model_file = results_dir / "best_estimator_model.pkl"
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
|
||||
if not model_file.exists():
|
||||
print(f"⚠️ Missing model file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
if not state_file.exists():
|
||||
print(f"⚠️ Missing state file in {results_dir.name}")
|
||||
return False
|
||||
|
||||
# Load the pickled model
|
||||
print(f"Loading model from {results_dir.name}...")
|
||||
with open(model_file, "rb") as f:
|
||||
best_estimator = pickle.load(f)
|
||||
|
||||
# Load the old state to get feature names
|
||||
old_state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
features = old_state.coords["feature"].values.tolist()
|
||||
old_state.close()
|
||||
|
||||
# Get the booster and extract feature importance with correct mapping
|
||||
booster = best_estimator.get_booster()
|
||||
|
||||
importance_weight = booster.get_score(importance_type="weight")
|
||||
importance_gain = booster.get_score(importance_type="gain")
|
||||
importance_cover = booster.get_score(importance_type="cover")
|
||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
||||
|
||||
# Align importance using feature indices (f0, f1, ...)
|
||||
def align_importance(importance_dict, features):
|
||||
"""Align importance dict to feature list using feature indices."""
|
||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
||||
|
||||
# Create new DataArrays
|
||||
feature_importance_weight = xr.DataArray(
|
||||
align_importance(importance_weight, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_weight",
|
||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
||||
)
|
||||
feature_importance_gain = xr.DataArray(
|
||||
align_importance(importance_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_gain",
|
||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_cover = xr.DataArray(
|
||||
align_importance(importance_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_cover",
|
||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_gain = xr.DataArray(
|
||||
align_importance(importance_total_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_gain",
|
||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_cover = xr.DataArray(
|
||||
align_importance(importance_total_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_cover",
|
||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
||||
)
|
||||
|
||||
# Create new state dataset
|
||||
n_trees = booster.num_boosted_rounds()
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"feature_importance_weight": feature_importance_weight,
|
||||
"feature_importance_gain": feature_importance_gain,
|
||||
"feature_importance_cover": feature_importance_cover,
|
||||
"feature_importance_total_gain": feature_importance_total_gain,
|
||||
"feature_importance_total_cover": feature_importance_total_cover,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
||||
"n_trees": n_trees,
|
||||
"objective": str(best_estimator.objective),
|
||||
},
|
||||
)
|
||||
|
||||
# Backup the old file
|
||||
backup_file = state_file.with_suffix(".nc.backup")
|
||||
if not backup_file.exists():
|
||||
print(f" Creating backup: {backup_file.name}")
|
||||
state_file.rename(backup_file)
|
||||
else:
|
||||
print(" Backup already exists, removing old state file")
|
||||
state_file.unlink()
|
||||
|
||||
# Save the fixed state
|
||||
print(f" Saving fixed state to {state_file.name}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Verify the fix
|
||||
total_importance = (
|
||||
feature_importance_weight.sum().item()
|
||||
+ feature_importance_gain.sum().item()
|
||||
+ feature_importance_cover.sum().item()
|
||||
)
|
||||
|
||||
if total_importance > 0:
|
||||
print(f" ✓ Success! Total importance: {total_importance:.2f}")
|
||||
return True
|
||||
else:
|
||||
print(" ✗ Warning: Total importance is still 0!")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Find and fix all XGBoost model state files."""
|
||||
print("Scanning for XGBoost model results...")
|
||||
|
||||
# Find all result directories
|
||||
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
|
||||
print(f"Found {len(result_dirs)} result directories")
|
||||
|
||||
fixed_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for result_dir in sorted(result_dirs):
|
||||
try:
|
||||
success = fix_xgboost_model_state(result_dir)
|
||||
if success:
|
||||
fixed_count += 1
|
||||
elif success is False:
|
||||
# Explicitly skipped (not XGBoost)
|
||||
skipped_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {result_dir.name}: {e}")
|
||||
failed_count += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(f" ✓ Fixed: {fixed_count}")
|
||||
print(f" ⊘ Skipped: {skipped_count}")
|
||||
print(f" ✗ Failed: {failed_count}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""Recalculate test metrics and confusion matrix for existing training results.
|
||||
|
||||
This script loads previously trained models and recalculates test metrics
|
||||
and confusion matrices for training runs that were completed before these
|
||||
outputs were added to the training pipeline.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import toml
|
||||
import torch
|
||||
import xarray as xr
|
||||
from sklearn import set_config
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
|
||||
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def recalculate_metrics(results_dir: Path):
|
||||
"""Recalculate test metrics and confusion matrix for a training result.
|
||||
|
||||
Args:
|
||||
results_dir: Path to the results directory containing the trained model.
|
||||
|
||||
"""
|
||||
print(f"\nProcessing: {results_dir}")
|
||||
|
||||
# Load the search settings to get training configuration
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
if not settings_file.exists():
|
||||
print(" ❌ Missing search_settings.toml, skipping...")
|
||||
return
|
||||
|
||||
with open(settings_file) as f:
|
||||
config = toml.load(f)
|
||||
settings = config["settings"]
|
||||
|
||||
# Check if metrics already exist
|
||||
test_metrics_file = results_dir / "test_metrics.toml"
|
||||
cm_file = results_dir / "confusion_matrix.nc"
|
||||
|
||||
# if test_metrics_file.exists() and cm_file.exists():
|
||||
# print(" ✓ Metrics already exist, skipping...")
|
||||
# return
|
||||
|
||||
# Load the best estimator
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
if not best_model_file.exists():
|
||||
print(" ❌ Missing best_estimator_model.pkl, skipping...")
|
||||
return
|
||||
|
||||
print(f" Loading best estimator from {best_model_file.name}...")
|
||||
with open(best_model_file, "rb") as f:
|
||||
best_estimator = pickle.load(f)
|
||||
|
||||
# Recreate the dataset ensemble
|
||||
print(" Recreating training dataset...")
|
||||
dataset_ensemble = DatasetEnsemble(
|
||||
grid=settings["grid"],
|
||||
level=settings["level"],
|
||||
target=settings["target"],
|
||||
members=settings.get(
|
||||
"members",
|
||||
[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
),
|
||||
dimension_filters=settings.get("dimension_filters", {}),
|
||||
variable_filters=settings.get("variable_filters", {}),
|
||||
filter_target=settings.get("filter_target", False),
|
||||
add_lonlat=settings.get("add_lonlat", True),
|
||||
)
|
||||
|
||||
task = settings["task"]
|
||||
model = settings["model"]
|
||||
device = "torch" if model in ["espa"] else "cuda"
|
||||
|
||||
# Create training data
|
||||
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
|
||||
|
||||
# Prepare test data - match training.py's approach
|
||||
print(" Preparing test data...")
|
||||
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
|
||||
y_test = (
|
||||
training_data.y.test.get()
|
||||
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
|
||||
else training_data.y.test
|
||||
)
|
||||
|
||||
# Compute predictions on the test set (use original device data)
|
||||
print(" Computing predictions on test set...")
|
||||
y_pred = best_estimator.predict(training_data.X.test)
|
||||
|
||||
# Use torch
|
||||
y_pred = torch.as_tensor(y_pred, device="cuda")
|
||||
y_test = torch.as_tensor(y_test, device="cuda")
|
||||
|
||||
# Compute metrics manually to avoid device issues
|
||||
print(" Computing test metrics...")
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
jaccard_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
)
|
||||
|
||||
test_metrics = {}
|
||||
if task == "binary":
|
||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
||||
test_metrics["recall"] = float(recall_score(y_test, y_pred))
|
||||
test_metrics["precision"] = float(precision_score(y_test, y_pred))
|
||||
test_metrics["f1"] = float(f1_score(y_test, y_pred))
|
||||
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
|
||||
else:
|
||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
||||
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
|
||||
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
|
||||
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
|
||||
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
|
||||
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
|
||||
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
|
||||
|
||||
# Get confusion matrix
|
||||
print(" Computing confusion matrix...")
|
||||
labels = list(range(len(training_data.y.labels)))
|
||||
labels = torch.as_tensor(np.array(labels), device="cuda")
|
||||
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
|
||||
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
|
||||
print(" Device of labels:", getattr(labels, "device", "cpu"))
|
||||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
||||
cm = cm.cpu().numpy()
|
||||
labels = labels.cpu().numpy()
|
||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
||||
cm_xr = xr.DataArray(
|
||||
cm,
|
||||
dims=["true_label", "predicted_label"],
|
||||
coords={"true_label": label_names, "predicted_label": label_names},
|
||||
name="confusion_matrix",
|
||||
)
|
||||
|
||||
# Store the test metrics
|
||||
if not test_metrics_file.exists():
|
||||
print(f" Storing test metrics to {test_metrics_file.name}...")
|
||||
with open(test_metrics_file, "w") as f:
|
||||
toml.dump({"test_metrics": test_metrics}, f)
|
||||
else:
|
||||
print(" ✓ Test metrics already exist")
|
||||
|
||||
# Store the confusion matrix
|
||||
if True:
|
||||
# if not cm_file.exists():
|
||||
print(f" Storing confusion matrix to {cm_file.name}...")
|
||||
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
|
||||
else:
|
||||
print(" ✓ Confusion matrix already exists")
|
||||
|
||||
print(" ✓ Done!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Find all training results and recalculate metrics for those missing them."""
|
||||
print("Searching for training results directories...")
|
||||
|
||||
# Find all results directories
|
||||
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
|
||||
|
||||
print(f"Found {len(results_dirs)} results directories.\n")
|
||||
|
||||
for results_dir in results_dirs:
|
||||
recalculate_metrics(results_dir)
|
||||
# try:
|
||||
# except Exception as e:
|
||||
# print(f" ❌ Error processing {results_dir.name}: {e}")
|
||||
# continue
|
||||
|
||||
print("\n✅ All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
import dask.distributed as dd
|
||||
import xarray as xr
|
||||
from rich import print
|
||||
|
||||
import entropice.utils.codecs
|
||||
from entropice.utils.paths import get_era5_stores
|
||||
|
||||
|
||||
def print_info(daily_raw=None, show_vars: bool = True):
|
||||
if daily_raw is None:
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
print("=== Daily INFO ===")
|
||||
print(f" Dims: {daily_raw.sizes}")
|
||||
numchunks = 1
|
||||
chunksizes = {}
|
||||
approxchunksize = 4 # 4 Bytes = float32
|
||||
for d, cs in daily_raw.chunksizes.items():
|
||||
numchunks *= len(cs)
|
||||
chunksizes[d] = max(cs)
|
||||
approxchunksize *= max(cs)
|
||||
approxchunksize /= 10e6 # MB
|
||||
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
||||
print(f" Encoding: {daily_raw.encoding}")
|
||||
if show_vars:
|
||||
print(" Variables:")
|
||||
for var in daily_raw.data_vars:
|
||||
da = daily_raw[var]
|
||||
print(f" {var} Encoding:")
|
||||
print(da.encoding)
|
||||
print("")
|
||||
|
||||
|
||||
def rechunk(use_shards: bool = False):
|
||||
if use_shards:
|
||||
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
|
||||
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
|
||||
with (
|
||||
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
|
||||
dd.Client(cluster) as client,
|
||||
):
|
||||
print(f"Dashboard: {client.dashboard_link}")
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
daily_raw = daily_raw.chunk(
|
||||
{
|
||||
"time": 120,
|
||||
"latitude": -1, # Should be 337,
|
||||
"longitude": -1, # Should be 3600
|
||||
}
|
||||
)
|
||||
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
|
||||
for var in daily_raw.data_vars:
|
||||
encoding[var]["chunks"] = (120, 337, 3600)
|
||||
if use_shards:
|
||||
encoding[var]["shards"] = (1200, 337, 3600)
|
||||
print(encoding)
|
||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
|
||||
|
||||
|
||||
def validate():
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
|
||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
||||
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
|
||||
|
||||
print("\n=== Comparing Datasets ===")
|
||||
|
||||
# Compare dimensions
|
||||
if daily_raw.sizes != daily_rechunked.sizes:
|
||||
print("❌ Dimensions differ:")
|
||||
print(f" Original: {daily_raw.sizes}")
|
||||
print(f" Rechunked: {daily_rechunked.sizes}")
|
||||
else:
|
||||
print("✅ Dimensions match")
|
||||
|
||||
# Compare variables
|
||||
raw_vars = set(daily_raw.data_vars)
|
||||
rechunked_vars = set(daily_rechunked.data_vars)
|
||||
if raw_vars != rechunked_vars:
|
||||
print("❌ Variables differ:")
|
||||
print(f" Only in original: {raw_vars - rechunked_vars}")
|
||||
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
|
||||
else:
|
||||
print("✅ Variables match")
|
||||
|
||||
# Compare each variable
|
||||
print("\n=== Variable Comparison ===")
|
||||
all_equal = True
|
||||
for var in raw_vars & rechunked_vars:
|
||||
raw_var = daily_raw[var]
|
||||
rechunked_var = daily_rechunked[var]
|
||||
|
||||
if raw_var.equals(rechunked_var):
|
||||
print(f"✅ {var}: Equal")
|
||||
else:
|
||||
all_equal = False
|
||||
print(f"❌ {var}: NOT Equal")
|
||||
|
||||
# Check if values are equal
|
||||
try:
|
||||
values_equal = raw_var.values.shape == rechunked_var.values.shape
|
||||
if values_equal:
|
||||
import numpy as np
|
||||
|
||||
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
|
||||
|
||||
if values_equal:
|
||||
print(" → Values are numerically equal (likely metadata/encoding difference)")
|
||||
else:
|
||||
print(" → Values differ!")
|
||||
print(f" Original shape: {raw_var.values.shape}")
|
||||
print(f" Rechunked shape: {rechunked_var.values.shape}")
|
||||
except Exception as e:
|
||||
print(f" → Error comparing values: {e}")
|
||||
|
||||
# Check attributes
|
||||
if raw_var.attrs != rechunked_var.attrs:
|
||||
print(" → Attributes differ:")
|
||||
print(f" Original: {raw_var.attrs}")
|
||||
print(f" Rechunked: {rechunked_var.attrs}")
|
||||
|
||||
# Check encoding
|
||||
if raw_var.encoding != rechunked_var.encoding:
|
||||
print(" → Encoding differs:")
|
||||
print(f" Original: {raw_var.encoding}")
|
||||
print(f" Rechunked: {rechunked_var.encoding}")
|
||||
|
||||
if all_equal:
|
||||
print("\n✅ Validation successful: All datasets are equal.")
|
||||
else:
|
||||
print("\n❌ Validation failed: Datasets have differences (see above).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate()
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""Rerun inference for training results that are missing predicted probabilities.
|
||||
|
||||
This script searches through training result directories and identifies those that have
|
||||
a trained model but are missing inference results. It then loads the model and dataset
|
||||
configuration, reruns inference, and saves the results.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.inference import predict_proba
|
||||
from entropice.utils.paths import RESULTS_DIR
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def find_incomplete_trainings() -> list[Path]:
|
||||
"""Find training result directories missing inference results.
|
||||
|
||||
Returns:
|
||||
list[Path]: List of directories with trained models but missing predictions.
|
||||
|
||||
"""
|
||||
incomplete = []
|
||||
|
||||
if not RESULTS_DIR.exists():
|
||||
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
|
||||
return incomplete
|
||||
|
||||
# Search for all training result directories
|
||||
for result_dir in RESULTS_DIR.glob("*_cv*"):
|
||||
if not result_dir.is_dir():
|
||||
continue
|
||||
|
||||
model_file = result_dir / "best_estimator_model.pkl"
|
||||
settings_file = result_dir / "search_settings.toml"
|
||||
predictions_file = result_dir / "predicted_probabilities.parquet"
|
||||
|
||||
# Check if model and settings exist but predictions are missing
|
||||
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
|
||||
incomplete.append(result_dir)
|
||||
|
||||
return incomplete
|
||||
|
||||
|
||||
def rerun_inference(result_dir: Path) -> bool:
|
||||
"""Rerun inference for a training result directory.
|
||||
|
||||
Args:
|
||||
result_dir (Path): Path to the training result directory.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
|
||||
|
||||
# Load settings
|
||||
settings_file = result_dir / "search_settings.toml"
|
||||
with open(settings_file) as f:
|
||||
settings_data = toml.load(f)
|
||||
|
||||
settings = settings_data["settings"]
|
||||
|
||||
# Reconstruct DatasetEnsemble from settings
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=settings["grid"],
|
||||
level=settings["level"],
|
||||
target=settings["target"],
|
||||
members=settings["members"],
|
||||
dimension_filters=settings.get("dimension_filters", {}),
|
||||
variable_filters=settings.get("variable_filters", {}),
|
||||
filter_target=settings.get("filter_target", False),
|
||||
add_lonlat=settings.get("add_lonlat", True),
|
||||
)
|
||||
|
||||
# Load trained model
|
||||
model_file = result_dir / "best_estimator_model.pkl"
|
||||
with open(model_file, "rb") as f:
|
||||
clf = pickle.load(f)
|
||||
|
||||
console.print("[green]✓[/green] Loaded model and settings")
|
||||
|
||||
# Get class labels
|
||||
classes = settings["classes"]
|
||||
|
||||
# Run inference
|
||||
console.print("[yellow]Running inference...[/yellow]")
|
||||
preds = predict_proba(ensemble, clf=clf, classes=classes)
|
||||
|
||||
# Save predictions
|
||||
preds_file = result_dir / "predicted_probabilities.parquet"
|
||||
preds.to_parquet(preds_file)
|
||||
|
||||
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
|
||||
import traceback
|
||||
|
||||
console.print(f"[red]{traceback.format_exc()}[/red]")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Rerun missing inferences for incomplete training results."""
|
||||
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
|
||||
|
||||
incomplete_dirs = find_incomplete_trainings()
|
||||
|
||||
if not incomplete_dirs:
|
||||
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
|
||||
return
|
||||
|
||||
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
|
||||
for d in incomplete_dirs:
|
||||
console.print(f" • {d.name}")
|
||||
|
||||
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
|
||||
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
|
||||
if rerun_inference(result_dir):
|
||||
successful += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
console.print("\n[bold]Summary:[/bold]")
|
||||
console.print(f" [green]Successful: {successful}[/green]")
|
||||
console.print(f" [red]Failed: {failed}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue