Refactor autogluon
This commit is contained in:
parent
2cefe35690
commit
cfb7d65d6d
7 changed files with 305 additions and 232 deletions
|
|
@ -1,33 +0,0 @@
|
|||
"""Debug script to check what _prep_arcticdem returns for a batch."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=["ArcticDEM"],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
# Get targets
|
||||
targets = ensemble._read_target()
|
||||
print(f"Total targets: {len(targets)}")
|
||||
|
||||
# Get first batch of targets
|
||||
batch_targets = targets.iloc[:100]
|
||||
print(f"\nBatch targets: {len(batch_targets)}")
|
||||
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
|
||||
|
||||
# Try to prep ArcticDEM for this batch
|
||||
print("\n" + "=" * 80)
|
||||
print("Calling _prep_arcticdem...")
|
||||
print("=" * 80)
|
||||
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
|
||||
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
|
||||
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
|
||||
print(
|
||||
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
|
||||
)
|
||||
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
"""Debug script to identify feature mismatch between training and inference."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
# Test with level 6 (the actual level used in production)
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print("Creating training dataset...")
|
||||
print("=" * 80)
|
||||
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_features = set(training_data.X.data.columns)
|
||||
print(f"\nTraining dataset created with {len(training_features)} features")
|
||||
print(f"Sample features: {sorted(list(training_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Creating inference batch...")
|
||||
print("=" * 80)
|
||||
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
# for batch in batch_generator:
|
||||
if batch is None:
|
||||
print("ERROR: No batch created!")
|
||||
else:
|
||||
print(f"\nBatch created with {len(batch.columns)} columns")
|
||||
print(f"Batch columns: {sorted(batch.columns)[:15]}")
|
||||
|
||||
# Simulate the column dropping in predict_proba (inference.py)
|
||||
cols_to_drop = ["geometry"]
|
||||
if ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
print(f"\nColumns to drop: {cols_to_drop}")
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_features = set(inference_batch.columns)
|
||||
|
||||
print(f"\nInference batch after dropping has {len(inference_features)} features")
|
||||
print(f"Sample features: {sorted(list(inference_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
print(f"Training features: {len(training_features)}")
|
||||
print(f"Inference features: {len(inference_features)}")
|
||||
|
||||
if training_features == inference_features:
|
||||
print("\n✅ SUCCESS: Features match perfectly!")
|
||||
else:
|
||||
print("\n❌ MISMATCH DETECTED!")
|
||||
only_in_training = training_features - inference_features
|
||||
only_in_inference = inference_features - training_features
|
||||
|
||||
if only_in_training:
|
||||
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
|
||||
if only_in_inference:
|
||||
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
from itertools import product
|
||||
|
||||
import xarray as xr
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.utils.paths import (
|
||||
get_arcticdem_stores,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
)
|
||||
from entropice.utils.types import Grid, L2SourceDataset
|
||||
|
||||
|
||||
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
|
||||
"""Validate if the L2 dataset exists for the given grid and level.
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
l2ds (L2SourceDataset): The L2 source dataset to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset exists and does not contain NaNs, False otherwise.
|
||||
|
||||
"""
|
||||
if l2ds == "ArcticDEM":
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
|
||||
agg = l2ds.split("-")[1]
|
||||
store = get_era5_stores(agg, grid, level) # type: ignore
|
||||
elif l2ds == "AlphaEarth":
|
||||
store = get_embeddings_store(grid, level)
|
||||
else:
|
||||
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
|
||||
|
||||
if not store.exists():
|
||||
print("\t Dataset store does not exist")
|
||||
return False
|
||||
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
has_nan = False
|
||||
for var in ds.data_vars:
|
||||
n_nans = ds[var].isnull().sum().compute().item()
|
||||
if n_nans > 0:
|
||||
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
|
||||
has_nan = True
|
||||
if has_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-yearly",
|
||||
"AlphaEarth",
|
||||
]
|
||||
|
||||
for (grid, level), l2ds in track(
|
||||
product(grid_levels, l2_source_datasets),
|
||||
total=len(grid_levels) * len(l2_source_datasets),
|
||||
description="Validating L2 datasets...",
|
||||
):
|
||||
is_valid = validate_l2_dataset(grid, level, l2ds)
|
||||
status = "VALID" if is_valid else "INVALID"
|
||||
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue