Refactor autogluon
This commit is contained in:
parent
2cefe35690
commit
cfb7d65d6d
7 changed files with 305 additions and 232 deletions
|
|
@ -38,7 +38,7 @@ cli = cyclopts.App(name="alpha-earth")
|
||||||
# 7454521782,230147807,10000000.
|
# 7454521782,230147807,10000000.
|
||||||
|
|
||||||
|
|
||||||
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]:
|
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||||
|
|
||||||
|
|
|
||||||
270
src/entropice/ml/autogluon_training.py
Normal file
270
src/entropice/ml/autogluon_training.py
Normal file
|
|
@ -0,0 +1,270 @@
|
||||||
|
# ruff: noqa: N803, N806
|
||||||
|
"""Training with AutoGluon TabularPredictor for automated ML."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
|
import pandas as pd
|
||||||
|
import shap
|
||||||
|
import toml
|
||||||
|
from autogluon.tabular import TabularDataset, TabularPredictor
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from sklearn import set_config
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.ml.inference import predict_proba
|
||||||
|
from entropice.utils.paths import get_autogluon_results_dir
|
||||||
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
set_config(array_api_dispatch=False)
|
||||||
|
|
||||||
|
cli = cyclopts.App("entropice-autogluon")
|
||||||
|
|
||||||
|
|
||||||
|
@cyclopts.Parameter("*")
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class AutoGluonSettings:
|
||||||
|
"""AutoGluon training settings."""
|
||||||
|
|
||||||
|
task: Task = "binary"
|
||||||
|
target: TargetDataset = "darts_v1"
|
||||||
|
time_limit: int = 3600 # Time limit in seconds (1 hour default)
|
||||||
|
presets: str = "best" # AutoGluon preset: 'best', 'high', 'good', 'medium'
|
||||||
|
eval_metric: str | None = None # Evaluation metric, None for auto-detect
|
||||||
|
num_bag_folds: int = 5 # Number of folds for bagging
|
||||||
|
num_bag_sets: int = 1 # Number of bagging sets
|
||||||
|
num_stack_levels: int = 1 # Number of stacking levels
|
||||||
|
num_gpus: int = 1 # Number of GPUs to use
|
||||||
|
verbosity: int = 2 # Verbosity level (0-4)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
||||||
|
"""Combined settings for AutoGluon training."""
|
||||||
|
|
||||||
|
classes: list[str]
|
||||||
|
problem_type: str
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
|
||||||
|
"""Determine AutoGluon problem type and appropriate evaluation metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The training task type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (problem_type, eval_metric)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if task == "binary":
|
||||||
|
return ("binary", "balanced_accuracy") # Good for imbalanced datasets
|
||||||
|
elif task in ["count_regimes", "density_regimes"]:
|
||||||
|
return ("multiclass", "f1_weighted") # Weighted F1 for multiclass
|
||||||
|
elif task in ["count", "density"]:
|
||||||
|
return ("regression", "mean_absolute_error")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown task: {task}")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default
|
||||||
|
def autogluon_train(
|
||||||
|
dataset_ensemble: DatasetEnsemble,
|
||||||
|
settings: AutoGluonSettings = AutoGluonSettings(),
|
||||||
|
):
|
||||||
|
"""Train models using AutoGluon TabularPredictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_ensemble: Dataset ensemble configuration
|
||||||
|
settings: AutoGluon training settings
|
||||||
|
|
||||||
|
"""
|
||||||
|
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target)
|
||||||
|
|
||||||
|
# Convert to AutoGluon TabularDataset
|
||||||
|
train_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("train")) # ty:ignore[invalid-assignment]
|
||||||
|
test_data: pd.DataFrame = TabularDataset(training_data.to_dataframe("test")) # ty:ignore[invalid-assignment]
|
||||||
|
|
||||||
|
print(f"\nTraining data: {len(train_data)} samples")
|
||||||
|
print(f"Test data: {len(test_data)} samples")
|
||||||
|
print(f"Features: {len(training_data.feature_names)}")
|
||||||
|
print(f"Classes: {training_data.target_labels}")
|
||||||
|
|
||||||
|
# Determine problem type and metric
|
||||||
|
problem_type, default_metric = _determine_problem_type_and_metric(settings.task)
|
||||||
|
eval_metric = settings.eval_metric or default_metric
|
||||||
|
|
||||||
|
print(f"\n🎯 Problem type: {problem_type}")
|
||||||
|
print(f"📈 Evaluation metric: {eval_metric}")
|
||||||
|
|
||||||
|
# Create results directory
|
||||||
|
results_dir = get_autogluon_results_dir(
|
||||||
|
grid=dataset_ensemble.grid,
|
||||||
|
level=dataset_ensemble.level,
|
||||||
|
task=settings.task,
|
||||||
|
)
|
||||||
|
print(f"\n💾 Results directory: {results_dir}")
|
||||||
|
|
||||||
|
# Initialize TabularPredictor
|
||||||
|
print(f"\n🚀 Initializing AutoGluon TabularPredictor (preset='{settings.presets}')...")
|
||||||
|
predictor = TabularPredictor(
|
||||||
|
label="label",
|
||||||
|
problem_type=problem_type,
|
||||||
|
eval_metric=eval_metric,
|
||||||
|
path=str(results_dir / "models"),
|
||||||
|
verbosity=settings.verbosity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train models
|
||||||
|
print(f"\n⚡ Training models (time_limit={settings.time_limit}s, num_gpus={settings.num_gpus})...")
|
||||||
|
with stopwatch("AutoGluon training"):
|
||||||
|
predictor.fit(
|
||||||
|
train_data=train_data,
|
||||||
|
time_limit=settings.time_limit,
|
||||||
|
presets=settings.presets,
|
||||||
|
num_bag_folds=settings.num_bag_folds,
|
||||||
|
num_bag_sets=settings.num_bag_sets,
|
||||||
|
num_stack_levels=settings.num_stack_levels,
|
||||||
|
num_gpus=settings.num_gpus,
|
||||||
|
ag_args_fit={"num_gpus": settings.num_gpus} if settings.num_gpus > 0 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate on test data
|
||||||
|
print("\n📊 Evaluating on test data...")
|
||||||
|
test_score = predictor.evaluate(test_data, silent=True, detailed_report=True)
|
||||||
|
print(f"Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||||
|
|
||||||
|
# Get leaderboard
|
||||||
|
print("\n🏆 Model Leaderboard:")
|
||||||
|
leaderboard = predictor.leaderboard(test_data, silent=True)
|
||||||
|
print(leaderboard[["model", "score_test", "score_val", "pred_time_test", "fit_time"]].head(10))
|
||||||
|
|
||||||
|
# Save leaderboard
|
||||||
|
leaderboard_file = results_dir / "leaderboard.parquet"
|
||||||
|
print(f"\n💾 Saving leaderboard to {leaderboard_file}")
|
||||||
|
leaderboard.to_parquet(leaderboard_file)
|
||||||
|
|
||||||
|
# Get feature importance
|
||||||
|
print("\n🔍 Computing feature importance...")
|
||||||
|
with stopwatch("Feature importance"):
|
||||||
|
try:
|
||||||
|
# Compute feature importance with reduced repeats
|
||||||
|
feature_importance = predictor.feature_importance(
|
||||||
|
test_data,
|
||||||
|
num_shuffle_sets=3,
|
||||||
|
subsample_size=min(500, len(test_data)), # Further subsample if needed
|
||||||
|
)
|
||||||
|
fi_file = results_dir / "feature_importance.parquet"
|
||||||
|
print(f"💾 Saving feature importance to {fi_file}")
|
||||||
|
feature_importance.to_parquet(fi_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not compute feature importance: {e}")
|
||||||
|
|
||||||
|
# Compute SHAP values for the best model
|
||||||
|
print("\n🎨 Computing SHAP values...")
|
||||||
|
with stopwatch("SHAP computation"):
|
||||||
|
try:
|
||||||
|
# Use a subset of test data for SHAP (can be slow)
|
||||||
|
shap_sample_size = min(500, len(test_data))
|
||||||
|
test_sample = test_data.sample(n=shap_sample_size, random_state=42)
|
||||||
|
|
||||||
|
# Get the best model name
|
||||||
|
best_model = predictor.model_best
|
||||||
|
|
||||||
|
# Get predictions function for SHAP
|
||||||
|
def predict_fn(X):
|
||||||
|
"""Prediction function for SHAP."""
|
||||||
|
df = pd.DataFrame(X, columns=training_data.feature_names)
|
||||||
|
if problem_type == "binary":
|
||||||
|
# For binary, return probability of positive class
|
||||||
|
proba = predictor.predict_proba(df, model=best_model)
|
||||||
|
return proba.iloc[:, 1].to_numpy() # ty:ignore[possibly-missing-attribute, invalid-argument-type]
|
||||||
|
elif problem_type == "multiclass":
|
||||||
|
# For multiclass, return all class probabilities
|
||||||
|
return predictor.predict_proba(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
||||||
|
else:
|
||||||
|
# For regression
|
||||||
|
return predictor.predict(df, model=best_model).to_numpy() # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
|
# Create SHAP explainer
|
||||||
|
# Use a background sample (smaller is faster)
|
||||||
|
background_size = min(100, len(train_data))
|
||||||
|
background = train_data.drop(columns=["label"]).sample(n=background_size, random_state=42).to_numpy()
|
||||||
|
|
||||||
|
explainer = shap.KernelExplainer(predict_fn, background)
|
||||||
|
|
||||||
|
# Compute SHAP values
|
||||||
|
test_sample_X = test_sample.drop(columns=["label"]).to_numpy()
|
||||||
|
shap_values = explainer.shap_values(test_sample_X)
|
||||||
|
|
||||||
|
# Save SHAP values
|
||||||
|
shap_file = results_dir / "shap_values.pkl"
|
||||||
|
print(f"💾 Saving SHAP values to {shap_file}")
|
||||||
|
with open(shap_file, "wb") as f:
|
||||||
|
pickle.dump(
|
||||||
|
{
|
||||||
|
"shap_values": shap_values,
|
||||||
|
"base_values": explainer.expected_value,
|
||||||
|
"data": test_sample_X,
|
||||||
|
"feature_names": training_data.feature_names,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
protocol=pickle.HIGHEST_PROTOCOL,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not compute SHAP values: {e}")
|
||||||
|
import traceback as tb
|
||||||
|
|
||||||
|
tb.print_exc()
|
||||||
|
|
||||||
|
# Save training settings
|
||||||
|
print("\n💾 Saving training settings...")
|
||||||
|
combined_settings = AutoGluonTrainingSettings(
|
||||||
|
**asdict(settings),
|
||||||
|
**asdict(dataset_ensemble),
|
||||||
|
classes=training_data.y.labels,
|
||||||
|
problem_type=problem_type,
|
||||||
|
)
|
||||||
|
settings_file = results_dir / "training_settings.toml"
|
||||||
|
with open(settings_file, "w") as f:
|
||||||
|
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||||
|
|
||||||
|
# Save test metrics
|
||||||
|
test_metrics_file = results_dir / "test_metrics.json"
|
||||||
|
print(f"💾 Saving test metrics to {test_metrics_file}")
|
||||||
|
with open(test_metrics_file, "w") as f:
|
||||||
|
json.dump(test_score, f, indent=2)
|
||||||
|
|
||||||
|
# Generate predictions for all cells
|
||||||
|
print("\n🗺️ Generating predictions for all cells...")
|
||||||
|
with stopwatch("Prediction"):
|
||||||
|
preds = predict_proba(dataset_ensemble, model=predictor, device="cpu")
|
||||||
|
if training_data.targets["y"].dtype == "category":
|
||||||
|
preds["predicted"] = preds["predicted"].astype("category")
|
||||||
|
preds["predicted"].cat.categories = training_data.targets["y"].cat.categories
|
||||||
|
# Save predictions
|
||||||
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
|
print(f"💾 Saving predictions to {preds_file}")
|
||||||
|
preds.to_parquet(preds_file)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("✅ AutoGluon Training Complete!")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"\n📂 Results saved to: {results_dir}")
|
||||||
|
print(f"🏆 Best model: {predictor.model_best}")
|
||||||
|
print(f"📈 Test {eval_metric}: {test_score[eval_metric]:.4f}")
|
||||||
|
print(f"⏱️ Total models trained: {len(leaderboard)}")
|
||||||
|
|
||||||
|
stopwatch.summary()
|
||||||
|
print("\nDone! 🎉")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
|
|
@ -241,6 +241,26 @@ class TrainingSet:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.z)
|
return len(self.z)
|
||||||
|
|
||||||
|
def to_dataframe(self, split: Literal["train", "test"] | None) -> pd.DataFrame:
|
||||||
|
"""Create a simple DataFrame for training with e.g. autogluon.
|
||||||
|
|
||||||
|
This function creates a DataFrame with features and target labels (called "y") for training.
|
||||||
|
The resulting dataframe contains no geometry information and drops all samples with NaN values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task.
|
||||||
|
split (Literal["train", "test"] | None): If specified, only return the samples for the given split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: The training DataFrame.
|
||||||
|
|
||||||
|
"""
|
||||||
|
dataset = self.targets[["y"]].join(self.features)
|
||||||
|
if split is not None:
|
||||||
|
dataset = dataset[self.split == split]
|
||||||
|
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
|
@ -249,7 +269,6 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
grid: Grid
|
grid: Grid
|
||||||
level: int
|
level: int
|
||||||
target: TargetDataset
|
|
||||||
members: list[L2SourceDataset] = field(
|
members: list[L2SourceDataset] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
"AlphaEarth",
|
"AlphaEarth",
|
||||||
|
|
@ -325,8 +344,8 @@ class DatasetEnsemble:
|
||||||
return grid_gdf
|
return grid_gdf
|
||||||
|
|
||||||
@stopwatch.f("Getting target labels", print_kwargs=["task"])
|
@stopwatch.f("Getting target labels", print_kwargs=["task"])
|
||||||
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
def get_targets(self, target: TargetDataset, task: Task) -> gpd.GeoDataFrame:
|
||||||
"""Create a training target labels for a specific task.
|
"""Create a training target labels for a specific target and task.
|
||||||
|
|
||||||
The function reads the target dataset, filters it based on coverage,
|
The function reads the target dataset, filters it based on coverage,
|
||||||
and prepares the target labels according to the specified task.
|
and prepares the target labels according to the specified task.
|
||||||
|
|
@ -339,6 +358,7 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
target (TargetDataset): The target dataset to use.
|
||||||
task (Task): The task.
|
task (Task): The task.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -349,13 +369,13 @@ class DatasetEnsemble:
|
||||||
For regression tasks, both "y" and "z" are numerical and identical.
|
For regression tasks, both "y" and "z" are numerical and identical.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
match (self.target, self.temporal_mode):
|
match (target, self.temporal_mode):
|
||||||
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
|
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
|
||||||
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
|
version: Literal["v1-l3", "v2-l3"] = target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
|
||||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||||
targets = xr.open_zarr(target_store, consolidated=False)
|
targets = xr.open_zarr(target_store, consolidated=False)
|
||||||
case ("darts_v1" | "darts_v2", int()):
|
case ("darts_v1" | "darts_v2", int()):
|
||||||
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment]
|
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||||
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
||||||
case ("darts_mllabels", str()): # Years are not supported
|
case ("darts_mllabels", str()): # Years are not supported
|
||||||
|
|
@ -364,7 +384,7 @@ class DatasetEnsemble:
|
||||||
)
|
)
|
||||||
targets = xr.open_zarr(target_store, consolidated=False)
|
targets = xr.open_zarr(target_store, consolidated=False)
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.")
|
raise NotImplementedError(f"Target {target} on {self.temporal_mode} mode not supported.")
|
||||||
targets = cast(xr.Dataset, targets)
|
targets = cast(xr.Dataset, targets)
|
||||||
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
|
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
|
||||||
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
|
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
|
||||||
|
|
@ -599,42 +619,11 @@ class DatasetEnsemble:
|
||||||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||||
|
|
||||||
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"])
|
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
||||||
def create_training_df(
|
|
||||||
self,
|
|
||||||
task: Task,
|
|
||||||
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Create a simple DataFrame for training with e.g. autogluon.
|
|
||||||
|
|
||||||
This function creates a DataFrame with features and target labels (called "y") for training.
|
|
||||||
The resulting dataframe contains no geometry information and drops all samples with NaN values.
|
|
||||||
Does not split into train and test sets, also does not convert to arrays.
|
|
||||||
For more advanced use cases, use `create_training_set`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (Task): The task.
|
|
||||||
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
|
||||||
"none": No caching.
|
|
||||||
"read": Read from cache if exists, otherwise create and save to cache.
|
|
||||||
"overwrite": Always create and save to cache, overwriting existing cache.
|
|
||||||
Defaults to "read".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pd.DataFrame: The training DataFrame.
|
|
||||||
|
|
||||||
"""
|
|
||||||
targets = self.get_targets(task)
|
|
||||||
assert len(targets) > 0, "No target samples found."
|
|
||||||
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
|
|
||||||
dataset = targets[["y"]].join(features).dropna()
|
|
||||||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
|
|
||||||
def create_training_set(
|
def create_training_set(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
|
target: TargetDataset,
|
||||||
device: Literal["cpu", "cuda", "torch"] = "cpu",
|
device: Literal["cpu", "cuda", "torch"] = "cpu",
|
||||||
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||||
) -> TrainingSet:
|
) -> TrainingSet:
|
||||||
|
|
@ -645,6 +634,7 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (Task): The task.
|
task (Task): The task.
|
||||||
|
target (TargetDataset): The target dataset to use.
|
||||||
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
|
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
|
||||||
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||||
"none": No caching.
|
"none": No caching.
|
||||||
|
|
@ -657,7 +647,7 @@ class DatasetEnsemble:
|
||||||
Contains a lot of useful information for training and evaluation.
|
Contains a lot of useful information for training and evaluation.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
targets = self.get_targets(task)
|
targets = self.get_targets(target, task)
|
||||||
assert len(targets) > 0, "No target samples found."
|
assert len(targets) > 0, "No target samples found."
|
||||||
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
|
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
|
||||||
assert len(features) > 0, "No valid features found after dropping NaNs."
|
assert len(features) > 0, "No valid features found after dropping NaNs."
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from entropice.ml.models import (
|
||||||
get_model_hpo_config,
|
get_model_hpo_config,
|
||||||
)
|
)
|
||||||
from entropice.utils.paths import get_cv_results_dir
|
from entropice.utils.paths import get_cv_results_dir
|
||||||
from entropice.utils.types import Model, Task
|
from entropice.utils.types import Model, TargetDataset, Task
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -75,6 +75,7 @@ class CVSettings:
|
||||||
|
|
||||||
n_iter: int = 2000
|
n_iter: int = 2000
|
||||||
task: Task = "binary"
|
task: Task = "binary"
|
||||||
|
target: TargetDataset = "darts_v1"
|
||||||
model: Model = "espa"
|
model: Model = "espa"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -106,7 +107,7 @@ def random_cv(
|
||||||
set_config(array_api_dispatch=use_array_api)
|
set_config(array_api_dispatch=use_array_api)
|
||||||
|
|
||||||
print("Creating training data...")
|
print("Creating training data...")
|
||||||
training_data = dataset_ensemble.create_training_set(task=settings.task, device=device)
|
training_data = dataset_ensemble.create_training_set(task=settings.task, target=settings.target, device=device)
|
||||||
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
model_hpo_config = get_model_hpo_config(settings.model, settings.task)
|
||||||
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
print(f"Using model: {settings.model} with parameters: {model_hpo_config.hp_config}")
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
|
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
"""Debug script to check what _prep_arcticdem returns for a batch."""
|
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
|
||||||
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid="healpix",
|
|
||||||
level=10,
|
|
||||||
target="darts_mllabels",
|
|
||||||
members=["ArcticDEM"],
|
|
||||||
add_lonlat=True,
|
|
||||||
filter_target=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get targets
|
|
||||||
targets = ensemble._read_target()
|
|
||||||
print(f"Total targets: {len(targets)}")
|
|
||||||
|
|
||||||
# Get first batch of targets
|
|
||||||
batch_targets = targets.iloc[:100]
|
|
||||||
print(f"\nBatch targets: {len(batch_targets)}")
|
|
||||||
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
|
|
||||||
|
|
||||||
# Try to prep ArcticDEM for this batch
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("Calling _prep_arcticdem...")
|
|
||||||
print("=" * 80)
|
|
||||||
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
|
|
||||||
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
|
|
||||||
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
|
|
||||||
print(
|
|
||||||
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
|
|
||||||
)
|
|
||||||
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
"""Debug script to identify feature mismatch between training and inference."""
|
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
|
||||||
|
|
||||||
# Test with level 6 (the actual level used in production)
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid="healpix",
|
|
||||||
level=10,
|
|
||||||
target="darts_mllabels",
|
|
||||||
members=[
|
|
||||||
"AlphaEarth",
|
|
||||||
"ArcticDEM",
|
|
||||||
"ERA5-yearly",
|
|
||||||
"ERA5-seasonal",
|
|
||||||
"ERA5-shoulder",
|
|
||||||
],
|
|
||||||
add_lonlat=True,
|
|
||||||
filter_target=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
print("Creating training dataset...")
|
|
||||||
print("=" * 80)
|
|
||||||
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
|
||||||
training_features = set(training_data.X.data.columns)
|
|
||||||
print(f"\nTraining dataset created with {len(training_features)} features")
|
|
||||||
print(f"Sample features: {sorted(list(training_features))[:10]}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("Creating inference batch...")
|
|
||||||
print("=" * 80)
|
|
||||||
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
|
|
||||||
batch = next(batch_generator, None)
|
|
||||||
# for batch in batch_generator:
|
|
||||||
if batch is None:
|
|
||||||
print("ERROR: No batch created!")
|
|
||||||
else:
|
|
||||||
print(f"\nBatch created with {len(batch.columns)} columns")
|
|
||||||
print(f"Batch columns: {sorted(batch.columns)[:15]}")
|
|
||||||
|
|
||||||
# Simulate the column dropping in predict_proba (inference.py)
|
|
||||||
cols_to_drop = ["geometry"]
|
|
||||||
if ensemble.target == "darts_mllabels":
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
|
||||||
else:
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
||||||
|
|
||||||
print(f"\nColumns to drop: {cols_to_drop}")
|
|
||||||
|
|
||||||
inference_batch = batch.drop(columns=cols_to_drop)
|
|
||||||
inference_features = set(inference_batch.columns)
|
|
||||||
|
|
||||||
print(f"\nInference batch after dropping has {len(inference_features)} features")
|
|
||||||
print(f"Sample features: {sorted(list(inference_features))[:10]}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("COMPARISON")
|
|
||||||
print("=" * 80)
|
|
||||||
print(f"Training features: {len(training_features)}")
|
|
||||||
print(f"Inference features: {len(inference_features)}")
|
|
||||||
|
|
||||||
if training_features == inference_features:
|
|
||||||
print("\n✅ SUCCESS: Features match perfectly!")
|
|
||||||
else:
|
|
||||||
print("\n❌ MISMATCH DETECTED!")
|
|
||||||
only_in_training = training_features - inference_features
|
|
||||||
only_in_inference = inference_features - training_features
|
|
||||||
|
|
||||||
if only_in_training:
|
|
||||||
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
|
|
||||||
if only_in_inference:
|
|
||||||
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
import xarray as xr
|
|
||||||
from rich.progress import track
|
|
||||||
|
|
||||||
from entropice.utils.paths import (
|
|
||||||
get_arcticdem_stores,
|
|
||||||
get_embeddings_store,
|
|
||||||
get_era5_stores,
|
|
||||||
)
|
|
||||||
from entropice.utils.types import Grid, L2SourceDataset
|
|
||||||
|
|
||||||
|
|
||||||
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
|
|
||||||
"""Validate if the L2 dataset exists for the given grid and level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grid (Grid): The grid type to use.
|
|
||||||
level (int): The grid level to use.
|
|
||||||
l2ds (L2SourceDataset): The L2 source dataset to validate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the dataset exists and does not contain NaNs, False otherwise.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if l2ds == "ArcticDEM":
|
|
||||||
store = get_arcticdem_stores(grid, level)
|
|
||||||
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
|
|
||||||
agg = l2ds.split("-")[1]
|
|
||||||
store = get_era5_stores(agg, grid, level) # type: ignore
|
|
||||||
elif l2ds == "AlphaEarth":
|
|
||||||
store = get_embeddings_store(grid, level)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
|
|
||||||
|
|
||||||
if not store.exists():
|
|
||||||
print("\t Dataset store does not exist")
|
|
||||||
return False
|
|
||||||
|
|
||||||
ds = xr.open_zarr(store, consolidated=False)
|
|
||||||
has_nan = False
|
|
||||||
for var in ds.data_vars:
|
|
||||||
n_nans = ds[var].isnull().sum().compute().item()
|
|
||||||
if n_nans > 0:
|
|
||||||
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
|
|
||||||
has_nan = True
|
|
||||||
if has_nan:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
grid_levels: set[tuple[Grid, int]] = {
|
|
||||||
("hex", 3),
|
|
||||||
("hex", 4),
|
|
||||||
("hex", 5),
|
|
||||||
("hex", 6),
|
|
||||||
("healpix", 6),
|
|
||||||
("healpix", 7),
|
|
||||||
("healpix", 8),
|
|
||||||
("healpix", 9),
|
|
||||||
("healpix", 10),
|
|
||||||
}
|
|
||||||
l2_source_datasets: list[L2SourceDataset] = [
|
|
||||||
"ArcticDEM",
|
|
||||||
"ERA5-shoulder",
|
|
||||||
"ERA5-seasonal",
|
|
||||||
"ERA5-yearly",
|
|
||||||
"AlphaEarth",
|
|
||||||
]
|
|
||||||
|
|
||||||
for (grid, level), l2ds in track(
|
|
||||||
product(grid_levels, l2_source_datasets),
|
|
||||||
total=len(grid_levels) * len(l2_source_datasets),
|
|
||||||
description="Validating L2 datasets...",
|
|
||||||
):
|
|
||||||
is_valid = validate_l2_dataset(grid, level, l2ds)
|
|
||||||
status = "VALID" if is_valid else "INVALID"
|
|
||||||
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue