Refactor autogluon
This commit is contained in:
parent
2cefe35690
commit
cfb7d65d6d
7 changed files with 305 additions and 232 deletions
|
|
@ -38,7 +38,7 @@ cli = cyclopts.App(name="alpha-earth")
|
|||
# 7454521782,230147807,10000000.
|
||||
|
||||
|
||||
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]:
|
||||
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||
|
||||
|
|
|
|||
270
src/entropice/ml/autogluon_training.py
Normal file
270
src/entropice/ml/autogluon_training.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
# ruff: noqa: N803, N806
|
||||
"""Training with AutoGluon TabularPredictor for automated ML."""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import cyclopts
|
||||
import pandas as pd
|
||||
import shap
|
||||
import toml
|
||||
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.inference import predict_proba
|
||||
from entropice.utils.paths import get_autogluon_results_dir
|
||||
from entropice.utils.types import TargetDataset, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=False)
|
||||
|
||||
cli = cyclopts.App("entropice-autogluon")
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AutoGluonSettings:
|
||||
"""AutoGluon training settings."""
|
||||
|
||||
task: Task = "binary"
|
||||
target: TargetDataset = "darts_v1"
|
||||
time_limit: int = 3600 # Time limit in seconds (1 hour default)
|
||||
presets: str = "best" # AutoGluon preset: 'best', 'high', 'good', 'medium'
|
||||
eval_metric: str | None = None # Evaluation metric, None for auto-detect
|
||||
num_bag_folds: int = 5 # Number of folds for bagging
|
||||
num_bag_sets: int = 1 # Number of bagging sets
|
||||
num_stack_levels: int = 1 # Number of stacking levels
|
||||
num_gpus: int = 1 # Number of GPUs to use
|
||||
verbosity: int = 2 # Verbosity level (0-4)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
||||
"""Combined settings for AutoGluon training."""
|
||||
|
||||
classes: list[str]
|
||||
problem_type: str
|
||||
|
||||
|
||||
def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
|
||||
"""Determine AutoGluon problem type and appropriate evaluation metric.
|
||||
|
||||
Args:
|
||||
task: The training task type
|
||||
|
||||
Returns:
|
||||
Tuple of (problem_type, eval_metric)
|
||||
|
||||
"""
|
||||
if task == "binary":
|
||||
return ("binary", "balanced_accuracy") # Good for imbalanced datasets
|
||||
elif task in ["count_regimes", "density_regimes"]:
|
||||
return ("multiclass", "f1_weighted") # Weighted F1 for multiclass
|
||||
elif task in ["count", "density"]:
|
||||
return ("regression", "mean_absolute_error")
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
|
||||
|
||||
@cli.default
|
||||
def autogluon_train(
|
||||
dataset_ensemble: DatasetEnsemble,
|
||||
settings: AutoGluonSettings = AutoGluonSettings(),
|
||||
):
|
||||
"""Train models using AutoGluon TabularPredictor.
|
||||
|
||||
Args:
|
||||
dataset_ensemble: Dataset ensemble configuration
|
||||
settings: AutoGluon training settings
|
||||
|
||||
"""
|
||||
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
||||
|
||||
# Convert to AutoGluon TabularDataset
|
||||
train_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("train")) # ty:ignore[invalid-assignment]
|
||||
test_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("test")) # ty:ignore[invalid-assignment]
|
||||
|
||||
print(f"\nTraining data: {len(train_data)} samples")
|
||||
print(f"Test data: {len(test_data)} samples")
|
||||
print(f"Features: {len(training_data.feature_names)}")
|
||||
print(f"Classes: {training_data.target_labels}")
|
||||
|
||||
# Determine problem type and metric
|
||||
problem_type, default_metric = _determine_problem_type_and_metric(settings.task)
|
||||
eval_metric = settings.eval_metric or default_metric
|
||||
|
||||
print(f"\n🎯 Problem type: {problem_type}")
|
||||
print(f"📈 Evaluation metric: {eval_metric}")
|
||||
|
||||
# Create results directory
|
||||
results_dir = get_autogluon_results_dir(
|
||||
grid=dataset_ensemble.grid,
|
||||
level=dataset_ensemble.level,
|
||||
task=settings.task,
|
||||
)
|
||||
print(f"\n💾 Results directory: {results_dir}")
|
||||
|
||||
# Initialize TabularPredictor
|
||||
print(f"\n🚀 Initializing AutoGluon TabularPredictor (preset='{settings.presets}')...")
|
||||
predictor = TabularPredictor(
|
||||
label="label",
|
||||
problem_type=problem_type,
|
||||
eval_metric=eval_metric,
|
||||
path=str(results_dir / "models"),
|
||||
verbosity=settings.verbosity,
|
||||
)
|
||||
|
||||
# Train models
|
||||
print(f"\n⚡ Training models (time_limit={settings.time_limit}s, num_gpus={settings.num_gpus})...")
|
||||
with stopwatch("AutoGluon training"):
|
||||
predictor.fit(
|
||||
train_data=train_data,
|
||||
time_limit=settings.time_limit,
|
||||
presets=settings.presets,
|
||||
num_bag_folds=settings.num_bag_folds,
|
||||
num_bag_sets=settings.num_bag_sets,
|
||||
num_stack_levels=settings.num_stack_levels,
|
||||
num_gpus=settings.num_gpus,
|
||||
ag_args_fit={"num_gpus": settings.num_gpus} if settings.num_gpus > 0 else None,
|
||||
)
|
||||
|
||||
# Evaluate on test data
|
||||
print("\n📊 Evaluating on test data...")
|
||||
test_score = predictor.evaluate(test_data, silent=True, detailed_report=True)
|
||||
print(f"Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||
|
||||
# Get leaderboard
|
||||
print("\n🏆 Model Leaderboard:")
|
||||
leaderboard = predictor.leaderboard(test_data, silent=True)
|
||||
print(leaderboard[["model", "score_test", "score_val", "pred_time_test", "fit_time"]].head(10))
|
||||
|
||||
# Save leaderboard
|
||||
leaderboard_file = results_dir / "leaderboard.parquet"
|
||||
print(f"\n💾 Saving leaderboard to {leaderboard_file}")
|
||||
leaderboard.to_parquet(leaderboard_file)
|
||||
|
||||
# Get feature importance
|
||||
print("\n🔍 Computing feature importance...")
|
||||
with stopwatch("Feature importance"):
|
||||
try:
|
||||
# Compute feature importance with reduced repeats
|
||||
feature_importance = predictor.feature_importance(
|
||||
test_data,
|
||||
num_shuffle_sets=3,
|
||||
subsample_size=min(500, len(test_data)), # Further subsample if needed
|
||||
)
|
||||
fi_file = results_dir / "feature_importance.parquet"
|
||||
print(f"💾 Saving feature importance to {fi_file}")
|
||||
feature_importance.to_parquet(fi_file)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not compute feature importance: {e}")
|
||||
|
||||
# Compute SHAP values for the best model
|
||||
print("\n🎨 Computing SHAP values...")
|
||||
with stopwatch("SHAP computation"):
|
||||
try:
|
||||
# Use a subset of test data for SHAP (can be slow)
|
||||
shap_sample_size = min(500, len(test_data))
|
||||
test_sample = test_data.sample(n=shap_sample_size, random_state=42)
|
||||
|
||||
# Get the best model name
|
||||
best_model = predictor.model_best
|
||||
|
||||
# Get predictions function for SHAP
|
||||
def predict_fn(X):
|
||||
"""Prediction function for SHAP."""
|
||||
df = pd.DataFrame(X, columns=training_data.feature_names)
|
||||
if problem_type == "binary":
|
||||
# For binary, return probability of positive class
|
||||
proba = predictor.predict_proba(df, model=best_model)
|
||||
return proba.iloc[:, 1].to_numpy() # ty:ignore[possibly-missing-attribute, invalid-argument-type]
|
||||
elif problem_type == "multiclass":
|
||||
# For multiclass, return all class probabilities
|
||||
return predictor.predict_proba(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
||||
else:
|
||||
# For regression
|
||||
return predictor.predict(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Create SHAP explainer
|
||||
# Use a background sample (smaller is faster)
|
||||
background_size = min(100, len(train_data))
|
||||
background = train_data.drop(columns=["label"]).sample(n=background_size, random_state=42).to_numpy()
|
||||
|
||||
explainer = shap.KernelExplainer(predict_fn, background)
|
||||
|
||||
# Compute SHAP values
|
||||
test_sample_X = test_sample.drop(columns=["label"]).to_numpy()
|
||||
shap_values = explainer.shap_values(test_sample_X)
|
||||
|
||||
# Save SHAP values
|
||||
shap_file = results_dir / "shap_values.pkl"
|
||||
print(f"💾 Saving SHAP values to {shap_file}")
|
||||
with open(shap_file, "wb") as f:
|
||||
pickle.dump(
|
||||
{
|
||||
"shap_values": shap_values,
|
||||
"base_values": explainer.expected_value,
|
||||
"data": test_sample_X,
|
||||
"feature_names": training_data.feature_names,
|
||||
},
|
||||
f,
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not compute SHAP values: {e}")
|
||||
import traceback as tb
|
||||
|
||||
tb.print_exc()
|
||||
|
||||
# Save training settings
|
||||
print("\n💾 Saving training settings...")
|
||||
combined_settings = AutoGluonTrainingSettings(
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
classes=training_data.y.labels,
|
||||
problem_type=problem_type,
|
||||
)
|
||||
settings_file = results_dir / "training_settings.toml"
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||
|
||||
# Save test metrics
|
||||
test_metrics_file = results_dir / "test_metrics.json"
|
||||
print(f"💾 Saving test metrics to {test_metrics_file}")
|
||||
with open(test_metrics_file, "w") as f:
|
||||
json.dump(test_score, f, indent=2)
|
||||
|
||||
# Generate predictions for all cells
|
||||
print("\n🗺️ Generating predictions for all cells...")
|
||||
with stopwatch("Prediction"):
|
||||
preds = predict_proba(dataset_ensemble, model=predictor, device="cpu")
|
||||
if training_data.targets["y"].dtype == "category":
|
||||
preds["predicted"] = preds["predicted"].astype("category")
|
||||
preds["predicted"].cat.categories = training_data.targets["y"].cat.categories
|
||||
# Save predictions
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"💾 Saving predictions to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print("✅ AutoGluon Training Complete!")
|
||||
print("=" * 80)
|
||||
print(f"\n📂 Results saved to: {results_dir}")
|
||||
print(f"🏆 Best model: {predictor.model_best}")
|
||||
print(f"📈 Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||
print(f"⏱️ Total models trained: {len(leaderboard)}")
|
||||
|
||||
stopwatch.summary()
|
||||
print("\nDone! 🎉")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
@ -241,6 +241,26 @@ class TrainingSet:
|
|||
def __len__(self):
|
||||
return len(self.z)
|
||||
|
||||
def to_dataframe(self, split: Literal["train", "test"] | None) -> pd.DataFrame:
|
||||
"""Create a simple DataFrame for training with e.g. autogluon.
|
||||
|
||||
This function creates a DataFrame with features and target labels (called "y") for training.
|
||||
The resulting dataframe contains no geometry information and drops all samples with NaN values.
|
||||
|
||||
Args:
|
||||
task (Task): The task.
|
||||
split (Literal["train", "test"] | None): If specified, only return the samples for the given split.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The training DataFrame.
|
||||
|
||||
"""
|
||||
dataset = self.targets[["y"]].join(self.features)
|
||||
if split is not None:
|
||||
dataset = dataset[self.split == split]
|
||||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||
return dataset
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
|
|
@ -249,7 +269,6 @@ class DatasetEnsemble:
|
|||
|
||||
grid: Grid
|
||||
level: int
|
||||
target: TargetDataset
|
||||
members: list[L2SourceDataset] = field(
|
||||
default_factory=lambda: [
|
||||
"AlphaEarth",
|
||||
|
|
@ -325,8 +344,8 @@ class DatasetEnsemble:
|
|||
return grid_gdf
|
||||
|
||||
@stopwatch.f("Getting target labels", print_kwargs=["task"])
|
||||
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
||||
"""Create a training target labels for a specific task.
|
||||
def get_targets(self, target: TargetDataset, task: Task) -> gpd.GeoDataFrame:
|
||||
"""Create a training target labels for a specific target and task.
|
||||
|
||||
The function reads the target dataset, filters it based on coverage,
|
||||
and prepares the target labels according to the specified task.
|
||||
|
|
@ -339,6 +358,7 @@ class DatasetEnsemble:
|
|||
|
||||
|
||||
Args:
|
||||
target (TargetDataset): The target dataset to use.
|
||||
task (Task): The task.
|
||||
|
||||
Returns:
|
||||
|
|
@ -349,13 +369,13 @@ class DatasetEnsemble:
|
|||
For regression tasks, both "y" and "z" are numerical and identical.
|
||||
|
||||
"""
|
||||
match (self.target, self.temporal_mode):
|
||||
match (target, self.temporal_mode):
|
||||
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
|
||||
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
|
||||
version: Literal["v1-l3", "v2-l3"] = target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
|
||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||
targets = xr.open_zarr(target_store, consolidated=False)
|
||||
case ("darts_v1" | "darts_v2", int()):
|
||||
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
||||
case ("darts_mllabels", str()): # Years are not supported
|
||||
|
|
@ -364,7 +384,7 @@ class DatasetEnsemble:
|
|||
)
|
||||
targets = xr.open_zarr(target_store, consolidated=False)
|
||||
case _:
|
||||
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.")
|
||||
raise NotImplementedError(f"Target {target} on {self.temporal_mode} mode not supported.")
|
||||
targets = cast(xr.Dataset, targets)
|
||||
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
|
||||
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
|
||||
|
|
@ -599,42 +619,11 @@ class DatasetEnsemble:
|
|||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||
|
||||
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"])
|
||||
def create_training_df(
|
||||
self,
|
||||
task: Task,
|
||||
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||
) -> pd.DataFrame:
|
||||
"""Create a simple DataFrame for training with e.g. autogluon.
|
||||
|
||||
This function creates a DataFrame with features and target labels (called "y") for training.
|
||||
The resulting dataframe contains no geometry information and drops all samples with NaN values.
|
||||
Does not split into train and test sets, also does not convert to arrays.
|
||||
For more advanced use cases, use `create_training_set`.
|
||||
|
||||
Args:
|
||||
task (Task): The task.
|
||||
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||
"none": No caching.
|
||||
"read": Read from cache if exists, otherwise create and save to cache.
|
||||
"overwrite": Always create and save to cache, overwriting existing cache.
|
||||
Defaults to "read".
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The training DataFrame.
|
||||
|
||||
"""
|
||||
targets = self.get_targets(task)
|
||||
assert len(targets) > 0, "No target samples found."
|
||||
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
|
||||
dataset = targets[["y"]].join(features).dropna()
|
||||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||
return dataset
|
||||
|
||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
|
||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
||||
def create_training_set(
|
||||
self,
|
||||
task: Task,
|
||||
target: TargetDataset,
|
||||
device: Literal["cpu", "cuda", "torch"] = "cpu",
|
||||
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||
) -> TrainingSet:
|
||||
|
|
@ -645,6 +634,7 @@ class DatasetEnsemble:
|
|||
|
||||
Args:
|
||||
task (Task): The task.
|
||||
target (TargetDataset): The target dataset to use.
|
||||
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
|
||||
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||
"none": No caching.
|
||||
|
|
@ -657,7 +647,7 @@ class DatasetEnsemble:
|
|||
Contains a lot of useful information for training and evaluation.
|
||||
|
||||
"""
|
||||
targets = self.get_targets(task)
|
||||
targets = self.get_targets(target, task)
|
||||
assert len(targets) > 0, "No target samples found."
|
||||
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
|
||||
assert len(features) > 0, "No valid features found after dropping NaNs."
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from entropice.ml.models import (
|
|||
get_model_hpo_config,
|
||||
)
|
||||
from entropice.utils.paths import get_cv_results_dir
|
||||
from entropice.utils.types import Model, Task
|
||||
from entropice.utils.types import Model, TargetDataset, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -75,6 +75,7 @@ class CVSettings:
|
|||
|
||||
n_iter: int = 2000
|
||||
task: Task = "binary"
|
||||
target: TargetDataset = "darts_v1"
|
||||
model: Model = "espa"
|
||||
|
||||
|
||||
|
|
@ -106,7 +107,7 @@ def random_cv(
|
|||
set_config(array_api_dispatch=use_array_api)
|
||||
|
||||
print("Creating training data...")
|
||||
training_data = dataset_ensemble.create_training_set(task=settings.task, device=device)
|
||||
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target, device=device)
|
||||
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
"""Debug script to check what _prep_arcticdem returns for a batch."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=["ArcticDEM"],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
# Get targets
|
||||
targets = ensemble._read_target()
|
||||
print(f"Total targets: {len(targets)}")
|
||||
|
||||
# Get first batch of targets
|
||||
batch_targets = targets.iloc[:100]
|
||||
print(f"\nBatch targets: {len(batch_targets)}")
|
||||
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
|
||||
|
||||
# Try to prep ArcticDEM for this batch
|
||||
print("\n" + "=" * 80)
|
||||
print("Calling _prep_arcticdem...")
|
||||
print("=" * 80)
|
||||
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
|
||||
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
|
||||
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
|
||||
print(
|
||||
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
|
||||
)
|
||||
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
"""Debug script to identify feature mismatch between training and inference."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
# Test with level 6 (the actual level used in production)
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print("Creating training dataset...")
|
||||
print("=" * 80)
|
||||
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_features = set(training_data.X.data.columns)
|
||||
print(f"\nTraining dataset created with {len(training_features)} features")
|
||||
print(f"Sample features: {sorted(list(training_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Creating inference batch...")
|
||||
print("=" * 80)
|
||||
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
# for batch in batch_generator:
|
||||
if batch is None:
|
||||
print("ERROR: No batch created!")
|
||||
else:
|
||||
print(f"\nBatch created with {len(batch.columns)} columns")
|
||||
print(f"Batch columns: {sorted(batch.columns)[:15]}")
|
||||
|
||||
# Simulate the column dropping in predict_proba (inference.py)
|
||||
cols_to_drop = ["geometry"]
|
||||
if ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
print(f"\nColumns to drop: {cols_to_drop}")
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_features = set(inference_batch.columns)
|
||||
|
||||
print(f"\nInference batch after dropping has {len(inference_features)} features")
|
||||
print(f"Sample features: {sorted(list(inference_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
print(f"Training features: {len(training_features)}")
|
||||
print(f"Inference features: {len(inference_features)}")
|
||||
|
||||
if training_features == inference_features:
|
||||
print("\n✅ SUCCESS: Features match perfectly!")
|
||||
else:
|
||||
print("\n❌ MISMATCH DETECTED!")
|
||||
only_in_training = training_features - inference_features
|
||||
only_in_inference = inference_features - training_features
|
||||
|
||||
if only_in_training:
|
||||
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
|
||||
if only_in_inference:
|
||||
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
from itertools import product
|
||||
|
||||
import xarray as xr
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.utils.paths import (
|
||||
get_arcticdem_stores,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
)
|
||||
from entropice.utils.types import Grid, L2SourceDataset
|
||||
|
||||
|
||||
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
|
||||
"""Validate if the L2 dataset exists for the given grid and level.
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
l2ds (L2SourceDataset): The L2 source dataset to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset exists and does not contain NaNs, False otherwise.
|
||||
|
||||
"""
|
||||
if l2ds == "ArcticDEM":
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
|
||||
agg = l2ds.split("-")[1]
|
||||
store = get_era5_stores(agg, grid, level) # type: ignore
|
||||
elif l2ds == "AlphaEarth":
|
||||
store = get_embeddings_store(grid, level)
|
||||
else:
|
||||
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
|
||||
|
||||
if not store.exists():
|
||||
print("\t Dataset store does not exist")
|
||||
return False
|
||||
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
has_nan = False
|
||||
for var in ds.data_vars:
|
||||
n_nans = ds[var].isnull().sum().compute().item()
|
||||
if n_nans > 0:
|
||||
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
|
||||
has_nan = True
|
||||
if has_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-yearly",
|
||||
"AlphaEarth",
|
||||
]
|
||||
|
||||
for (grid, level), l2ds in track(
|
||||
product(grid_levels, l2_source_datasets),
|
||||
total=len(grid_levels) * len(l2_source_datasets),
|
||||
description="Validating L2 datasets...",
|
||||
):
|
||||
is_valid = validate_l2_dataset(grid, level, l2ds)
|
||||
status = "VALID" if is_valid else "INVALID"
|
||||
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue