Fix training and finalize dataset page
This commit is contained in:
parent
c358bb63bc
commit
636c034b55
30 changed files with 533 additions and 851 deletions
|
|
@ -19,3 +19,13 @@ pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||||
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||||
|
|
||||||
|
pixi run alpha-earth compute-synopsis --grid hex --level 3
|
||||||
|
pixi run alpha-earth compute-synopsis --grid hex --level 4
|
||||||
|
pixi run alpha-earth compute-synopsis --grid hex --level 5
|
||||||
|
pixi run alpha-earth compute-synopsis --grid hex --level 6
|
||||||
|
pixi run alpha-earth compute-synopsis --grid healpix --level 6
|
||||||
|
pixi run alpha-earth compute-synopsis --grid healpix --level 7
|
||||||
|
pixi run alpha-earth compute-synopsis --grid healpix --level 8
|
||||||
|
pixi run alpha-earth compute-synopsis --grid healpix --level 9
|
||||||
|
pixi run alpha-earth compute-synopsis --grid healpix --level 10
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,24 @@
|
||||||
# pixi run era5 download
|
# pixi run era5 download
|
||||||
# pixi run era5 enrich
|
# pixi run era5 enrich
|
||||||
|
|
||||||
pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
|
||||||
|
|
||||||
pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
|
||||||
pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20
|
# pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20
|
||||||
|
|
||||||
|
pixi run era5 compute-synopsis --grid hex --level 3
|
||||||
|
pixi run era5 compute-synopsis --grid hex --level 4
|
||||||
|
pixi run era5 compute-synopsis --grid hex --level 5
|
||||||
|
pixi run era5 compute-synopsis --grid hex --level 6
|
||||||
|
|
||||||
|
pixi run era5 compute-synopsis --grid healpix --level 6
|
||||||
|
pixi run era5 compute-synopsis --grid healpix --level 7
|
||||||
|
pixi run era5 compute-synopsis --grid healpix --level 8
|
||||||
|
pixi run era5 compute-synopsis --grid healpix --level 9
|
||||||
|
pixi run era5 compute-synopsis --grid healpix --level 10
|
||||||
37
scripts/06static/autogluon.sh
Normal file
37
scripts/06static/autogluon.sh
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||||
|
|
||||||
|
# Check if running inside the pixi environment
|
||||||
|
which autogluon >/dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "This script must be run inside the pixi environment."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
for grid in hex healpix; do
|
||||||
|
if [ "$grid" = "hex" ]; then
|
||||||
|
levels=(3 4 5 6)
|
||||||
|
else
|
||||||
|
levels=(6 7 8 9 10)
|
||||||
|
fi
|
||||||
|
for level in "${levels[@]}"; do
|
||||||
|
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||||
|
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||||
|
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||||
|
else
|
||||||
|
era5_dimension_filters=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
for target in darts_v1 darts_mllabels; do
|
||||||
|
for task in binary density count; do
|
||||||
|
echo
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Running autogluon training for grid=$grid, level=$level, target=$target, task=$task"
|
||||||
|
autogluon --grid "$grid" --level "$level" --target "$target" --task "$task" --time-limit 600 --temporal-mode synopsis --experiment "static-variables-autogluon" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
49
scripts/06static/healpix_darts_mllabels.sh
Normal file
49
scripts/06static/healpix_darts_mllabels.sh
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if running inside the pixi environment
|
||||||
|
which train >/dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "This script must be run inside the pixi environment."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||||
|
grid="healpix"
|
||||||
|
target="darts_mllabels"
|
||||||
|
# levels=(6 7 8 9 10)
|
||||||
|
levels=(8 9 10)
|
||||||
|
for level in "${levels[@]}"; do
|
||||||
|
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||||
|
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||||
|
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||||
|
else
|
||||||
|
era5_dimension_filters=""
|
||||||
|
fi
|
||||||
|
for task in binary density count; do
|
||||||
|
for model in espa xgboost rf knn; do
|
||||||
|
# Skip if task is density or count and model is espa because espa only supports binary
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set number of iterations (use less for slow models)
|
||||||
|
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||||
|
niter=5
|
||||||
|
else
|
||||||
|
niter=100
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||||
|
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
48
scripts/06static/healpix_darts_v1.sh
Normal file
48
scripts/06static/healpix_darts_v1.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if running inside the pixi environment
|
||||||
|
which train >/dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "This script must be run inside the pixi environment."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||||
|
grid="healpix"
|
||||||
|
target="darts_v1"
|
||||||
|
#levels=(6 7 8 9 10)
|
||||||
|
levels=(8 9 10)
|
||||||
|
for level in "${levels[@]}"; do
|
||||||
|
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||||
|
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||||
|
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||||
|
else
|
||||||
|
era5_dimension_filters=""
|
||||||
|
fi
|
||||||
|
for task in binary density count; do
|
||||||
|
for model in espa xgboost rf knn; do
|
||||||
|
# Skip if task is density or count and model is espa because espa only supports binary
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set number of iterations (use less for slow models)
|
||||||
|
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||||
|
niter=5
|
||||||
|
else
|
||||||
|
niter=100
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||||
|
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
48
scripts/06static/hex_darts_mllabels.sh
Normal file
48
scripts/06static/hex_darts_mllabels.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if running inside the pixi environment
|
||||||
|
which train >/dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "This script must be run inside the pixi environment."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||||
|
grid="hex"
|
||||||
|
target="darts_mllabels"
|
||||||
|
# levels=(3 4 5 6)
|
||||||
|
levels=(5 6)
|
||||||
|
for level in "${levels[@]}"; do
|
||||||
|
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||||
|
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||||
|
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||||
|
else
|
||||||
|
era5_dimension_filters=""
|
||||||
|
fi
|
||||||
|
for task in binary density count; do
|
||||||
|
for model in espa xgboost rf knn; do
|
||||||
|
# Skip if task is density or count and model is espa because espa only supports binary
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set number of iterations (use less for slow models)
|
||||||
|
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||||
|
niter=5
|
||||||
|
else
|
||||||
|
niter=100
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||||
|
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
48
scripts/06static/hex_darts_v1.sh
Normal file
48
scripts/06static/hex_darts_v1.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if running inside the pixi environment
|
||||||
|
which train >/dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "This script must be run inside the pixi environment."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run some trainings on synopsis datasets for all different tasks + targets + grids + models
|
||||||
|
grid="hex"
|
||||||
|
target="darts_v1"
|
||||||
|
# levels=(3 4 5 6)
|
||||||
|
levels=(5 6)
|
||||||
|
for level in "${levels[@]}"; do
|
||||||
|
# Only apply ERA5 filter for hex-3, hex-4, healpix-6 and healpix-7
|
||||||
|
if { [ "$grid" = "hex" ] && { [ "$level" -eq 3 ] || [ "$level" -eq 4 ]; }; } || { [ "$grid" = "healpix" ] && { [ "$level" -eq 6 ] || [ "$level" -eq 7 ]; }; }; then
|
||||||
|
era5_dimension_filters="--dimension-filters.ERA5-shoulder.aggregations=median --dimension-filters.ERA5-seasonal.aggregations=median --dimension-filters.ERA5-yearly.aggregations=median"
|
||||||
|
else
|
||||||
|
era5_dimension_filters=""
|
||||||
|
fi
|
||||||
|
for task in binary density count; do
|
||||||
|
for model in espa xgboost rf knn; do
|
||||||
|
# Skip if task is density or count and model is espa because espa only supports binary
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "espa" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
# Skip if task is count or density and model is rf because rf is super slow for regression tasks
|
||||||
|
if { [ "$task" = "density" ] || [ "$task" = "count" ]; } && [ "$model" = "rf" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set number of iterations (use less for slow models)
|
||||||
|
if { [ "$model" = "knn" ] || [ "$model" = "rf" ]; }; then
|
||||||
|
niter=5
|
||||||
|
else
|
||||||
|
niter=100
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Running training for grid=$grid, level=$level, target=$target, task=$task, model=$model"
|
||||||
|
train --grid "$grid" --level "$level" --target "$target" --task "$task" --model "$model" --n-iter "$niter" --temporal-mode synopsis --experiment "static-variables" --dimension-filters.ArcticDEM.aggregations=median --dimension-filters.AlphaEarth.agg=median $era5_dimension_filters
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
@ -1,191 +0,0 @@
|
||||||
"""Fix XGBoost feature importance in existing model state files.
|
|
||||||
|
|
||||||
This script repairs XGBoost model state files that have all-zero feature importance
|
|
||||||
values due to incorrect feature name lookup. It reloads the pickled models and
|
|
||||||
regenerates the feature importance arrays with the correct feature index mapping.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import toml
|
|
||||||
import xarray as xr
|
|
||||||
from rich import print
|
|
||||||
|
|
||||||
from entropice.utils.paths import RESULTS_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def fix_xgboost_model_state(results_dir: Path) -> bool:
|
|
||||||
"""Fix a single XGBoost model state file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results_dir: Directory containing the model files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if fixed successfully, False otherwise.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Check if this is an XGBoost model
|
|
||||||
settings_file = results_dir / "search_settings.toml"
|
|
||||||
if not settings_file.exists():
|
|
||||||
return False
|
|
||||||
|
|
||||||
settings = toml.load(settings_file)
|
|
||||||
model_type = settings.get("settings", {}).get("model", "")
|
|
||||||
|
|
||||||
if model_type != "xgboost":
|
|
||||||
print(f"Skipping {results_dir.name} - not an XGBoost model (model={model_type})")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if required files exist
|
|
||||||
model_file = results_dir / "best_estimator_model.pkl"
|
|
||||||
state_file = results_dir / "best_estimator_state.nc"
|
|
||||||
|
|
||||||
if not model_file.exists():
|
|
||||||
print(f"⚠️ Missing model file in {results_dir.name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not state_file.exists():
|
|
||||||
print(f"⚠️ Missing state file in {results_dir.name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Load the pickled model
|
|
||||||
print(f"Loading model from {results_dir.name}...")
|
|
||||||
with open(model_file, "rb") as f:
|
|
||||||
best_estimator = pickle.load(f)
|
|
||||||
|
|
||||||
# Load the old state to get feature names
|
|
||||||
old_state = xr.open_dataset(state_file, engine="h5netcdf")
|
|
||||||
features = old_state.coords["feature"].values.tolist()
|
|
||||||
old_state.close()
|
|
||||||
|
|
||||||
# Get the booster and extract feature importance with correct mapping
|
|
||||||
booster = best_estimator.get_booster()
|
|
||||||
|
|
||||||
importance_weight = booster.get_score(importance_type="weight")
|
|
||||||
importance_gain = booster.get_score(importance_type="gain")
|
|
||||||
importance_cover = booster.get_score(importance_type="cover")
|
|
||||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
|
||||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
|
||||||
|
|
||||||
# Align importance using feature indices (f0, f1, ...)
|
|
||||||
def align_importance(importance_dict, features):
|
|
||||||
"""Align importance dict to feature list using feature indices."""
|
|
||||||
return [importance_dict.get(f"f{i}", 0.0) for i in range(len(features))]
|
|
||||||
|
|
||||||
# Create new DataArrays
|
|
||||||
feature_importance_weight = xr.DataArray(
|
|
||||||
align_importance(importance_weight, features),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": features},
|
|
||||||
name="feature_importance_weight",
|
|
||||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
|
||||||
)
|
|
||||||
feature_importance_gain = xr.DataArray(
|
|
||||||
align_importance(importance_gain, features),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": features},
|
|
||||||
name="feature_importance_gain",
|
|
||||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_cover = xr.DataArray(
|
|
||||||
align_importance(importance_cover, features),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": features},
|
|
||||||
name="feature_importance_cover",
|
|
||||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_total_gain = xr.DataArray(
|
|
||||||
align_importance(importance_total_gain, features),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": features},
|
|
||||||
name="feature_importance_total_gain",
|
|
||||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
feature_importance_total_cover = xr.DataArray(
|
|
||||||
align_importance(importance_total_cover, features),
|
|
||||||
dims=["feature"],
|
|
||||||
coords={"feature": features},
|
|
||||||
name="feature_importance_total_cover",
|
|
||||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create new state dataset
|
|
||||||
n_trees = booster.num_boosted_rounds()
|
|
||||||
state = xr.Dataset(
|
|
||||||
{
|
|
||||||
"feature_importance_weight": feature_importance_weight,
|
|
||||||
"feature_importance_gain": feature_importance_gain,
|
|
||||||
"feature_importance_cover": feature_importance_cover,
|
|
||||||
"feature_importance_total_gain": feature_importance_total_gain,
|
|
||||||
"feature_importance_total_cover": feature_importance_total_cover,
|
|
||||||
},
|
|
||||||
attrs={
|
|
||||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
|
||||||
"n_trees": n_trees,
|
|
||||||
"objective": str(best_estimator.objective),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backup the old file
|
|
||||||
backup_file = state_file.with_suffix(".nc.backup")
|
|
||||||
if not backup_file.exists():
|
|
||||||
print(f" Creating backup: {backup_file.name}")
|
|
||||||
state_file.rename(backup_file)
|
|
||||||
else:
|
|
||||||
print(" Backup already exists, removing old state file")
|
|
||||||
state_file.unlink()
|
|
||||||
|
|
||||||
# Save the fixed state
|
|
||||||
print(f" Saving fixed state to {state_file.name}")
|
|
||||||
state.to_netcdf(state_file, engine="h5netcdf")
|
|
||||||
|
|
||||||
# Verify the fix
|
|
||||||
total_importance = (
|
|
||||||
feature_importance_weight.sum().item()
|
|
||||||
+ feature_importance_gain.sum().item()
|
|
||||||
+ feature_importance_cover.sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
if total_importance > 0:
|
|
||||||
print(f" ✓ Success! Total importance: {total_importance:.2f}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(" ✗ Warning: Total importance is still 0!")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Find and fix all XGBoost model state files."""
|
|
||||||
print("Scanning for XGBoost model results...")
|
|
||||||
|
|
||||||
# Find all result directories
|
|
||||||
result_dirs = [d for d in RESULTS_DIR.iterdir() if d.is_dir()]
|
|
||||||
print(f"Found {len(result_dirs)} result directories")
|
|
||||||
|
|
||||||
fixed_count = 0
|
|
||||||
skipped_count = 0
|
|
||||||
failed_count = 0
|
|
||||||
|
|
||||||
for result_dir in sorted(result_dirs):
|
|
||||||
try:
|
|
||||||
success = fix_xgboost_model_state(result_dir)
|
|
||||||
if success:
|
|
||||||
fixed_count += 1
|
|
||||||
elif success is False:
|
|
||||||
# Explicitly skipped (not XGBoost)
|
|
||||||
skipped_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error processing {result_dir.name}: {e}")
|
|
||||||
failed_count += 1
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Summary:")
|
|
||||||
print(f" ✓ Fixed: {fixed_count}")
|
|
||||||
print(f" ⊘ Skipped: {skipped_count}")
|
|
||||||
print(f" ✗ Failed: {failed_count}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,195 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
"""Recalculate test metrics and confusion matrix for existing training results.
|
|
||||||
|
|
||||||
This script loads previously trained models and recalculates test metrics
|
|
||||||
and confusion matrices for training runs that were completed before these
|
|
||||||
outputs were added to the training pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cupy as cp
|
|
||||||
import numpy as np
|
|
||||||
import toml
|
|
||||||
import torch
|
|
||||||
import xarray as xr
|
|
||||||
from sklearn import set_config
|
|
||||||
from sklearn.metrics import confusion_matrix
|
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
|
||||||
from entropice.utils.paths import RESULTS_DIR
|
|
||||||
|
|
||||||
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
|
|
||||||
set_config(array_api_dispatch=True)
|
|
||||||
|
|
||||||
|
|
||||||
def recalculate_metrics(results_dir: Path):
|
|
||||||
"""Recalculate test metrics and confusion matrix for a training result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results_dir: Path to the results directory containing the trained model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
print(f"\nProcessing: {results_dir}")
|
|
||||||
|
|
||||||
# Load the search settings to get training configuration
|
|
||||||
settings_file = results_dir / "search_settings.toml"
|
|
||||||
if not settings_file.exists():
|
|
||||||
print(" ❌ Missing search_settings.toml, skipping...")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(settings_file) as f:
|
|
||||||
config = toml.load(f)
|
|
||||||
settings = config["settings"]
|
|
||||||
|
|
||||||
# Check if metrics already exist
|
|
||||||
test_metrics_file = results_dir / "test_metrics.toml"
|
|
||||||
cm_file = results_dir / "confusion_matrix.nc"
|
|
||||||
|
|
||||||
# if test_metrics_file.exists() and cm_file.exists():
|
|
||||||
# print(" ✓ Metrics already exist, skipping...")
|
|
||||||
# return
|
|
||||||
|
|
||||||
# Load the best estimator
|
|
||||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
|
||||||
if not best_model_file.exists():
|
|
||||||
print(" ❌ Missing best_estimator_model.pkl, skipping...")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f" Loading best estimator from {best_model_file.name}...")
|
|
||||||
with open(best_model_file, "rb") as f:
|
|
||||||
best_estimator = pickle.load(f)
|
|
||||||
|
|
||||||
# Recreate the dataset ensemble
|
|
||||||
print(" Recreating training dataset...")
|
|
||||||
dataset_ensemble = DatasetEnsemble(
|
|
||||||
grid=settings["grid"],
|
|
||||||
level=settings["level"],
|
|
||||||
target=settings["target"],
|
|
||||||
members=settings.get(
|
|
||||||
"members",
|
|
||||||
[
|
|
||||||
"AlphaEarth",
|
|
||||||
"ArcticDEM",
|
|
||||||
"ERA5-yearly",
|
|
||||||
"ERA5-seasonal",
|
|
||||||
"ERA5-shoulder",
|
|
||||||
],
|
|
||||||
),
|
|
||||||
dimension_filters=settings.get("dimension_filters", {}),
|
|
||||||
variable_filters=settings.get("variable_filters", {}),
|
|
||||||
filter_target=settings.get("filter_target", False),
|
|
||||||
add_lonlat=settings.get("add_lonlat", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
task = settings["task"]
|
|
||||||
model = settings["model"]
|
|
||||||
device = "torch" if model in ["espa"] else "cuda"
|
|
||||||
|
|
||||||
# Create training data
|
|
||||||
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
|
|
||||||
|
|
||||||
# Prepare test data - match training.py's approach
|
|
||||||
print(" Preparing test data...")
|
|
||||||
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
|
|
||||||
y_test = (
|
|
||||||
training_data.y.test.get()
|
|
||||||
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
|
|
||||||
else training_data.y.test
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute predictions on the test set (use original device data)
|
|
||||||
print(" Computing predictions on test set...")
|
|
||||||
y_pred = best_estimator.predict(training_data.X.test)
|
|
||||||
|
|
||||||
# Use torch
|
|
||||||
y_pred = torch.as_tensor(y_pred, device="cuda")
|
|
||||||
y_test = torch.as_tensor(y_test, device="cuda")
|
|
||||||
|
|
||||||
# Compute metrics manually to avoid device issues
|
|
||||||
print(" Computing test metrics...")
|
|
||||||
from sklearn.metrics import (
|
|
||||||
accuracy_score,
|
|
||||||
f1_score,
|
|
||||||
jaccard_score,
|
|
||||||
precision_score,
|
|
||||||
recall_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_metrics = {}
|
|
||||||
if task == "binary":
|
|
||||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
|
||||||
test_metrics["recall"] = float(recall_score(y_test, y_pred))
|
|
||||||
test_metrics["precision"] = float(precision_score(y_test, y_pred))
|
|
||||||
test_metrics["f1"] = float(f1_score(y_test, y_pred))
|
|
||||||
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
|
|
||||||
else:
|
|
||||||
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
|
|
||||||
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
|
|
||||||
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
|
|
||||||
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
|
|
||||||
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
|
|
||||||
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
|
|
||||||
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
|
|
||||||
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
|
|
||||||
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
|
|
||||||
|
|
||||||
# Get confusion matrix
|
|
||||||
print(" Computing confusion matrix...")
|
|
||||||
labels = list(range(len(training_data.y.labels)))
|
|
||||||
labels = torch.as_tensor(np.array(labels), device="cuda")
|
|
||||||
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
|
|
||||||
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
|
|
||||||
print(" Device of labels:", getattr(labels, "device", "cpu"))
|
|
||||||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
|
||||||
cm = cm.cpu().numpy()
|
|
||||||
labels = labels.cpu().numpy()
|
|
||||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
|
||||||
cm_xr = xr.DataArray(
|
|
||||||
cm,
|
|
||||||
dims=["true_label", "predicted_label"],
|
|
||||||
coords={"true_label": label_names, "predicted_label": label_names},
|
|
||||||
name="confusion_matrix",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the test metrics
|
|
||||||
if not test_metrics_file.exists():
|
|
||||||
print(f" Storing test metrics to {test_metrics_file.name}...")
|
|
||||||
with open(test_metrics_file, "w") as f:
|
|
||||||
toml.dump({"test_metrics": test_metrics}, f)
|
|
||||||
else:
|
|
||||||
print(" ✓ Test metrics already exist")
|
|
||||||
|
|
||||||
# Store the confusion matrix
|
|
||||||
if True:
|
|
||||||
# if not cm_file.exists():
|
|
||||||
print(f" Storing confusion matrix to {cm_file.name}...")
|
|
||||||
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
|
|
||||||
else:
|
|
||||||
print(" ✓ Confusion matrix already exists")
|
|
||||||
|
|
||||||
print(" ✓ Done!")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Find all training results and recalculate metrics for those missing them."""
|
|
||||||
print("Searching for training results directories...")
|
|
||||||
|
|
||||||
# Find all results directories
|
|
||||||
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
|
|
||||||
|
|
||||||
print(f"Found {len(results_dirs)} results directories.\n")
|
|
||||||
|
|
||||||
for results_dir in results_dirs:
|
|
||||||
recalculate_metrics(results_dir)
|
|
||||||
# try:
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f" ❌ Error processing {results_dir.name}: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
print("\n✅ All done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,138 +0,0 @@
|
||||||
import dask.distributed as dd
|
|
||||||
import xarray as xr
|
|
||||||
from rich import print
|
|
||||||
|
|
||||||
import entropice.utils.codecs
|
|
||||||
from entropice.utils.paths import get_era5_stores
|
|
||||||
|
|
||||||
|
|
||||||
def print_info(daily_raw=None, show_vars: bool = True):
|
|
||||||
if daily_raw is None:
|
|
||||||
daily_store = get_era5_stores("daily")
|
|
||||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
||||||
print("=== Daily INFO ===")
|
|
||||||
print(f" Dims: {daily_raw.sizes}")
|
|
||||||
numchunks = 1
|
|
||||||
chunksizes = {}
|
|
||||||
approxchunksize = 4 # 4 Bytes = float32
|
|
||||||
for d, cs in daily_raw.chunksizes.items():
|
|
||||||
numchunks *= len(cs)
|
|
||||||
chunksizes[d] = max(cs)
|
|
||||||
approxchunksize *= max(cs)
|
|
||||||
approxchunksize /= 10e6 # MB
|
|
||||||
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
|
||||||
print(f" Encoding: {daily_raw.encoding}")
|
|
||||||
if show_vars:
|
|
||||||
print(" Variables:")
|
|
||||||
for var in daily_raw.data_vars:
|
|
||||||
da = daily_raw[var]
|
|
||||||
print(f" {var} Encoding:")
|
|
||||||
print(da.encoding)
|
|
||||||
print("")
|
|
||||||
|
|
||||||
|
|
||||||
def rechunk(use_shards: bool = False):
|
|
||||||
if use_shards:
|
|
||||||
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
|
|
||||||
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
|
|
||||||
with (
|
|
||||||
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
|
|
||||||
dd.Client(cluster) as client,
|
|
||||||
):
|
|
||||||
print(f"Dashboard: {client.dashboard_link}")
|
|
||||||
daily_store = get_era5_stores("daily")
|
|
||||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
|
|
||||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
||||||
daily_raw = daily_raw.chunk(
|
|
||||||
{
|
|
||||||
"time": 120,
|
|
||||||
"latitude": -1, # Should be 337,
|
|
||||||
"longitude": -1, # Should be 3600
|
|
||||||
}
|
|
||||||
)
|
|
||||||
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
|
|
||||||
for var in daily_raw.data_vars:
|
|
||||||
encoding[var]["chunks"] = (120, 337, 3600)
|
|
||||||
if use_shards:
|
|
||||||
encoding[var]["shards"] = (1200, 337, 3600)
|
|
||||||
print(encoding)
|
|
||||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
|
|
||||||
|
|
||||||
|
|
||||||
def validate():
|
|
||||||
daily_store = get_era5_stores("daily")
|
|
||||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
||||||
|
|
||||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
|
||||||
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
|
|
||||||
|
|
||||||
print("\n=== Comparing Datasets ===")
|
|
||||||
|
|
||||||
# Compare dimensions
|
|
||||||
if daily_raw.sizes != daily_rechunked.sizes:
|
|
||||||
print("❌ Dimensions differ:")
|
|
||||||
print(f" Original: {daily_raw.sizes}")
|
|
||||||
print(f" Rechunked: {daily_rechunked.sizes}")
|
|
||||||
else:
|
|
||||||
print("✅ Dimensions match")
|
|
||||||
|
|
||||||
# Compare variables
|
|
||||||
raw_vars = set(daily_raw.data_vars)
|
|
||||||
rechunked_vars = set(daily_rechunked.data_vars)
|
|
||||||
if raw_vars != rechunked_vars:
|
|
||||||
print("❌ Variables differ:")
|
|
||||||
print(f" Only in original: {raw_vars - rechunked_vars}")
|
|
||||||
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
|
|
||||||
else:
|
|
||||||
print("✅ Variables match")
|
|
||||||
|
|
||||||
# Compare each variable
|
|
||||||
print("\n=== Variable Comparison ===")
|
|
||||||
all_equal = True
|
|
||||||
for var in raw_vars & rechunked_vars:
|
|
||||||
raw_var = daily_raw[var]
|
|
||||||
rechunked_var = daily_rechunked[var]
|
|
||||||
|
|
||||||
if raw_var.equals(rechunked_var):
|
|
||||||
print(f"✅ {var}: Equal")
|
|
||||||
else:
|
|
||||||
all_equal = False
|
|
||||||
print(f"❌ {var}: NOT Equal")
|
|
||||||
|
|
||||||
# Check if values are equal
|
|
||||||
try:
|
|
||||||
values_equal = raw_var.values.shape == rechunked_var.values.shape
|
|
||||||
if values_equal:
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
|
|
||||||
|
|
||||||
if values_equal:
|
|
||||||
print(" → Values are numerically equal (likely metadata/encoding difference)")
|
|
||||||
else:
|
|
||||||
print(" → Values differ!")
|
|
||||||
print(f" Original shape: {raw_var.values.shape}")
|
|
||||||
print(f" Rechunked shape: {rechunked_var.values.shape}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" → Error comparing values: {e}")
|
|
||||||
|
|
||||||
# Check attributes
|
|
||||||
if raw_var.attrs != rechunked_var.attrs:
|
|
||||||
print(" → Attributes differ:")
|
|
||||||
print(f" Original: {raw_var.attrs}")
|
|
||||||
print(f" Rechunked: {rechunked_var.attrs}")
|
|
||||||
|
|
||||||
# Check encoding
|
|
||||||
if raw_var.encoding != rechunked_var.encoding:
|
|
||||||
print(" → Encoding differs:")
|
|
||||||
print(f" Original: {raw_var.encoding}")
|
|
||||||
print(f" Rechunked: {rechunked_var.encoding}")
|
|
||||||
|
|
||||||
if all_equal:
|
|
||||||
print("\n✅ Validation successful: All datasets are equal.")
|
|
||||||
else:
|
|
||||||
print("\n❌ Validation failed: Datasets have differences (see above).")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
validate()
|
|
||||||
|
|
@ -1,144 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
"""Rerun inference for training results that are missing predicted probabilities.
|
|
||||||
|
|
||||||
This script searches through training result directories and identifies those that have
|
|
||||||
a trained model but are missing inference results. It then loads the model and dataset
|
|
||||||
configuration, reruns inference, and saves the results.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import toml
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.progress import track
|
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
|
||||||
from entropice.ml.inference import predict_proba
|
|
||||||
from entropice.utils.paths import RESULTS_DIR
|
|
||||||
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
|
|
||||||
def find_incomplete_trainings() -> list[Path]:
|
|
||||||
"""Find training result directories missing inference results.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Path]: List of directories with trained models but missing predictions.
|
|
||||||
|
|
||||||
"""
|
|
||||||
incomplete = []
|
|
||||||
|
|
||||||
if not RESULTS_DIR.exists():
|
|
||||||
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
|
|
||||||
return incomplete
|
|
||||||
|
|
||||||
# Search for all training result directories
|
|
||||||
for result_dir in RESULTS_DIR.glob("*_cv*"):
|
|
||||||
if not result_dir.is_dir():
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_file = result_dir / "best_estimator_model.pkl"
|
|
||||||
settings_file = result_dir / "search_settings.toml"
|
|
||||||
predictions_file = result_dir / "predicted_probabilities.parquet"
|
|
||||||
|
|
||||||
# Check if model and settings exist but predictions are missing
|
|
||||||
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
|
|
||||||
incomplete.append(result_dir)
|
|
||||||
|
|
||||||
return incomplete
|
|
||||||
|
|
||||||
|
|
||||||
def rerun_inference(result_dir: Path) -> bool:
|
|
||||||
"""Rerun inference for a training result directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result_dir (Path): Path to the training result directory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise.
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
|
|
||||||
|
|
||||||
# Load settings
|
|
||||||
settings_file = result_dir / "search_settings.toml"
|
|
||||||
with open(settings_file) as f:
|
|
||||||
settings_data = toml.load(f)
|
|
||||||
|
|
||||||
settings = settings_data["settings"]
|
|
||||||
|
|
||||||
# Reconstruct DatasetEnsemble from settings
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid=settings["grid"],
|
|
||||||
level=settings["level"],
|
|
||||||
target=settings["target"],
|
|
||||||
members=settings["members"],
|
|
||||||
dimension_filters=settings.get("dimension_filters", {}),
|
|
||||||
variable_filters=settings.get("variable_filters", {}),
|
|
||||||
filter_target=settings.get("filter_target", False),
|
|
||||||
add_lonlat=settings.get("add_lonlat", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load trained model
|
|
||||||
model_file = result_dir / "best_estimator_model.pkl"
|
|
||||||
with open(model_file, "rb") as f:
|
|
||||||
clf = pickle.load(f)
|
|
||||||
|
|
||||||
console.print("[green]✓[/green] Loaded model and settings")
|
|
||||||
|
|
||||||
# Get class labels
|
|
||||||
classes = settings["classes"]
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
console.print("[yellow]Running inference...[/yellow]")
|
|
||||||
preds = predict_proba(ensemble, clf=clf, classes=classes)
|
|
||||||
|
|
||||||
# Save predictions
|
|
||||||
preds_file = result_dir / "predicted_probabilities.parquet"
|
|
||||||
preds.to_parquet(preds_file)
|
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
console.print(f"[red]{traceback.format_exc()}[/red]")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Rerun missing inferences for incomplete training results."""
|
|
||||||
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
|
|
||||||
|
|
||||||
incomplete_dirs = find_incomplete_trainings()
|
|
||||||
|
|
||||||
if not incomplete_dirs:
|
|
||||||
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
|
|
||||||
return
|
|
||||||
|
|
||||||
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
|
|
||||||
for d in incomplete_dirs:
|
|
||||||
console.print(f" • {d.name}")
|
|
||||||
|
|
||||||
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
|
|
||||||
|
|
||||||
successful = 0
|
|
||||||
failed = 0
|
|
||||||
|
|
||||||
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
|
|
||||||
if rerun_inference(result_dir):
|
|
||||||
successful += 1
|
|
||||||
else:
|
|
||||||
failed += 1
|
|
||||||
|
|
||||||
console.print("\n[bold]Summary:[/bold]")
|
|
||||||
console.print(f" [green]Successful: {successful}[/green]")
|
|
||||||
console.print(f" [red]Failed: {failed}[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -32,9 +32,9 @@ def create_climate_map(
|
||||||
"""
|
"""
|
||||||
# Subsample if too many cells for performance
|
# Subsample if too many cells for performance
|
||||||
n_cells = len(climate_values["cell_ids"])
|
n_cells = len(climate_values["cell_ids"])
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
rng = np.random.default_rng(42)
|
rng = np.random.default_rng(42)
|
||||||
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
cell_indices = rng.choice(n_cells, size=50000, replace=False)
|
||||||
climate_values = climate_values.isel(cell_ids=cell_indices)
|
climate_values = climate_values.isel(cell_ids=cell_indices)
|
||||||
# Create a copy to avoid modifying the original
|
# Create a copy to avoid modifying the original
|
||||||
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,9 @@ def create_embedding_map(
|
||||||
"""
|
"""
|
||||||
# Subsample if too many cells for performance
|
# Subsample if too many cells for performance
|
||||||
n_cells = len(embedding_values["cell_ids"])
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
||||||
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
cell_indices = rng.choice(n_cells, size=50000, replace=False)
|
||||||
embedding_values = embedding_values.isel(cell_ids=cell_indices)
|
embedding_values = embedding_values.isel(cell_ids=cell_indices)
|
||||||
|
|
||||||
# Create a copy to avoid modifying the original
|
# Create a copy to avoid modifying the original
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,11 @@ def create_grid_areas_map(
|
||||||
pdk.Deck: A PyDeck map visualization of the specified grid statistic.
|
pdk.Deck: A PyDeck map visualization of the specified grid statistic.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Subsample if too many cells for performance
|
||||||
|
n_cells = len(grid_gdf)
|
||||||
|
if n_cells > 50000:
|
||||||
|
grid_gdf = grid_gdf.sample(n=50000, random_state=42)
|
||||||
|
|
||||||
# Create a copy to avoid modifying the original
|
# Create a copy to avoid modifying the original
|
||||||
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,9 @@ def create_terrain_map(
|
||||||
"""
|
"""
|
||||||
# Subsample if too many cells for performance
|
# Subsample if too many cells for performance
|
||||||
n_cells = len(terrain_values)
|
n_cells = len(terrain_values)
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
rng = np.random.default_rng(42)
|
rng = np.random.default_rng(42)
|
||||||
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
cell_indices = rng.choice(n_cells, size=50000, replace=False)
|
||||||
terrain_values = terrain_values.iloc[cell_indices]
|
terrain_values = terrain_values.iloc[cell_indices]
|
||||||
|
|
||||||
# Create a copy to avoid modifying the original
|
# Create a copy to avoid modifying the original
|
||||||
|
|
@ -159,9 +159,9 @@ def create_terrain_distribution_plot(arcticdem_ds: xr.Dataset, features: list[st
|
||||||
"""
|
"""
|
||||||
# Subsample if too many cells for performance
|
# Subsample if too many cells for performance
|
||||||
n_cells = len(arcticdem_ds.cell_ids)
|
n_cells = len(arcticdem_ds.cell_ids)
|
||||||
if n_cells > 10000:
|
if n_cells > 50000:
|
||||||
rng = np.random.default_rng(42)
|
rng = np.random.default_rng(42)
|
||||||
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
cell_indices = rng.choice(n_cells, size=50000, replace=False)
|
||||||
arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices)
|
arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices)
|
||||||
|
|
||||||
# Determine aggregation types available
|
# Determine aggregation types available
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,8 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
|
||||||
|
|
||||||
# Check if subsampling will occur
|
# Check if subsampling will occur
|
||||||
n_cells = len(embedding_values["cell_ids"])
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
|
st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
|
||||||
|
|
||||||
map_deck = create_embedding_map(
|
map_deck = create_embedding_map(
|
||||||
embedding_values=embedding_values,
|
embedding_values=embedding_values,
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,8 @@ def _render_terrain_map(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
|
||||||
|
|
||||||
# Check if subsampling will occur
|
# Check if subsampling will occur
|
||||||
n_cells = len(terrain_series)
|
n_cells = len(terrain_series)
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
|
st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
|
||||||
|
|
||||||
# Create map
|
# Create map
|
||||||
map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map)
|
map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,11 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame):
|
||||||
|
|
||||||
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d"))
|
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d"))
|
||||||
|
|
||||||
|
# Check if subsampling will occur
|
||||||
|
n_cells = len(grid_gdf)
|
||||||
|
if n_cells > 50000:
|
||||||
|
st.info(f"🗺️ **Map subsampled:** Displaying 50,000 randomly selected cells out of {n_cells:,} for performance.")
|
||||||
|
|
||||||
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
|
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
|
||||||
st.pydeck_chart(map_deck)
|
st.pydeck_chart(map_deck)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -702,6 +702,26 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||||
m: stats.members[m] for m in selected_members if m in stats.members
|
m: stats.members[m] for m in selected_members if m in stats.members
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Write a possible training command
|
||||||
|
cmd = (
|
||||||
|
f"pixi run train --grid {selected_grid_config.grid} --level {selected_grid_config.level} "
|
||||||
|
f"--temporal-mode {selected_temporal_mode} --target {selected_target} --task {selected_task}"
|
||||||
|
)
|
||||||
|
if set(selected_members) != set(available_members):
|
||||||
|
for member in selected_members:
|
||||||
|
cmd += f" --member {member}"
|
||||||
|
if dimension_filters:
|
||||||
|
for member, dims in dimension_filters.items():
|
||||||
|
for dim_name, dim_values in dims.items():
|
||||||
|
dim_filter_str = ",".join(str(v) for v in dim_values)
|
||||||
|
cmd += f" --dimension-filters.{member}.{dim_name}={dim_filter_str}"
|
||||||
|
# dim_filters_str = json.dumps(dimension_filters)
|
||||||
|
# cmd += f" --dimension-filters '{dim_filters_str}'"
|
||||||
|
|
||||||
|
st.markdown("#### Equivalent Training Command")
|
||||||
|
st.code(cmd, language="bash")
|
||||||
|
st.code(cmd.replace("train", "autogluon"), language="bash")
|
||||||
|
|
||||||
# Render configuration summary and statistics
|
# Render configuration summary and statistics
|
||||||
_render_configuration_summary(
|
_render_configuration_summary(
|
||||||
selected_members=selected_members,
|
selected_members=selected_members,
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,8 @@ def _render_climate_variable_map(climate_values: xr.DataArray, grid_gdf: gpd.Geo
|
||||||
|
|
||||||
# Create map
|
# Create map
|
||||||
n_cells = len(climate_values)
|
n_cells = len(climate_values)
|
||||||
if n_cells > 100000:
|
if n_cells > 50000:
|
||||||
st.info(f"Showing 100,000 / {n_cells:,} cells for performance")
|
st.info(f"Showing 50,000 / {n_cells:,} cells for performance")
|
||||||
|
|
||||||
deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d)
|
deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d)
|
||||||
st.pydeck_chart(deck, use_container_width=True)
|
st.pydeck_chart(deck, use_container_width=True)
|
||||||
|
|
@ -211,9 +211,9 @@ def render_era5_tab(
|
||||||
climate_values = climate_values.sel(month=selected_month)
|
climate_values = climate_values.sel(month=selected_month)
|
||||||
climate_values = climate_values.compute()
|
climate_values = climate_values.compute()
|
||||||
|
|
||||||
_render_climate_variable_map(climate_values, grid_gdf, selected_var)
|
|
||||||
|
|
||||||
if "year" in climate_values.dims:
|
if "year" in climate_values.dims:
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
_render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month)
|
_render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month)
|
||||||
|
|
||||||
|
_render_climate_variable_map(climate_values, grid_gdf, selected_var)
|
||||||
|
|
|
||||||
|
|
@ -113,29 +113,33 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
|
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
|
||||||
st.write("**Configuration:**")
|
st.write(
|
||||||
st.write(f"- **Experiment:** {tr.experiment}")
|
"**Configuration:**\n"
|
||||||
st.write(f"- **Task:** {tr.settings.task}")
|
f"- **Experiment:** {tr.experiment}\n"
|
||||||
st.write(f"- **Model:** {tr.settings.model}")
|
f"- **Task:** {tr.settings.task}\n"
|
||||||
st.write(f"- **Grid:** {grid_config.display_name}")
|
f"- **Target:** {tr.settings.target}\n"
|
||||||
st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}")
|
f"- **Model:** {tr.settings.model}\n"
|
||||||
st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}")
|
f"- **Grid:** {grid_config.display_name}\n"
|
||||||
st.write(f"- **Members:** {', '.join(tr.settings.members)}")
|
f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n"
|
||||||
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
f"- **Temporal Mode:** {tr.settings.temporal_mode}\n"
|
||||||
st.write(f"- **Classes:** {tr.settings.classes}")
|
f"- **Members:** {', '.join(tr.settings.members)}\n"
|
||||||
|
f"- **CV Splits:** {tr.settings.cv_splits}\n"
|
||||||
|
f"- **Classes:** {tr.settings.classes}\n"
|
||||||
|
)
|
||||||
|
|
||||||
st.write("\n**Files:**")
|
file_str = "\n**Files:**\n"
|
||||||
for file in tr.files:
|
for file in tr.files:
|
||||||
if file.name == "search_settings.toml":
|
if file.name == "search_settings.toml":
|
||||||
st.write(f"- ⚙️ `{file.name}`")
|
file_str += f"- ⚙️ `{file.name}`\n"
|
||||||
elif file.name == "best_estimator_model.pkl":
|
elif file.name == "best_estimator_model.pkl":
|
||||||
st.write(f"- 🧮 `{file.name}`")
|
file_str += f"- 🧮 `{file.name}`\n"
|
||||||
elif file.name == "search_results.parquet":
|
elif file.name == "search_results.parquet":
|
||||||
st.write(f"- 📊 `{file.name}`")
|
file_str += f"- 📊 `{file.name}`\n"
|
||||||
elif file.name == "predicted_probabilities.parquet":
|
elif file.name == "predicted_probabilities.parquet":
|
||||||
st.write(f"- 🎯 `{file.name}`")
|
file_str += f"- 🎯 `{file.name}`\n"
|
||||||
else:
|
else:
|
||||||
st.write(f"- 📄 `{file.name}`")
|
file_str += f"- 📄 `{file.name}`\n"
|
||||||
|
st.write(file_str)
|
||||||
with col2:
|
with col2:
|
||||||
st.write("**CV Score Summary:**")
|
st.write("**CV Score Summary:**")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from entropice.utils.types import Grid, Model, Task
|
from entropice.utils.types import Grid, Model, TargetDataset, Task
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -58,6 +58,7 @@ task_display_infos: dict[Task, TaskDisplayInfo] = {
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingResultDisplayInfo:
|
class TrainingResultDisplayInfo:
|
||||||
task: Task
|
task: Task
|
||||||
|
target: TargetDataset
|
||||||
model: Model
|
model: Model
|
||||||
grid: Grid
|
grid: Grid
|
||||||
level: int
|
level: int
|
||||||
|
|
@ -65,15 +66,16 @@ class TrainingResultDisplayInfo:
|
||||||
|
|
||||||
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
|
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
|
||||||
task = self.task.capitalize()
|
task = self.task.capitalize()
|
||||||
|
target = self.target.replace("_", " ").title()
|
||||||
model = self.model.upper()
|
model = self.model.upper()
|
||||||
grid = self.grid.capitalize()
|
grid = self.grid.capitalize()
|
||||||
level = self.level
|
level = self.level
|
||||||
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
|
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
if format_type == "model_first":
|
if format_type == "model_first":
|
||||||
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
return f"{model} - {task}@{target} - {grid}-{level} ({timestamp})"
|
||||||
else: # task_first
|
else: # task_first
|
||||||
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
return f"{task}@{target} - {model} - {grid}-{level} ({timestamp})"
|
||||||
|
|
||||||
|
|
||||||
def format_metric_name(metric: str) -> str:
|
def format_metric_name(metric: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,7 @@ class TrainingResult:
|
||||||
"""Get display information for the training result."""
|
"""Get display information for the training result."""
|
||||||
return TrainingResultDisplayInfo(
|
return TrainingResultDisplayInfo(
|
||||||
task=self.settings.task,
|
task=self.settings.task,
|
||||||
|
target=self.settings.target,
|
||||||
model=self.settings.model,
|
model=self.settings.model,
|
||||||
grid=self.settings.grid,
|
grid=self.settings.grid,
|
||||||
level=self.settings.level,
|
level=self.settings.level,
|
||||||
|
|
@ -198,6 +199,7 @@ class TrainingResult:
|
||||||
record = {
|
record = {
|
||||||
"Experiment": tr.experiment if tr.experiment else "N/A",
|
"Experiment": tr.experiment if tr.experiment else "N/A",
|
||||||
"Task": info.task,
|
"Task": info.task,
|
||||||
|
"Target": info.target,
|
||||||
"Model": info.model,
|
"Model": info.model,
|
||||||
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
||||||
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||||
|
|
@ -228,7 +230,7 @@ def load_all_training_results() -> list[TrainingResult]:
|
||||||
if not experiment_path.is_dir():
|
if not experiment_path.is_dir():
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
experiment_name = experiment_path.name
|
experiment_name = experiment_path.parent.name
|
||||||
training_result = TrainingResult.from_path(experiment_path, experiment_name)
|
training_result = TrainingResult.from_path(experiment_path, experiment_name)
|
||||||
training_results.append(training_result)
|
training_results.append(training_result)
|
||||||
is_experiment_dir = True
|
is_experiment_dir = True
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,60 @@ def combine_to_zarr(grid: Grid, level: int):
|
||||||
print(f"Saved combined embeddings to {zarr_path}.")
|
print(f"Saved combined embeddings to {zarr_path}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_polyfit(embeddings: xr.Dataset) -> xr.Dataset:
|
||||||
|
batch_size = 1000
|
||||||
|
embedding_trends = []
|
||||||
|
for i in track(list(range(0, embeddings.sizes["cell_ids"], batch_size))):
|
||||||
|
embeddings_islice = embeddings.isel(cell_ids=slice(i, i + batch_size))
|
||||||
|
try:
|
||||||
|
embeddings_trend = embeddings_islice.polyfit(dim="year", deg=1, skipna=True).sel(degree=1, drop=True)
|
||||||
|
embedding_trends.append(embeddings_trend)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing embeddings {i} to {i + batch_size}: {e}")
|
||||||
|
return xr.concat(embedding_trends, dim="cell_ids")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command
|
||||||
|
def compute_synopsis(
|
||||||
|
grid: Grid,
|
||||||
|
level: int,
|
||||||
|
):
|
||||||
|
"""Create synopsis datasets for spatially aggregated AlphaEarth embeddings.
|
||||||
|
|
||||||
|
Loads spatially aggregated AlphaEarth embeddings and computes mean and trend for each variable.
|
||||||
|
The resulting synopsis datasets are saved to new zarr stores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Grid): Grid type.
|
||||||
|
level (int): Grid resolution level.
|
||||||
|
|
||||||
|
"""
|
||||||
|
store = get_embeddings_store(grid=grid, level=level)
|
||||||
|
embeddings = xr.open_zarr(store, consolidated=False).compute()
|
||||||
|
|
||||||
|
embeddings_mean = embeddings.mean(dim="year")
|
||||||
|
print(f"{embeddings.sizes=} -> {embeddings.nbytes * 1e-9:.2f} GB")
|
||||||
|
|
||||||
|
if (grid == "hex" and level >= 6) or (grid == "healpix" and level >= 10):
|
||||||
|
# We need to process the trend in chunks, because very large arrays can lead to
|
||||||
|
# numerical issues in the least squares solver used by polyfit.
|
||||||
|
embeddings_trend = _safe_polyfit(embeddings)
|
||||||
|
else:
|
||||||
|
embeddings_trend = embeddings.polyfit(dim="year", deg=1, skipna=True).sel(degree=1, drop=True)
|
||||||
|
|
||||||
|
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
|
||||||
|
embeddings_trend = embeddings_trend.rename(
|
||||||
|
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in embeddings_trend.data_vars}
|
||||||
|
)
|
||||||
|
embeddings_synopsis = xr.merge([embeddings_mean, embeddings_trend])
|
||||||
|
synopsis_store = get_embeddings_store(grid=grid, level=level, temporal="synopsis")
|
||||||
|
encoding = codecs.from_ds(embeddings_synopsis)
|
||||||
|
for var in embeddings_synopsis.data_vars:
|
||||||
|
encoding[var]["chunks"] = (min(v, 100000) for v in embeddings_synopsis[var].shape)
|
||||||
|
print(f"Saving AlphaEarth embeddings synopsis data to {synopsis_store}.")
|
||||||
|
embeddings_synopsis.to_zarr(synopsis_store, mode="w", encoding=encoding, consolidated=False)
|
||||||
|
|
||||||
|
|
||||||
def main(): # noqa: D103
|
def main(): # noqa: D103
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -775,6 +775,10 @@ def spatial_agg(
|
||||||
pxbuffer=10,
|
pxbuffer=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# These are artifacts from previous processing steps, drop them
|
||||||
|
if "latitude" in aggregated.coords and "longitude" in aggregated.coords:
|
||||||
|
aggregated = aggregated.drop_vars(["latitude", "longitude"])
|
||||||
|
|
||||||
aggregated = aggregated.chunk(
|
aggregated = aggregated.chunk(
|
||||||
{
|
{
|
||||||
"cell_ids": min(len(aggregated.cell_ids), 10000),
|
"cell_ids": min(len(aggregated.cell_ids), 10000),
|
||||||
|
|
@ -790,5 +794,83 @@ def spatial_agg(
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================
|
||||||
|
# === Spatial Aggregation ===
|
||||||
|
# ===========================
|
||||||
|
|
||||||
|
|
||||||
|
def unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
|
||||||
|
"""Unstacks the time dimension of an ERA5 dataset into year and month/season dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
era5 (xr.Dataset): The ERA5 dataset with a time dimension to be unstack
|
||||||
|
aggregation (Literal["yearly", "seasonal", "shoulder"]): The type of aggregation to perform.
|
||||||
|
- "yearly": No unstacking, just renames time to year.
|
||||||
|
- "seasonal": Unstacks into year and season (winter/summer).
|
||||||
|
- "shoulder": Unstacks into year and shoulder season (OND, JFM, AMJ, JAS).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
|
||||||
|
if aggregation == "yearly":
|
||||||
|
era5 = era5.rename({"time": "year"})
|
||||||
|
era5.coords["year"] = era5["year"].dt.year
|
||||||
|
return era5
|
||||||
|
|
||||||
|
# Make the time index a MultiIndex of year and month
|
||||||
|
era5.coords["year"] = era5.time.dt.year
|
||||||
|
era5.coords["month"] = era5.time.dt.month
|
||||||
|
era5["time"] = pd.MultiIndex.from_arrays(
|
||||||
|
[
|
||||||
|
era5.time.dt.year.values, # noqa: PD011
|
||||||
|
era5.time.dt.month.values, # noqa: PD011
|
||||||
|
],
|
||||||
|
names=("year", "month"),
|
||||||
|
)
|
||||||
|
era5 = era5.unstack("time") # noqa: PD010
|
||||||
|
seasons = {10: "winter", 4: "summer"}
|
||||||
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
|
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
|
||||||
|
era5.coords["month"] = era5["month"].to_series().map(month_map)
|
||||||
|
return era5
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command
|
||||||
|
def compute_synopsis(
|
||||||
|
grid: Grid,
|
||||||
|
level: int,
|
||||||
|
):
|
||||||
|
"""Create synopsis datasets for spatially aggregated ERA5 data.
|
||||||
|
|
||||||
|
Loads spatially aggregated ERA5 data and computes mean and trend for each variable.
|
||||||
|
The resulting synopsis datasets are saved to new zarr stores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Grid): Grid type.
|
||||||
|
level (int): Grid resolution level.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for agg in ("yearly", "seasonal", "shoulder"):
|
||||||
|
store = get_era5_stores(agg, grid=grid, level=level)
|
||||||
|
era5 = xr.open_zarr(store, consolidated=False).compute()
|
||||||
|
|
||||||
|
# For the trend calculation it is necessary to unstack the time dimension
|
||||||
|
era5 = unstack_era5_time(era5, agg)
|
||||||
|
|
||||||
|
era5_mean = era5.mean(dim="year")
|
||||||
|
era5_trend = era5.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
||||||
|
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
|
||||||
|
era5_trend = era5_trend.rename(
|
||||||
|
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in era5_trend.data_vars}
|
||||||
|
)
|
||||||
|
era5_synopsis = xr.merge([era5_mean, era5_trend])
|
||||||
|
synopsis_store = get_era5_stores(agg, grid=grid, level=level, temporal="synopsis")
|
||||||
|
encoding = codecs.from_ds(era5_synopsis)
|
||||||
|
for var in era5_synopsis.data_vars:
|
||||||
|
encoding[var]["chunks"] = tuple(era5_synopsis[var].shape)
|
||||||
|
|
||||||
|
print(f"Saving ERA5 {agg} synopsis data to {synopsis_store}.")
|
||||||
|
era5_synopsis.to_zarr(synopsis_store, mode="w", encoding=encoding, consolidated=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
# ruff: noqa: N803, N806
|
|
||||||
"""Training with AutoGluon TabularPredictor for automated ML."""
|
"""Training with AutoGluon TabularPredictor for automated ML."""
|
||||||
|
|
||||||
import json
|
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import shap
|
|
||||||
import toml
|
import toml
|
||||||
from autogluon.tabular import TabularDataset, TabularPredictor
|
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
|
|
@ -15,7 +12,6 @@ from sklearn import set_config
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.ml.inference import predict_proba
|
|
||||||
from entropice.utils.paths import get_autogluon_results_dir
|
from entropice.utils.paths import get_autogluon_results_dir
|
||||||
from entropice.utils.types import TargetDataset, Task
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
|
|
@ -76,12 +72,14 @@ def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
|
||||||
def autogluon_train(
|
def autogluon_train(
|
||||||
dataset_ensemble: DatasetEnsemble,
|
dataset_ensemble: DatasetEnsemble,
|
||||||
settings: AutoGluonSettings = AutoGluonSettings(),
|
settings: AutoGluonSettings = AutoGluonSettings(),
|
||||||
|
experiment: str | None = None,
|
||||||
):
|
):
|
||||||
"""Train models using AutoGluon TabularPredictor.
|
"""Train models using AutoGluon TabularPredictor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_ensemble: Dataset ensemble configuration
|
dataset_ensemble: Dataset ensemble configuration
|
||||||
settings: AutoGluon training settings
|
settings: AutoGluon training settings
|
||||||
|
experiment: Optional experiment name for organizing results
|
||||||
|
|
||||||
"""
|
"""
|
||||||
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
||||||
|
|
@ -104,6 +102,7 @@ def autogluon_train(
|
||||||
|
|
||||||
# Create results directory
|
# Create results directory
|
||||||
results_dir = get_autogluon_results_dir(
|
results_dir = get_autogluon_results_dir(
|
||||||
|
experiment=experiment,
|
||||||
grid=dataset_ensemble.grid,
|
grid=dataset_ensemble.grid,
|
||||||
level=dataset_ensemble.level,
|
level=dataset_ensemble.level,
|
||||||
task=settings.task,
|
task=settings.task,
|
||||||
|
|
@ -165,64 +164,6 @@ def autogluon_train(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Could not compute feature importance: {e}")
|
print(f"⚠️ Could not compute feature importance: {e}")
|
||||||
|
|
||||||
# Compute SHAP values for the best model
|
|
||||||
print("\n🎨 Computing SHAP values...")
|
|
||||||
with stopwatch("SHAP computation"):
|
|
||||||
try:
|
|
||||||
# Use a subset of test data for SHAP (can be slow)
|
|
||||||
shap_sample_size = min(500, len(test_data))
|
|
||||||
test_sample = test_data.sample(n=shap_sample_size, random_state=42)
|
|
||||||
|
|
||||||
# Get the best model name
|
|
||||||
best_model = predictor.model_best
|
|
||||||
|
|
||||||
# Get predictions function for SHAP
|
|
||||||
def predict_fn(X):
|
|
||||||
"""Prediction function for SHAP."""
|
|
||||||
df = pd.DataFrame(X, columns=training_data.feature_names)
|
|
||||||
if problem_type == "binary":
|
|
||||||
# For binary, return probability of positive class
|
|
||||||
proba = predictor.predict_proba(df, model=best_model)
|
|
||||||
return proba.iloc[:, 1].to_numpy() # ty:ignore[possibly-missing-attribute, invalid-argument-type]
|
|
||||||
elif problem_type == "multiclass":
|
|
||||||
# For multiclass, return all class probabilities
|
|
||||||
return predictor.predict_proba(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
|
||||||
else:
|
|
||||||
# For regression
|
|
||||||
return predictor.predict(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
|
||||||
|
|
||||||
# Create SHAP explainer
|
|
||||||
# Use a background sample (smaller is faster)
|
|
||||||
background_size = min(100, len(train_data))
|
|
||||||
background = train_data.drop(columns=["label"]).sample(n=background_size, random_state=42).to_numpy()
|
|
||||||
|
|
||||||
explainer = shap.KernelExplainer(predict_fn, background)
|
|
||||||
|
|
||||||
# Compute SHAP values
|
|
||||||
test_sample_X = test_sample.drop(columns=["label"]).to_numpy()
|
|
||||||
shap_values = explainer.shap_values(test_sample_X)
|
|
||||||
|
|
||||||
# Save SHAP values
|
|
||||||
shap_file = results_dir / "shap_values.pkl"
|
|
||||||
print(f"💾 Saving SHAP values to {shap_file}")
|
|
||||||
with open(shap_file, "wb") as f:
|
|
||||||
pickle.dump(
|
|
||||||
{
|
|
||||||
"shap_values": shap_values,
|
|
||||||
"base_values": explainer.expected_value,
|
|
||||||
"data": test_sample_X,
|
|
||||||
"feature_names": training_data.feature_names,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
protocol=pickle.HIGHEST_PROTOCOL,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Could not compute SHAP values: {e}")
|
|
||||||
import traceback as tb
|
|
||||||
|
|
||||||
tb.print_exc()
|
|
||||||
|
|
||||||
# Save training settings
|
# Save training settings
|
||||||
print("\n💾 Saving training settings...")
|
print("\n💾 Saving training settings...")
|
||||||
combined_settings = AutoGluonTrainingSettings(
|
combined_settings = AutoGluonTrainingSettings(
|
||||||
|
|
@ -236,22 +177,15 @@ def autogluon_train(
|
||||||
toml.dump({"settings": asdict(combined_settings)}, f)
|
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||||
|
|
||||||
# Save test metrics
|
# Save test metrics
|
||||||
test_metrics_file = results_dir / "test_metrics.json"
|
test_metrics_file = results_dir / "test_metrics.pickle"
|
||||||
print(f"💾 Saving test metrics to {test_metrics_file}")
|
print(f"💾 Saving test metrics to {test_metrics_file}")
|
||||||
with open(test_metrics_file, "w") as f:
|
with open(test_metrics_file, "wb") as f:
|
||||||
json.dump(test_score, f, indent=2)
|
pickle.dump(test_score, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
# Generate predictions for all cells
|
# Save the predictor
|
||||||
print("\n🗺️ Generating predictions for all cells...")
|
predictor_file = results_dir / "tabular_predictor.pkl"
|
||||||
with stopwatch("Prediction"):
|
print(f"💾 Saving TabularPredictor to {predictor_file}")
|
||||||
preds = predict_proba(dataset_ensemble, model=predictor, device="cpu")
|
predictor.save()
|
||||||
if training_data.targets["y"].dtype == "category":
|
|
||||||
preds["predicted"] = preds["predicted"].astype("category")
|
|
||||||
preds["predicted"].cat.categories = training_data.targets["y"].cat.categories
|
|
||||||
# Save predictions
|
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
|
||||||
print(f"💾 Saving predictions to {preds_file}")
|
|
||||||
preds.to_parquet(preds_file)
|
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from stopuhr import stopwatch
|
||||||
|
|
||||||
import entropice.spatial.grids
|
import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
|
from entropice.ingest.era5 import unstack_era5_time
|
||||||
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
|
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
|
|
@ -45,31 +46,6 @@ set_config(array_api_dispatch=True)
|
||||||
sns.set_theme("talk", "whitegrid")
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
|
|
||||||
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
|
|
||||||
if aggregation == "yearly":
|
|
||||||
era5 = era5.rename({"time": "year"})
|
|
||||||
era5.coords["year"] = era5["year"].dt.year
|
|
||||||
return era5
|
|
||||||
|
|
||||||
# Make the time index a MultiIndex of year and month
|
|
||||||
era5.coords["year"] = era5.time.dt.year
|
|
||||||
era5.coords["month"] = era5.time.dt.month
|
|
||||||
era5["time"] = pd.MultiIndex.from_arrays(
|
|
||||||
[
|
|
||||||
era5.time.dt.year.values, # noqa: PD011
|
|
||||||
era5.time.dt.month.values, # noqa: PD011
|
|
||||||
],
|
|
||||||
names=("year", "month"),
|
|
||||||
)
|
|
||||||
era5 = era5.unstack("time") # noqa: PD010
|
|
||||||
seasons = {10: "winter", 4: "summer"}
|
|
||||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
|
||||||
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
|
|
||||||
era5.coords["month"] = era5["month"].to_series().map(month_map)
|
|
||||||
return era5
|
|
||||||
|
|
||||||
|
|
||||||
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
||||||
collapsed = ds.to_dataframe()
|
collapsed = ds.to_dataframe()
|
||||||
# Make a dummy row to avoid empty dataframe issues
|
# Make a dummy row to avoid empty dataframe issues
|
||||||
|
|
@ -255,7 +231,7 @@ class TrainingSet:
|
||||||
pd.DataFrame: The training DataFrame.
|
pd.DataFrame: The training DataFrame.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dataset = self.targets[["y"]].join(self.features)
|
dataset = self.targets[["y"]].rename(columns={"y": "label"}).join(self.features)
|
||||||
if split is not None:
|
if split is not None:
|
||||||
dataset = dataset[self.split == split]
|
dataset = dataset[self.split == split]
|
||||||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||||
|
|
@ -330,14 +306,18 @@ class DatasetEnsemble:
|
||||||
@cache
|
@cache
|
||||||
def read_grid(self) -> gpd.GeoDataFrame:
|
def read_grid(self) -> gpd.GeoDataFrame:
|
||||||
"""Load the grid dataframe and enrich it with lat-lon information."""
|
"""Load the grid dataframe and enrich it with lat-lon information."""
|
||||||
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
|
columns_to_load = ["cell_id", "geometry", "cell_area", "land_area", "water_area", "land_ratio"]
|
||||||
|
# The name add_lonlat has legacy reasons and should be add_location
|
||||||
# Add the lat / lon of the cell centers
|
# If add_location is true, keep the x and y
|
||||||
|
# For future reworks: "lat" and "lon" are also available columns
|
||||||
if self.add_lonlat:
|
if self.add_lonlat:
|
||||||
grid_gdf["lon"] = grid_gdf.geometry.centroid.x
|
columns_to_load.extend(["x", "y"])
|
||||||
grid_gdf["lat"] = grid_gdf.geometry.centroid.y
|
|
||||||
|
|
||||||
# Convert hex cell_id to int
|
# Reading the data takes for the largest grids ~1.7s
|
||||||
|
gridfile = entropice.utils.paths.get_grid_file(self.grid, self.level)
|
||||||
|
grid_gdf = gpd.read_parquet(gridfile, columns=columns_to_load)
|
||||||
|
|
||||||
|
# Convert hex cell_id to int (for the largest grids ~0.4s)
|
||||||
if self.grid == "hex":
|
if self.grid == "hex":
|
||||||
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
|
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
|
||||||
|
|
||||||
|
|
@ -410,6 +390,8 @@ class DatasetEnsemble:
|
||||||
raise NotImplementedError(f"Task {task} not supported.")
|
raise NotImplementedError(f"Task {task} not supported.")
|
||||||
cell_ids = targets["cell_ids"].to_series()
|
cell_ids = targets["cell_ids"].to_series()
|
||||||
geometries = self.geometries.loc[cell_ids]
|
geometries = self.geometries.loc[cell_ids]
|
||||||
|
# ! Warning: For some stupid unknown reason, it is not possible to enforce the uint64 dtype on the index
|
||||||
|
# This will result in joining issues later on if the dtypes do not match!
|
||||||
return gpd.GeoDataFrame(
|
return gpd.GeoDataFrame(
|
||||||
{
|
{
|
||||||
"cell_id": cell_ids,
|
"cell_id": cell_ids,
|
||||||
|
|
@ -437,10 +419,14 @@ class DatasetEnsemble:
|
||||||
"""
|
"""
|
||||||
match member:
|
match member:
|
||||||
case "AlphaEarth":
|
case "AlphaEarth":
|
||||||
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_embeddings_store(
|
||||||
|
grid=self.grid, level=self.level, temporal=self.temporal_mode
|
||||||
|
)
|
||||||
case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
|
case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_era5_stores(
|
||||||
|
era5_agg, grid=self.grid, level=self.level, temporal=self.temporal_mode
|
||||||
|
)
|
||||||
case "ArcticDEM":
|
case "ArcticDEM":
|
||||||
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
||||||
case _:
|
case _:
|
||||||
|
|
@ -458,32 +444,18 @@ class DatasetEnsemble:
|
||||||
# Only load target cell ids
|
# Only load target cell ids
|
||||||
if cell_ids is None:
|
if cell_ids is None:
|
||||||
cell_ids = self.cell_ids
|
cell_ids = self.cell_ids
|
||||||
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy()))
|
is_intersecting_cell_ids = np.isin(ds["cell_ids"].values, cell_ids.to_numpy(), assume_unique=True)
|
||||||
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
|
if not is_intersecting_cell_ids.all():
|
||||||
|
ds = ds[{"cell_ids": is_intersecting_cell_ids}]
|
||||||
|
|
||||||
# Unstack era5 data if needed
|
# Unstack era5 data if needed (already done for synopsis mode)
|
||||||
if member.startswith("ERA5"):
|
if member.startswith("ERA5") and self.temporal_mode != "synopsis":
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
ds = _unstack_era5_time(ds, era5_agg)
|
ds = unstack_era5_time(ds, era5_agg)
|
||||||
|
|
||||||
# Apply the temporal mode
|
# Apply the temporal mode
|
||||||
match (member.split("-"), self.temporal_mode):
|
if isinstance(self.temporal_mode, int):
|
||||||
case (["ArcticDEM"], _):
|
ds = ds.sel(year=self.temporal_mode, drop=True)
|
||||||
pass # No temporal dimension
|
|
||||||
case (_, "feature"):
|
|
||||||
pass
|
|
||||||
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
|
|
||||||
ds_mean = ds.mean(dim="year")
|
|
||||||
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
|
||||||
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
|
|
||||||
ds_trend = ds_trend.rename(
|
|
||||||
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
|
|
||||||
)
|
|
||||||
ds = xr.merge([ds_mean, ds_trend])
|
|
||||||
case (_, int() as year):
|
|
||||||
ds = ds.sel(year=year, drop=True)
|
|
||||||
case _:
|
|
||||||
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
|
|
||||||
|
|
||||||
# Actually read data into memory
|
# Actually read data into memory
|
||||||
if not lazy:
|
if not lazy:
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,9 @@ def predict_proba(
|
||||||
|
|
||||||
grid_gdf = e.read_grid()
|
grid_gdf = e.read_grid()
|
||||||
for batch in e.create_inference_df(batch_size=batch_size):
|
for batch in e.create_inference_df(batch_size=batch_size):
|
||||||
|
# Filter rows containing NaN values
|
||||||
|
batch = batch.dropna(axis=0, how="any")
|
||||||
|
|
||||||
# Skip empty batches (all rows had NaN values)
|
# Skip empty batches (all rows had NaN values)
|
||||||
if len(batch) == 0:
|
if len(batch) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -315,6 +315,14 @@ def cli(grid: Grid, level: int):
|
||||||
print("No valid grid cells found.")
|
print("No valid grid cells found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Add location to the grid
|
||||||
|
grid_4326_centroids = grid_gdf.to_crs("EPSG:4326").geometry.centroid
|
||||||
|
grid_gdf["lon"] = grid_4326_centroids.x
|
||||||
|
grid_gdf["lat"] = grid_4326_centroids.y
|
||||||
|
grid_centroids = grid_gdf.geometry.centroid
|
||||||
|
grid_gdf["x"] = grid_centroids.x
|
||||||
|
grid_gdf["y"] = grid_centroids.y
|
||||||
|
|
||||||
grid_file = get_grid_file(grid, level)
|
grid_file = get_grid_file(grid, level)
|
||||||
grid_gdf.to_parquet(grid_file)
|
grid_gdf.to_parquet(grid_file)
|
||||||
print(f"Saved to {grid_file}")
|
print(f"Saved to {grid_file}")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from entropice.utils.types import Grid, Task
|
from entropice.utils.types import Grid, Task, TemporalMode
|
||||||
|
|
||||||
DATA_DIR = (
|
DATA_DIR = (
|
||||||
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||||
|
|
@ -79,8 +79,10 @@ def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
|
||||||
return embfile
|
return embfile
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_store(grid: Grid, level: int) -> Path:
|
def get_embeddings_store(grid: Grid, level: int, temporal: TemporalMode | None = None) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
|
if temporal == "synopsis":
|
||||||
|
gridname += "_synopsis"
|
||||||
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
|
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
|
||||||
return embstore
|
return embstore
|
||||||
|
|
||||||
|
|
@ -89,11 +91,11 @@ def get_era5_stores(
|
||||||
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
||||||
grid: Grid | None = None,
|
grid: Grid | None = None,
|
||||||
level: int | None = None,
|
level: int | None = None,
|
||||||
temporal: Literal["synopsis"] | None = None,
|
temporal: TemporalMode | None = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
pdir = ERA5_DIR
|
pdir = ERA5_DIR
|
||||||
if temporal is not None:
|
if temporal == "synopsis":
|
||||||
agg += f"_{temporal}" # ty:ignore[invalid-assignment]
|
agg += "_synopsis" # ty:ignore[invalid-assignment]
|
||||||
fname = f"{agg}_climate.zarr"
|
fname = f"{agg}_climate.zarr"
|
||||||
|
|
||||||
if grid is None or level is None:
|
if grid is None or level is None:
|
||||||
|
|
@ -165,12 +167,18 @@ def get_cv_results_dir(
|
||||||
|
|
||||||
|
|
||||||
def get_autogluon_results_dir(
|
def get_autogluon_results_dir(
|
||||||
|
experiment: str | None,
|
||||||
grid: Grid,
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
task: Task,
|
task: Task,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}"
|
if experiment is not None:
|
||||||
|
experiment_dir = RESULTS_DIR / experiment
|
||||||
|
experiment_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
experiment_dir = RESULTS_DIR
|
||||||
|
results_dir = experiment_dir / f"{gridname}_autogluon_{now}_{task}"
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return results_dir
|
return results_dir
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue