entropice/tests/debug_feature_mismatch.py

73 lines
2.5 KiB
Python
Raw Normal View History

"""Debug script to identify feature mismatch between training and inference."""
from entropice.ml.dataset import DatasetEnsemble
# Test with level 6 (the actual level used in production)
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
add_lonlat=True,
filter_target=False,
)
print("=" * 80)
print("Creating training dataset...")
print("=" * 80)
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_features = set(training_data.X.data.columns)
print(f"\nTraining dataset created with {len(training_features)} features")
print(f"Sample features: {sorted(list(training_features))[:10]}")
print("\n" + "=" * 80)
print("Creating inference batch...")
print("=" * 80)
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
# for batch in batch_generator:
if batch is None:
print("ERROR: No batch created!")
else:
print(f"\nBatch created with {len(batch.columns)} columns")
print(f"Batch columns: {sorted(batch.columns)[:15]}")
# Simulate the column dropping in predict_proba (inference.py)
cols_to_drop = ["geometry"]
if ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
print(f"\nColumns to drop: {cols_to_drop}")
inference_batch = batch.drop(columns=cols_to_drop)
inference_features = set(inference_batch.columns)
print(f"\nInference batch after dropping has {len(inference_features)} features")
print(f"Sample features: {sorted(list(inference_features))[:10]}")
print("\n" + "=" * 80)
print("COMPARISON")
print("=" * 80)
print(f"Training features: {len(training_features)}")
print(f"Inference features: {len(inference_features)}")
if training_features == inference_features:
print("\n✅ SUCCESS: Features match perfectly!")
else:
print("\n❌ MISMATCH DETECTED!")
only_in_training = training_features - inference_features
only_in_inference = inference_features - training_features
if only_in_training:
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
if only_in_inference:
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")