Refactor autogluon

This commit is contained in:
Tobias Hölzer 2026-01-11 20:51:53 +01:00
parent 2cefe35690
commit cfb7d65d6d
7 changed files with 305 additions and 232 deletions

View file

@ -38,7 +38,7 @@ cli = cyclopts.App(name="alpha-earth")
# 7454521782,230147807,10000000. # 7454521782,230147807,10000000.
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]: def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts # Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y}) centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})

View file

@ -0,0 +1,270 @@
# ruff: noqa: N803, N806
"""Training with AutoGluon TabularPredictor for automated ML."""
import json
import pickle
from dataclasses import asdict, dataclass
import cyclopts
import pandas as pd
import shap
import toml
from autogluon.tabular import TabularDataset, TabularPredictor
from rich import pretty, traceback
from sklearn import set_config
from stopuhr import stopwatch
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba
from entropice.utils.paths import get_autogluon_results_dir
from entropice.utils.types import TargetDataset, Task
traceback.install()
pretty.install()
set_config(array_api_dispatch=False)
cli = cyclopts.App("entropice-autogluon")
@cyclopts.Parameter("*")
@dataclass(frozen=True, kw_only=True)
class AutoGluonSettings:
"""AutoGluon training settings."""
task: Task = "binary"
target: TargetDataset = "darts_v1"
time_limit: int = 3600 # Time limit in seconds (1 hour default)
presets: str = "best" # AutoGluon preset: 'best', 'high', 'good', 'medium'
eval_metric: str | None = None # Evaluation metric, None for auto-detect
num_bag_folds: int = 5 # Number of folds for bagging
num_bag_sets: int = 1 # Number of bagging sets
num_stack_levels: int = 1 # Number of stacking levels
num_gpus: int = 1 # Number of GPUs to use
verbosity: int = 2 # Verbosity level (0-4)
@dataclass(frozen=True, kw_only=True)
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
"""Combined settings for AutoGluon training."""
classes: list[str]
problem_type: str
def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
"""Determine AutoGluon problem type and appropriate evaluation metric.
Args:
task: The training task type
Returns:
Tuple of (problem_type, eval_metric)
"""
if task == "binary":
return ("binary", "balanced_accuracy") # Good for imbalanced datasets
elif task in ["count_regimes", "density_regimes"]:
return ("multiclass", "f1_weighted") # Weighted F1 for multiclass
elif task in ["count", "density"]:
return ("regression", "mean_absolute_error")
else:
raise ValueError(f"Unknown task: {task}")
@cli.default
def autogluon_train(
dataset_ensemble: DatasetEnsemble,
settings: AutoGluonSettings = AutoGluonSettings(),
):
"""Train models using AutoGluon TabularPredictor.
Args:
dataset_ensemble: Dataset ensemble configuration
settings: AutoGluon training settings
"""
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
# Convert to AutoGluon TabularDataset
train_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("train")) # ty:ignore[invalid-assignment]
test_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("test")) # ty:ignore[invalid-assignment]
print(f"\nTraining data: {len(train_data)} samples")
print(f"Test data: {len(test_data)} samples")
print(f"Features: {len(training_data.feature_names)}")
print(f"Classes: {training_data.target_labels}")
# Determine problem type and metric
problem_type, default_metric = _determine_problem_type_and_metric(settings.task)
eval_metric = settings.eval_metric or default_metric
print(f"\n🎯 Problem type: {problem_type}")
print(f"📈 Evaluation metric: {eval_metric}")
# Create results directory
results_dir = get_autogluon_results_dir(
grid=dataset_ensemble.grid,
level=dataset_ensemble.level,
task=settings.task,
)
print(f"\n💾 Results directory: {results_dir}")
# Initialize TabularPredictor
print(f"\n🚀 Initializing AutoGluon TabularPredictor (preset='{settings.presets}')...")
predictor = TabularPredictor(
label="label",
problem_type=problem_type,
eval_metric=eval_metric,
path=str(results_dir / "models"),
verbosity=settings.verbosity,
)
# Train models
print(f"\n⚡ Training models (time_limit={settings.time_limit}s, num_gpus={settings.num_gpus})...")
with stopwatch("AutoGluon training"):
predictor.fit(
train_data=train_data,
time_limit=settings.time_limit,
presets=settings.presets,
num_bag_folds=settings.num_bag_folds,
num_bag_sets=settings.num_bag_sets,
num_stack_levels=settings.num_stack_levels,
num_gpus=settings.num_gpus,
ag_args_fit={"num_gpus": settings.num_gpus} if settings.num_gpus > 0 else None,
)
# Evaluate on test data
print("\n📊 Evaluating on test data...")
test_score = predictor.evaluate(test_data, silent=True, detailed_report=True)
print(f"Test {eval_metric}: {test_score[eval_metric]:.4f}")
# Get leaderboard
print("\n🏆 Model Leaderboard:")
leaderboard = predictor.leaderboard(test_data, silent=True)
print(leaderboard[["model", "score_test", "score_val", "pred_time_test", "fit_time"]].head(10))
# Save leaderboard
leaderboard_file = results_dir / "leaderboard.parquet"
print(f"\n💾 Saving leaderboard to {leaderboard_file}")
leaderboard.to_parquet(leaderboard_file)
# Get feature importance
print("\n🔍 Computing feature importance...")
with stopwatch("Feature importance"):
try:
# Compute feature importance with reduced repeats
feature_importance = predictor.feature_importance(
test_data,
num_shuffle_sets=3,
subsample_size=min(500, len(test_data)), # Further subsample if needed
)
fi_file = results_dir / "feature_importance.parquet"
print(f"💾 Saving feature importance to {fi_file}")
feature_importance.to_parquet(fi_file)
except Exception as e:
print(f"⚠️ Could not compute feature importance: {e}")
# Compute SHAP values for the best model
print("\n🎨 Computing SHAP values...")
with stopwatch("SHAP computation"):
try:
# Use a subset of test data for SHAP (can be slow)
shap_sample_size = min(500, len(test_data))
test_sample = test_data.sample(n=shap_sample_size, random_state=42)
# Get the best model name
best_model = predictor.model_best
# Get predictions function for SHAP
def predict_fn(X):
"""Prediction function for SHAP."""
df = pd.DataFrame(X, columns=training_data.feature_names)
if problem_type == "binary":
# For binary, return probability of positive class
proba = predictor.predict_proba(df, model=best_model)
return proba.iloc[:, 1].to_numpy() # ty:ignore[possibly-missing-attribute, invalid-argument-type]
elif problem_type == "multiclass":
# For multiclass, return all class probabilities
return predictor.predict_proba(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
else:
# For regression
return predictor.predict(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
# Create SHAP explainer
# Use a background sample (smaller is faster)
background_size = min(100, len(train_data))
background = train_data.drop(columns=["label"]).sample(n=background_size, random_state=42).to_numpy()
explainer = shap.KernelExplainer(predict_fn, background)
# Compute SHAP values
test_sample_X = test_sample.drop(columns=["label"]).to_numpy()
shap_values = explainer.shap_values(test_sample_X)
# Save SHAP values
shap_file = results_dir / "shap_values.pkl"
print(f"💾 Saving SHAP values to {shap_file}")
with open(shap_file, "wb") as f:
pickle.dump(
{
"shap_values": shap_values,
"base_values": explainer.expected_value,
"data": test_sample_X,
"feature_names": training_data.feature_names,
},
f,
protocol=pickle.HIGHEST_PROTOCOL,
)
except Exception as e:
print(f"⚠️ Could not compute SHAP values: {e}")
import traceback as tb
tb.print_exc()
# Save training settings
print("\n💾 Saving training settings...")
combined_settings = AutoGluonTrainingSettings(
**asdict(settings),
**asdict(dataset_ensemble),
classes=training_data.y.labels,
problem_type=problem_type,
)
settings_file = results_dir / "training_settings.toml"
with open(settings_file, "w") as f:
toml.dump({"settings": asdict(combined_settings)}, f)
# Save test metrics
test_metrics_file = results_dir / "test_metrics.json"
print(f"💾 Saving test metrics to {test_metrics_file}")
with open(test_metrics_file, "w") as f:
json.dump(test_score, f, indent=2)
# Generate predictions for all cells
print("\n🗺️ Generating predictions for all cells...")
with stopwatch("Prediction"):
preds = predict_proba(dataset_ensemble, model=predictor, device="cpu")
if training_data.targets["y"].dtype == "category":
preds["predicted"] = preds["predicted"].astype("category")
preds["predicted"].cat.categories = training_data.targets["y"].cat.categories
# Save predictions
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"💾 Saving predictions to {preds_file}")
preds.to_parquet(preds_file)
# Print summary
print("\n" + "=" * 80)
print("✅ AutoGluon Training Complete!")
print("=" * 80)
print(f"\n📂 Results saved to: {results_dir}")
print(f"🏆 Best model: {predictor.model_best}")
print(f"📈 Test {eval_metric}: {test_score[eval_metric]:.4f}")
print(f"⏱️ Total models trained: {len(leaderboard)}")
stopwatch.summary()
print("\nDone! 🎉")
if __name__ == "__main__":
cli()

View file

@ -241,6 +241,26 @@ class TrainingSet:
def __len__(self): def __len__(self):
return len(self.z) return len(self.z)
def to_dataframe(self, split: Literal["train", "test"] | None) -> pd.DataFrame:
"""Create a simple DataFrame for training with e.g. autogluon.
This function creates a DataFrame with features and target labels (called "y") for training.
The resulting dataframe contains no geometry information and drops all samples with NaN values.
Args:
task (Task): The task.
split (Literal["train", "test"] | None): If specified, only return the samples for the given split.
Returns:
pd.DataFrame: The training DataFrame.
"""
dataset = self.targets[["y"]].join(self.features)
if split is not None:
dataset = dataset[self.split == split]
assert len(dataset) > 0, "No valid samples found after joining features and targets."
return dataset
@cyclopts.Parameter("*") @cyclopts.Parameter("*")
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@ -249,7 +269,6 @@ class DatasetEnsemble:
grid: Grid grid: Grid
level: int level: int
target: TargetDataset
members: list[L2SourceDataset] = field( members: list[L2SourceDataset] = field(
default_factory=lambda: [ default_factory=lambda: [
"AlphaEarth", "AlphaEarth",
@ -325,8 +344,8 @@ class DatasetEnsemble:
return grid_gdf return grid_gdf
@stopwatch.f("Getting target labels", print_kwargs=["task"]) @stopwatch.f("Getting target labels", print_kwargs=["task"])
def get_targets(self, task: Task) -> gpd.GeoDataFrame: def get_targets(self, target: TargetDataset, task: Task) -> gpd.GeoDataFrame:
"""Create a training target labels for a specific task. """Create a training target labels for a specific target and task.
The function reads the target dataset, filters it based on coverage, The function reads the target dataset, filters it based on coverage,
and prepares the target labels according to the specified task. and prepares the target labels according to the specified task.
@ -339,6 +358,7 @@ class DatasetEnsemble:
Args: Args:
target (TargetDataset): The target dataset to use.
task (Task): The task. task (Task): The task.
Returns: Returns:
@ -349,13 +369,13 @@ class DatasetEnsemble:
For regression tasks, both "y" and "z" are numerical and identical. For regression tasks, both "y" and "z" are numerical and identical.
""" """
match (self.target, self.temporal_mode): match (target, self.temporal_mode):
case ("darts_v1" | "darts_v2", "feature" | "synopsis"): case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment] version: Literal["v1-l3", "v2-l3"] = target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version) target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False) targets = xr.open_zarr(target_store, consolidated=False)
case ("darts_v1" | "darts_v2", int()): case ("darts_v1" | "darts_v2", int()):
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment] version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version) target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode) targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
case ("darts_mllabels", str()): # Years are not supported case ("darts_mllabels", str()): # Years are not supported
@ -364,7 +384,7 @@ class DatasetEnsemble:
) )
targets = xr.open_zarr(target_store, consolidated=False) targets = xr.open_zarr(target_store, consolidated=False)
case _: case _:
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.") raise NotImplementedError(f"Target {target} on {self.temporal_mode} mode not supported.")
targets = cast(xr.Dataset, targets) targets = cast(xr.Dataset, targets)
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series() covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy()) targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
@ -599,42 +619,11 @@ class DatasetEnsemble:
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size] batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode) yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"]) @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
def create_training_df(
self,
task: Task,
cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> pd.DataFrame:
"""Create a simple DataFrame for training with e.g. autogluon.
This function creates a DataFrame with features and target labels (called "y") for training.
The resulting dataframe contains no geometry information and drops all samples with NaN values.
Does not split into train and test sets, also does not convert to arrays.
For more advanced use cases, use `create_training_set`.
Args:
task (Task): The task.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "read".
Returns:
pd.DataFrame: The training DataFrame.
"""
targets = self.get_targets(task)
assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
dataset = targets[["y"]].join(features).dropna()
assert len(dataset) > 0, "No valid samples found after joining features and targets."
return dataset
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
def create_training_set( def create_training_set(
self, self,
task: Task, task: Task,
target: TargetDataset,
device: Literal["cpu", "cuda", "torch"] = "cpu", device: Literal["cpu", "cuda", "torch"] = "cpu",
cache_mode: Literal["none", "overwrite", "read"] = "read", cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> TrainingSet: ) -> TrainingSet:
@ -645,6 +634,7 @@ class DatasetEnsemble:
Args: Args:
task (Task): The task. task (Task): The task.
target (TargetDataset): The target dataset to use.
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu". device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation. cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching. "none": No caching.
@ -657,7 +647,7 @@ class DatasetEnsemble:
Contains a lot of useful information for training and evaluation. Contains a lot of useful information for training and evaluation.
""" """
targets = self.get_targets(task) targets = self.get_targets(target, task)
assert len(targets) > 0, "No target samples found." assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna() features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
assert len(features) > 0, "No valid features found after dropping NaNs." assert len(features) > 0, "No valid features found after dropping NaNs."

View file

@ -24,7 +24,7 @@ from entropice.ml.models import (
get_model_hpo_config, get_model_hpo_config,
) )
from entropice.utils.paths import get_cv_results_dir from entropice.utils.paths import get_cv_results_dir
from entropice.utils.types import Model, Task from entropice.utils.types import Model, TargetDataset, Task
traceback.install() traceback.install()
pretty.install() pretty.install()
@ -75,6 +75,7 @@ class CVSettings:
n_iter: int = 2000 n_iter: int = 2000
task: Task = "binary" task: Task = "binary"
target: TargetDataset = "darts_v1"
model: Model = "espa" model: Model = "espa"
@ -106,7 +107,7 @@ def random_cv(
set_config(array_api_dispatch=use_array_api) set_config(array_api_dispatch=use_array_api)
print("Creating training data...") print("Creating training data...")
training_data = dataset_ensemble.create_training_set(task=settings.task, device=device) training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target, device=device)
model_hpo_config = get_model_hpo_config(settings.model, settings.task) model_hpo_config = get_model_hpo_config(settings.model, settings.task)
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}") print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)

View file

@ -1,33 +0,0 @@
"""Debug script to check what _prep_arcticdem returns for a batch."""
from entropice.ml.dataset import DatasetEnsemble
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=["ArcticDEM"],
add_lonlat=True,
filter_target=False,
)
# Get targets
targets = ensemble._read_target()
print(f"Total targets: {len(targets)}")
# Get first batch of targets
batch_targets = targets.iloc[:100]
print(f"\nBatch targets: {len(batch_targets)}")
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
# Try to prep ArcticDEM for this batch
print("\n" + "=" * 80)
print("Calling _prep_arcticdem...")
print("=" * 80)
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
print(
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
)
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")

View file

@ -1,72 +0,0 @@
"""Debug script to identify feature mismatch between training and inference."""
from entropice.ml.dataset import DatasetEnsemble
# Test with level 6 (the actual level used in production)
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
add_lonlat=True,
filter_target=False,
)
print("=" * 80)
print("Creating training dataset...")
print("=" * 80)
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_features = set(training_data.X.data.columns)
print(f"\nTraining dataset created with {len(training_features)} features")
print(f"Sample features: {sorted(list(training_features))[:10]}")
print("\n" + "=" * 80)
print("Creating inference batch...")
print("=" * 80)
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
# for batch in batch_generator:
if batch is None:
print("ERROR: No batch created!")
else:
print(f"\nBatch created with {len(batch.columns)} columns")
print(f"Batch columns: {sorted(batch.columns)[:15]}")
# Simulate the column dropping in predict_proba (inference.py)
cols_to_drop = ["geometry"]
if ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
print(f"\nColumns to drop: {cols_to_drop}")
inference_batch = batch.drop(columns=cols_to_drop)
inference_features = set(inference_batch.columns)
print(f"\nInference batch after dropping has {len(inference_features)} features")
print(f"Sample features: {sorted(list(inference_features))[:10]}")
print("\n" + "=" * 80)
print("COMPARISON")
print("=" * 80)
print(f"Training features: {len(training_features)}")
print(f"Inference features: {len(inference_features)}")
if training_features == inference_features:
print("\n✅ SUCCESS: Features match perfectly!")
else:
print("\n❌ MISMATCH DETECTED!")
only_in_training = training_features - inference_features
only_in_inference = inference_features - training_features
if only_in_training:
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
if only_in_inference:
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")

View file

@ -1,83 +0,0 @@
from itertools import product
import xarray as xr
from rich.progress import track
from entropice.utils.paths import (
get_arcticdem_stores,
get_embeddings_store,
get_era5_stores,
)
from entropice.utils.types import Grid, L2SourceDataset
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
"""Validate if the L2 dataset exists for the given grid and level.
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
l2ds (L2SourceDataset): The L2 source dataset to validate.
Returns:
bool: True if the dataset exists and does not contain NaNs, False otherwise.
"""
if l2ds == "ArcticDEM":
store = get_arcticdem_stores(grid, level)
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
agg = l2ds.split("-")[1]
store = get_era5_stores(agg, grid, level) # type: ignore
elif l2ds == "AlphaEarth":
store = get_embeddings_store(grid, level)
else:
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
if not store.exists():
print("\t Dataset store does not exist")
return False
ds = xr.open_zarr(store, consolidated=False)
has_nan = False
for var in ds.data_vars:
n_nans = ds[var].isnull().sum().compute().item()
if n_nans > 0:
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
has_nan = True
if has_nan:
return False
return True
def main():
grid_levels: set[tuple[Grid, int]] = {
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
"ERA5-seasonal",
"ERA5-yearly",
"AlphaEarth",
]
for (grid, level), l2ds in track(
product(grid_levels, l2_source_datasets),
total=len(grid_levels) * len(l2_source_datasets),
description="Validating L2 datasets...",
):
is_valid = validate_l2_dataset(grid, level, l2ds)
status = "VALID" if is_valid else "INVALID"
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
if __name__ == "__main__":
main()