From d22b857722074384eae879555b1dbd3df19c6a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 11 Dec 2025 17:14:02 +0100 Subject: [PATCH] Make the training work again --- src/entropice/darts.py | 4 +- src/entropice/dataset.py | 10 +- src/entropice/inference.py | 26 ++-- src/entropice/training.py | 249 ++++++++++++++++++++----------------- 4 files changed, 156 insertions(+), 133 deletions(-) diff --git a/src/entropice/darts.py b/src/entropice/darts.py index f45e415..74638e7 100644 --- a/src/entropice/darts.py +++ b/src/entropice/darts.py @@ -129,8 +129,8 @@ def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int): covered = ~grid_gdf["dartsml_coverage"].isna() grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0) - grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna() - grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna() + grid_gdf["dartsml_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna() + grid_gdf["dartsml_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna() output_path = get_darts_rts_file(grid, level, labels=True) grid_gdf.to_parquet(output_path) diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index dfcd36c..ef54dc7 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -16,6 +16,7 @@ import json from dataclasses import asdict, dataclass, field from typing import Literal +import cyclopts import geopandas as gpd import pandas as pd import seaborn as sns @@ -55,14 +56,17 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", " type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] +@cyclopts.Parameter("*") @dataclass class DatasetEnsemble: grid: Literal["hex", "healpix"] level: int target: Literal["darts_rts", "darts_mllabels"] - members: list[L2Dataset] - dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict) - variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict) + members: list[L2Dataset] = field( + default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] + ) + dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) + variable_filters: dict[str, list[str]] = field(default_factory=dict) filter_target: str | Literal[False] = False add_lonlat: bool = True diff --git a/src/entropice/inference.py b/src/entropice/inference.py index 846b29c..1e4ce8b 100644 --- a/src/entropice/inference.py +++ b/src/entropice/inference.py @@ -1,16 +1,14 @@ # ruff: noqa: N806 """Inference runs on trained models.""" -from typing import Literal - import geopandas as gpd import pandas as pd import torch -from entropy import ESPAClassifier from rich import pretty, traceback from sklearn import set_config +from sklearn.base import BaseEstimator -from entropice.paths import get_train_dataset_file +from entropice.dataset import DatasetEnsemble traceback.install() pretty.install() @@ -18,32 +16,34 @@ pretty.install() set_config(array_api_dispatch=True) -def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame: +def predict_proba(e: DatasetEnsemble, clf: BaseEstimator, classes: list) -> gpd.GeoDataFrame: """Get predicted probabilities for each cell. Args: - grid (Literal["hex", "healpix"]): The grid type to use. - level (int): The grid level to use. - clf (ESPAClassifier): The trained classifier to use for predictions. + e (DatasetEnsemble): The dataset ensemble configuration. + clf (BaseEstimator): The trained classifier to use for predictions. classes (list): List of class names. Returns: list: A list of predicted probabilities for each cell. """ - data = get_train_dataset_file(grid=grid, level=level) - data = gpd.read_parquet(data) + data = e.create() print(f"Predicting probabilities for {len(data)} cells...") # Predict in batches to avoid memory issues batch_size = 10_000 preds = [] + + cols_to_drop = ["geometry"] + if e.target == "darts_mllabels": + cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in data.columns if col.startswith("darts_")] for i in range(0, len(data), batch_size): batch = data.iloc[i : i + batch_size] - cols_to_drop = ["cell_id", "geometry", "darts_has_rts"] - cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] X_batch = batch.drop(columns=cols_to_drop).dropna() - cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy() + cell_ids = X_batch.index.to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() X_batch = X_batch.to_numpy(dtype="float64") X_batch = torch.asarray(X_batch, device=0) diff --git a/src/entropice/training.py b/src/entropice/training.py index 182783b..9cadf22 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -2,10 +2,10 @@ """Training of classification models training.""" import pickle +from dataclasses import asdict, dataclass from typing import Literal import cyclopts -import geopandas as gpd import pandas as pd import toml import torch @@ -15,63 +15,85 @@ from cuml.neighbors import KNeighborsClassifier from entropy import ESPAClassifier from rich import pretty, traceback from scipy.stats import loguniform, randint +from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen from sklearn import set_config from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from stopuhr import stopwatch from xgboost import XGBClassifier +from entropice.dataset import DatasetEnsemble from entropice.inference import predict_proba -from entropice.paths import ( - get_cv_results_dir, - get_train_dataset_file, -) +from entropice.paths import get_cv_results_dir traceback.install() pretty.install() set_config(array_api_dispatch=True) +cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) -def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"): - """Create X and y data from the training dataset. +_metrics = { + "binary": ["accuracy", "recall", "precision", "f1", "jaccard"], + "multiclass": [ + "accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted" + "f1_macro", + "f1_weighted", + "precision_macro", + "precision_weighted", + "recall_macro", + "jaccard_micro", + "jaccard_macro", + "jaccard_weighted", + ], +} - Args: - grid (Literal["hex", "healpix"]): The grid type to use. - level (int): The grid level to use. - task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". - Returns: - Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names. +@cyclopts.Parameter("*") +@dataclass +class CVSettings: + n_iter: int = 2000 + robust: bool = False + task: Literal["binary", "count", "density"] = "binary" + model: Literal["espa", "xgboost", "rf", "knn"] = "espa" - """ - data = get_train_dataset_file(grid=grid, level=level) - data = gpd.read_parquet(data) - data = data[data["darts_has_coverage"]] - cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"] - cols_to_drop += [col for col in data.columns if col.startswith("darts_")] +def _create_xy_data(e: DatasetEnsemble, task: Literal["binary", "count", "density"] = "binary"): + data = e.create() + + covcol = "dartsml_has_coverage" if e.target == "darts_mllabels" else "darts_has_coverage" + bincol = "dartsml_has_rts" if e.target == "darts_mllabels" else "darts_has_rts" + countcol = "dartsml_rts_count" if e.target == "darts_mllabels" else "darts_rts_count" + densitycol = "dartsml_rts_density" if e.target == "darts_mllabels" else "darts_rts_density" + + data = data[data[covcol]].reset_index(drop=True) + + cols_to_drop = ["geometry"] + if e.target == "darts_mllabels": + cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in data.columns if col.startswith("darts_")] X_data = data.drop(columns=cols_to_drop).dropna() if task == "binary": labels = ["No RTS", "RTS"] - y_data = data.loc[X_data.index, "darts_has_rts"] + y_data = data.loc[X_data.index, bincol] elif task == "count": # Put into n categories (log scaled) - y_data = data.loc[X_data.index, "darts_rts_count"] + y_data = data.loc[X_data.index, countcol] n_categories = 5 bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories # Change the first interval to start at 1 and add a category for 0 bins = pd.IntervalIndex.from_tuples( [(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins] ) + print(f"{bins=}") y_data = pd.cut(y_data, bins=bins) labels = [str(v) for v in y_data.sort_values().unique()] y_data = y_data.cat.codes elif task == "density": - y_data = data.loc[X_data.index, "darts_rts_density"] + y_data = data.loc[X_data.index, densitycol] n_categories = 5 bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories - # Change the first interval to start at 0 - bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins]) + print(f"{bins=}") y_data = pd.cut(y_data, bins=bins) labels = [str(v) for v in y_data.sort_values().unique()] y_data = y_data.cat.codes @@ -80,37 +102,11 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b return data, X_data, y_data, labels -def random_cv( - grid: Literal["hex", "healpix"], - level: int, - n_iter: int = 2000, - robust: bool = False, - task: Literal["binary", "count", "density"] = "binary", - model: Literal["espa", "xgboost", "rf", "knn"] = "espa", +def _create_clf( + settings: CVSettings, ): - """Perform random cross-validation on the training dataset. - - Args: - grid (Literal["hex", "healpix"]): The grid type to use. - level (int): The grid level to use. - n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. - robust (bool, optional): Whether to use robust training. Defaults to False. - task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". - - """ - _, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task) - print(f"Using {task}-class classification with {len(labels)} classes: {labels}") - print(f"{y_data.describe()=}") - X = X_data.to_numpy(dtype="float64") - y = y_data.to_numpy(dtype="int8") - X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) - print(f"{X.shape=}, {y.shape=}") - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) - print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") - - if model == "espa": - clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) - if task == "binary": + if settings.model == "espa": + if settings.task == "binary": param_grid = { "eps_cl": loguniform(1e-4, 1e1), "eps_e": loguniform(1e1, 1e7), @@ -122,7 +118,9 @@ def random_cv( "eps_e": loguniform(1e4, 1e8), "initial_K": randint(400, 800), } - elif model == "xgboost": + clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=settings.robust) + fit_params = {"max_iter": 300} + elif settings.model == "xgboost": param_grid = { "learning_rate": loguniform(1e-4, 1e-1), "max_depth": randint(3, 15), @@ -131,57 +129,102 @@ def random_cv( "colsample_bytree": loguniform(0.5, 1.0), } clf = XGBClassifier( - objective="multi:softprob" if task != "binary" else "binary:logistic", - eval_metric="mlogloss" if task != "binary" else "logloss", + objective="multi:softprob" if settings.task != "binary" else "binary:logistic", + eval_metric="mlogloss" if settings.task != "binary" else "logloss", random_state=42, tree_method="gpu_hist", device="cuda", ) - elif model == "rf": + fit_params = {} + elif settings.model == "rf": param_grid = { "max_depth": randint(5, 50), "n_estimators": randint(50, 500), } clf = RandomForestClassifier(random_state=42) - elif model == "knn": + fit_params = {} + elif settings.model == "knn": param_grid = { "n_neighbors": randint(3, 15), "weights": ["uniform", "distance"], "algorithm": ["brute", "kd_tree", "ball_tree"], } clf = KNeighborsClassifier(random_state=42) + fit_params = {} else: - raise ValueError(f"Unknown model: {model}") + raise ValueError(f"Unknown model: {settings.model}") + + return clf, param_grid, fit_params + + +def _serialize_param_grid(param_grid): + param_grid_serializable = {} + for key, dist in param_grid.items(): + # ! Hacky, but I can't find a better way to serialize scipy distributions once they are created + if isinstance(dist, rv_continuous_frozen): + param_grid_serializable[key] = { + "distribution": "loguniform", + "low": dist.a, + "high": dist.b, + } + elif isinstance(dist, rv_discrete_frozen): + param_grid_serializable[key] = { + "distribution": "randint", + "low": dist.a, + "high": dist.b, + } + elif isinstance(dist, list): + param_grid_serializable[key] = dist + else: + raise ValueError(f"Unknown distribution type for {key}: {type(dist)}") + return param_grid_serializable + + +@cli.default +def random_cv( + dataset_ensemble: DatasetEnsemble, + settings: CVSettings = CVSettings(), +): + """Perform random cross-validation on the training dataset. + + Args: + grid (Literal["hex", "healpix"]): The grid type to use. + level (int): The grid level to use. + n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. + robust (bool, optional): Whether to use robust training. Defaults to False. + task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". + + """ + print("Creating training data...") + _, X_data, y_data, labels = _create_xy_data(dataset_ensemble, task=settings.task) + print(f"Using {settings.task}-class classification with {len(labels)} classes: {labels}") + print(f"{y_data.describe()=}") + X = X_data.to_numpy(dtype="float64") + y = y_data.to_numpy(dtype="int8") + X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) + print(f"{X.shape=}, {y.shape=}") + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") + + clf, param_grid, fit_params = _create_clf(settings) + print(f"Using model: {settings.model} with parameters: {param_grid}") cv = KFold(n_splits=5, shuffle=True, random_state=42) - if task == "binary": - metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU - else: - metrics = [ - "accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted" - "f1_macro", - "f1_weighted", - "precision_macro", - "precision_weighted", - "recall_macro", - "jaccard_micro", - "jaccard_macro", - "jaccard_weighted", - ] + metrics = _metrics["binary" if settings.task == "binary" else "multiclass"] search = RandomizedSearchCV( clf, param_grid, - n_iter=n_iter, + n_iter=settings.n_iter, n_jobs=8, cv=cv, random_state=42, verbose=10, scoring=metrics, - refit="f1" if task == "binary" else "f1_weighted", + refit="f1" if settings.task == "binary" else "f1_weighted", ) print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"): - search.fit(X_train, y_train, max_iter=300) + search.fit(X_train, y_train, **fit_params) print("Best parameters combination found:") best_parameters = search.best_estimator_.get_params() @@ -192,35 +235,19 @@ def random_cv( print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}") - results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task) + results_dir = get_cv_results_dir( + "random_search", + grid=dataset_ensemble.grid, + level=dataset_ensemble.level, + task=settings.task, + ) # Store the search settings # First convert the param_grid distributions to a serializable format - param_grid_serializable = {} - for key, dist in param_grid.items(): - if isinstance(dist, loguniform): - param_grid_serializable[key] = { - "distribution": "loguniform", - "low": dist.a, - "high": dist.b, - } - elif isinstance(dist, randint): - param_grid_serializable[key] = { - "distribution": "randint", - "low": dist.a, - "high": dist.b, - } - elif isinstance(dist, list): - param_grid_serializable[key] = dist - else: - raise ValueError(f"Unknown distribution type for {key}: {type(dist)}") - settings = { - "task": task, - "model": model, - "grid": grid, - "level": level, - "random_state": 42, - "n_iter": n_iter, + param_grid_serializable = _serialize_param_grid(param_grid) + combined_settings = { + **asdict(settings), + **asdict(dataset_ensemble), "param_grid": param_grid_serializable, "cv_splits": cv.get_n_splits(), "metrics": metrics, @@ -229,7 +256,7 @@ def random_cv( settings_file = results_dir / "search_settings.toml" print(f"Storing search settings to {settings_file}") with open(settings_file, "w") as f: - toml.dump({"settings": settings}, f) + toml.dump({"settings": combined_settings}, f) # Store the best estimator model best_model_file = results_dir / "best_estimator_model.pkl" @@ -243,14 +270,12 @@ def random_cv( params = pd.json_normalize(results["params"]) # Concatenate the params columns with the original DataFrame results = pd.concat([results.drop(columns=["params"]), params], axis=1) - results["grid"] = grid - results["level"] = level results_file = results_dir / "search_results.parquet" print(f"Storing CV results to {results_file}") results.to_parquet(results_file) # Get the inner state of the best estimator - if model == "espa": + if settings.model == "espa": best_estimator = search.best_estimator_ # Annotate the state with xarray metadata features = X_data.columns.tolist() @@ -284,8 +309,6 @@ def random_cv( }, attrs={ "description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.", - "grid": grid, - "level": level, }, ) state_file = results_dir / "best_estimator_state.nc" @@ -294,7 +317,7 @@ def random_cv( # Predict probabilities for all cells print("Predicting probabilities for all cells...") - preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels) + preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=labels) preds_file = results_dir / "predicted_probabilities.parquet" print(f"Storing predicted probabilities to {preds_file}") preds.to_parquet(preds_file) @@ -303,9 +326,5 @@ def random_cv( print("Done.") -def main(): - cyclopts.run(random_cv) - - if __name__ == "__main__": - main() + cli()