Refactor autogluon

This commit is contained in:
Tobias Hölzer 2026-01-11 20:51:53 +01:00
parent 2cefe35690
commit cfb7d65d6d
7 changed files with 305 additions and 232 deletions

View file

@ -1,33 +0,0 @@
"""Debug script to check what _prep_arcticdem returns for a batch."""
from entropice.ml.dataset import DatasetEnsemble
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=["ArcticDEM"],
add_lonlat=True,
filter_target=False,
)
# Get targets
targets = ensemble._read_target()
print(f"Total targets: {len(targets)}")
# Get first batch of targets
batch_targets = targets.iloc[:100]
print(f"\nBatch targets: {len(batch_targets)}")
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
# Try to prep ArcticDEM for this batch
print("\n" + "=" * 80)
print("Calling _prep_arcticdem...")
print("=" * 80)
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
print(
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
)
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")

View file

@ -1,72 +0,0 @@
"""Debug script to identify feature mismatch between training and inference."""
from entropice.ml.dataset import DatasetEnsemble
# Test with level 6 (the actual level used in production)
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
add_lonlat=True,
filter_target=False,
)
print("=" * 80)
print("Creating training dataset...")
print("=" * 80)
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_features = set(training_data.X.data.columns)
print(f"\nTraining dataset created with {len(training_features)} features")
print(f"Sample features: {sorted(list(training_features))[:10]}")
print("\n" + "=" * 80)
print("Creating inference batch...")
print("=" * 80)
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
# for batch in batch_generator:
if batch is None:
print("ERROR: No batch created!")
else:
print(f"\nBatch created with {len(batch.columns)} columns")
print(f"Batch columns: {sorted(batch.columns)[:15]}")
# Simulate the column dropping in predict_proba (inference.py)
cols_to_drop = ["geometry"]
if ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
print(f"\nColumns to drop: {cols_to_drop}")
inference_batch = batch.drop(columns=cols_to_drop)
inference_features = set(inference_batch.columns)
print(f"\nInference batch after dropping has {len(inference_features)} features")
print(f"Sample features: {sorted(list(inference_features))[:10]}")
print("\n" + "=" * 80)
print("COMPARISON")
print("=" * 80)
print(f"Training features: {len(training_features)}")
print(f"Inference features: {len(inference_features)}")
if training_features == inference_features:
print("\n✅ SUCCESS: Features match perfectly!")
else:
print("\n❌ MISMATCH DETECTED!")
only_in_training = training_features - inference_features
only_in_inference = inference_features - training_features
if only_in_training:
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
if only_in_inference:
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")

View file

@ -1,83 +0,0 @@
from itertools import product
import xarray as xr
from rich.progress import track
from entropice.utils.paths import (
get_arcticdem_stores,
get_embeddings_store,
get_era5_stores,
)
from entropice.utils.types import Grid, L2SourceDataset
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
"""Validate if the L2 dataset exists for the given grid and level.
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
l2ds (L2SourceDataset): The L2 source dataset to validate.
Returns:
bool: True if the dataset exists and does not contain NaNs, False otherwise.
"""
if l2ds == "ArcticDEM":
store = get_arcticdem_stores(grid, level)
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
agg = l2ds.split("-")[1]
store = get_era5_stores(agg, grid, level) # type: ignore
elif l2ds == "AlphaEarth":
store = get_embeddings_store(grid, level)
else:
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
if not store.exists():
print("\t Dataset store does not exist")
return False
ds = xr.open_zarr(store, consolidated=False)
has_nan = False
for var in ds.data_vars:
n_nans = ds[var].isnull().sum().compute().item()
if n_nans > 0:
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
has_nan = True
if has_nan:
return False
return True
def main():
grid_levels: set[tuple[Grid, int]] = {
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
"ERA5-seasonal",
"ERA5-yearly",
"AlphaEarth",
]
for (grid, level), l2ds in track(
product(grid_levels, l2_source_datasets),
total=len(grid_levels) * len(l2_source_datasets),
description="Validating L2 datasets...",
):
is_valid = validate_l2_dataset(grid, level, l2ds)
status = "VALID" if is_valid else "INVALID"
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
if __name__ == "__main__":
main()