Make the training work again
This commit is contained in:
parent
2be2244cdb
commit
d22b857722
4 changed files with 156 additions and 133 deletions
|
|
@ -129,8 +129,8 @@ def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
||||||
covered = ~grid_gdf["dartsml_coverage"].isna()
|
covered = ~grid_gdf["dartsml_coverage"].isna()
|
||||||
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
||||||
|
|
||||||
grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
grid_gdf["dartsml_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
||||||
grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
grid_gdf["dartsml_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
||||||
|
|
||||||
output_path = get_darts_rts_file(grid, level, labels=True)
|
output_path = get_darts_rts_file(grid, level, labels=True)
|
||||||
grid_gdf.to_parquet(output_path)
|
grid_gdf.to_parquet(output_path)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ import json
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
@ -55,14 +56,17 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "
|
||||||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
|
||||||
|
@cyclopts.Parameter("*")
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetEnsemble:
|
class DatasetEnsemble:
|
||||||
grid: Literal["hex", "healpix"]
|
grid: Literal["hex", "healpix"]
|
||||||
level: int
|
level: int
|
||||||
target: Literal["darts_rts", "darts_mllabels"]
|
target: Literal["darts_rts", "darts_mllabels"]
|
||||||
members: list[L2Dataset]
|
members: list[L2Dataset] = field(
|
||||||
dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict)
|
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict)
|
)
|
||||||
|
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||||
|
variable_filters: dict[str, list[str]] = field(default_factory=dict)
|
||||||
filter_target: str | Literal[False] = False
|
filter_target: str | Literal[False] = False
|
||||||
add_lonlat: bool = True
|
add_lonlat: bool = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,14 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
"""Inference runs on trained models."""
|
"""Inference runs on trained models."""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from entropy import ESPAClassifier
|
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
|
from sklearn.base import BaseEstimator
|
||||||
|
|
||||||
from entropice.paths import get_train_dataset_file
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -18,32 +16,34 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
|
def predict_proba(e: DatasetEnsemble, clf: BaseEstimator, classes: list) -> gpd.GeoDataFrame:
|
||||||
"""Get predicted probabilities for each cell.
|
"""Get predicted probabilities for each cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
e (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
level (int): The grid level to use.
|
clf (BaseEstimator): The trained classifier to use for predictions.
|
||||||
clf (ESPAClassifier): The trained classifier to use for predictions.
|
|
||||||
classes (list): List of class names.
|
classes (list): List of class names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of predicted probabilities for each cell.
|
list: A list of predicted probabilities for each cell.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
data = e.create()
|
||||||
data = gpd.read_parquet(data)
|
|
||||||
print(f"Predicting probabilities for {len(data)} cells...")
|
print(f"Predicting probabilities for {len(data)} cells...")
|
||||||
|
|
||||||
# Predict in batches to avoid memory issues
|
# Predict in batches to avoid memory issues
|
||||||
batch_size = 10_000
|
batch_size = 10_000
|
||||||
preds = []
|
preds = []
|
||||||
|
|
||||||
|
cols_to_drop = ["geometry"]
|
||||||
|
if e.target == "darts_mllabels":
|
||||||
|
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
||||||
|
else:
|
||||||
|
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||||
for i in range(0, len(data), batch_size):
|
for i in range(0, len(data), batch_size):
|
||||||
batch = data.iloc[i : i + batch_size]
|
batch = data.iloc[i : i + batch_size]
|
||||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
||||||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||||
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
|
cell_ids = X_batch.index.to_numpy()
|
||||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||||
X_batch = X_batch.to_numpy(dtype="float64")
|
X_batch = X_batch.to_numpy(dtype="float64")
|
||||||
X_batch = torch.asarray(X_batch, device=0)
|
X_batch = torch.asarray(X_batch, device=0)
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,10 @@
|
||||||
"""Training of classification models training."""
|
"""Training of classification models training."""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import toml
|
import toml
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -15,63 +15,85 @@ from cuml.neighbors import KNeighborsClassifier
|
||||||
from entropy import ESPAClassifier
|
from entropy import ESPAClassifier
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from scipy.stats import loguniform, randint
|
from scipy.stats import loguniform, randint
|
||||||
|
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
from xgboost import XGBClassifier
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
|
from entropice.dataset import DatasetEnsemble
|
||||||
from entropice.inference import predict_proba
|
from entropice.inference import predict_proba
|
||||||
from entropice.paths import (
|
from entropice.paths import get_cv_results_dir
|
||||||
get_cv_results_dir,
|
|
||||||
get_train_dataset_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml"))
|
||||||
|
|
||||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"):
|
_metrics = {
|
||||||
"""Create X and y data from the training dataset.
|
"binary": ["accuracy", "recall", "precision", "f1", "jaccard"],
|
||||||
|
"multiclass": [
|
||||||
|
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||||
|
"f1_macro",
|
||||||
|
"f1_weighted",
|
||||||
|
"precision_macro",
|
||||||
|
"precision_weighted",
|
||||||
|
"recall_macro",
|
||||||
|
"jaccard_micro",
|
||||||
|
"jaccard_macro",
|
||||||
|
"jaccard_weighted",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
Args:
|
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
|
||||||
level (int): The grid level to use.
|
|
||||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
|
||||||
|
|
||||||
Returns:
|
@cyclopts.Parameter("*")
|
||||||
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
@dataclass
|
||||||
|
class CVSettings:
|
||||||
|
n_iter: int = 2000
|
||||||
|
robust: bool = False
|
||||||
|
task: Literal["binary", "count", "density"] = "binary"
|
||||||
|
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
||||||
|
|
||||||
"""
|
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
|
||||||
data = gpd.read_parquet(data)
|
|
||||||
data = data[data["darts_has_coverage"]]
|
|
||||||
|
|
||||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
|
def _create_xy_data(e: DatasetEnsemble, task: Literal["binary", "count", "density"] = "binary"):
|
||||||
|
data = e.create()
|
||||||
|
|
||||||
|
covcol = "dartsml_has_coverage" if e.target == "darts_mllabels" else "darts_has_coverage"
|
||||||
|
bincol = "dartsml_has_rts" if e.target == "darts_mllabels" else "darts_has_rts"
|
||||||
|
countcol = "dartsml_rts_count" if e.target == "darts_mllabels" else "darts_rts_count"
|
||||||
|
densitycol = "dartsml_rts_density" if e.target == "darts_mllabels" else "darts_rts_density"
|
||||||
|
|
||||||
|
data = data[data[covcol]].reset_index(drop=True)
|
||||||
|
|
||||||
|
cols_to_drop = ["geometry"]
|
||||||
|
if e.target == "darts_mllabels":
|
||||||
|
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
||||||
|
else:
|
||||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||||
if task == "binary":
|
if task == "binary":
|
||||||
labels = ["No RTS", "RTS"]
|
labels = ["No RTS", "RTS"]
|
||||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
y_data = data.loc[X_data.index, bincol]
|
||||||
elif task == "count":
|
elif task == "count":
|
||||||
# Put into n categories (log scaled)
|
# Put into n categories (log scaled)
|
||||||
y_data = data.loc[X_data.index, "darts_rts_count"]
|
y_data = data.loc[X_data.index, countcol]
|
||||||
n_categories = 5
|
n_categories = 5
|
||||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||||
# Change the first interval to start at 1 and add a category for 0
|
# Change the first interval to start at 1 and add a category for 0
|
||||||
bins = pd.IntervalIndex.from_tuples(
|
bins = pd.IntervalIndex.from_tuples(
|
||||||
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
||||||
)
|
)
|
||||||
|
print(f"{bins=}")
|
||||||
y_data = pd.cut(y_data, bins=bins)
|
y_data = pd.cut(y_data, bins=bins)
|
||||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||||
y_data = y_data.cat.codes
|
y_data = y_data.cat.codes
|
||||||
elif task == "density":
|
elif task == "density":
|
||||||
y_data = data.loc[X_data.index, "darts_rts_density"]
|
y_data = data.loc[X_data.index, densitycol]
|
||||||
n_categories = 5
|
n_categories = 5
|
||||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||||
# Change the first interval to start at 0
|
print(f"{bins=}")
|
||||||
bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins])
|
|
||||||
y_data = pd.cut(y_data, bins=bins)
|
y_data = pd.cut(y_data, bins=bins)
|
||||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||||
y_data = y_data.cat.codes
|
y_data = y_data.cat.codes
|
||||||
|
|
@ -80,37 +102,11 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
||||||
return data, X_data, y_data, labels
|
return data, X_data, y_data, labels
|
||||||
|
|
||||||
|
|
||||||
def random_cv(
|
def _create_clf(
|
||||||
grid: Literal["hex", "healpix"],
|
settings: CVSettings,
|
||||||
level: int,
|
|
||||||
n_iter: int = 2000,
|
|
||||||
robust: bool = False,
|
|
||||||
task: Literal["binary", "count", "density"] = "binary",
|
|
||||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa",
|
|
||||||
):
|
):
|
||||||
"""Perform random cross-validation on the training dataset.
|
if settings.model == "espa":
|
||||||
|
if settings.task == "binary":
|
||||||
Args:
|
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
|
||||||
level (int): The grid level to use.
|
|
||||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
|
||||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
|
||||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
|
||||||
|
|
||||||
"""
|
|
||||||
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
|
||||||
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
|
||||||
print(f"{y_data.describe()=}")
|
|
||||||
X = X_data.to_numpy(dtype="float64")
|
|
||||||
y = y_data.to_numpy(dtype="int8")
|
|
||||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
|
||||||
print(f"{X.shape=}, {y.shape=}")
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
||||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
|
||||||
|
|
||||||
if model == "espa":
|
|
||||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
|
||||||
if task == "binary":
|
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"eps_cl": loguniform(1e-4, 1e1),
|
"eps_cl": loguniform(1e-4, 1e1),
|
||||||
"eps_e": loguniform(1e1, 1e7),
|
"eps_e": loguniform(1e1, 1e7),
|
||||||
|
|
@ -122,7 +118,9 @@ def random_cv(
|
||||||
"eps_e": loguniform(1e4, 1e8),
|
"eps_e": loguniform(1e4, 1e8),
|
||||||
"initial_K": randint(400, 800),
|
"initial_K": randint(400, 800),
|
||||||
}
|
}
|
||||||
elif model == "xgboost":
|
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=settings.robust)
|
||||||
|
fit_params = {"max_iter": 300}
|
||||||
|
elif settings.model == "xgboost":
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"learning_rate": loguniform(1e-4, 1e-1),
|
"learning_rate": loguniform(1e-4, 1e-1),
|
||||||
"max_depth": randint(3, 15),
|
"max_depth": randint(3, 15),
|
||||||
|
|
@ -131,57 +129,102 @@ def random_cv(
|
||||||
"colsample_bytree": loguniform(0.5, 1.0),
|
"colsample_bytree": loguniform(0.5, 1.0),
|
||||||
}
|
}
|
||||||
clf = XGBClassifier(
|
clf = XGBClassifier(
|
||||||
objective="multi:softprob" if task != "binary" else "binary:logistic",
|
objective="multi:softprob" if settings.task != "binary" else "binary:logistic",
|
||||||
eval_metric="mlogloss" if task != "binary" else "logloss",
|
eval_metric="mlogloss" if settings.task != "binary" else "logloss",
|
||||||
random_state=42,
|
random_state=42,
|
||||||
tree_method="gpu_hist",
|
tree_method="gpu_hist",
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
elif model == "rf":
|
fit_params = {}
|
||||||
|
elif settings.model == "rf":
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"max_depth": randint(5, 50),
|
"max_depth": randint(5, 50),
|
||||||
"n_estimators": randint(50, 500),
|
"n_estimators": randint(50, 500),
|
||||||
}
|
}
|
||||||
clf = RandomForestClassifier(random_state=42)
|
clf = RandomForestClassifier(random_state=42)
|
||||||
elif model == "knn":
|
fit_params = {}
|
||||||
|
elif settings.model == "knn":
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"n_neighbors": randint(3, 15),
|
"n_neighbors": randint(3, 15),
|
||||||
"weights": ["uniform", "distance"],
|
"weights": ["uniform", "distance"],
|
||||||
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
||||||
}
|
}
|
||||||
clf = KNeighborsClassifier(random_state=42)
|
clf = KNeighborsClassifier(random_state=42)
|
||||||
|
fit_params = {}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model: {model}")
|
raise ValueError(f"Unknown model: {settings.model}")
|
||||||
|
|
||||||
|
return clf, param_grid, fit_params
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_param_grid(param_grid):
|
||||||
|
param_grid_serializable = {}
|
||||||
|
for key, dist in param_grid.items():
|
||||||
|
# ! Hacky, but I can't find a better way to serialize scipy distributions once they are created
|
||||||
|
if isinstance(dist, rv_continuous_frozen):
|
||||||
|
param_grid_serializable[key] = {
|
||||||
|
"distribution": "loguniform",
|
||||||
|
"low": dist.a,
|
||||||
|
"high": dist.b,
|
||||||
|
}
|
||||||
|
elif isinstance(dist, rv_discrete_frozen):
|
||||||
|
param_grid_serializable[key] = {
|
||||||
|
"distribution": "randint",
|
||||||
|
"low": dist.a,
|
||||||
|
"high": dist.b,
|
||||||
|
}
|
||||||
|
elif isinstance(dist, list):
|
||||||
|
param_grid_serializable[key] = dist
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
||||||
|
return param_grid_serializable
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default
|
||||||
|
def random_cv(
|
||||||
|
dataset_ensemble: DatasetEnsemble,
|
||||||
|
settings: CVSettings = CVSettings(),
|
||||||
|
):
|
||||||
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
|
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||||
|
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||||
|
|
||||||
|
"""
|
||||||
|
print("Creating training data...")
|
||||||
|
_, X_data, y_data, labels = _create_xy_data(dataset_ensemble, task=settings.task)
|
||||||
|
print(f"Using {settings.task}-class classification with {len(labels)} classes: {labels}")
|
||||||
|
print(f"{y_data.describe()=}")
|
||||||
|
X = X_data.to_numpy(dtype="float64")
|
||||||
|
y = y_data.to_numpy(dtype="int8")
|
||||||
|
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||||
|
print(f"{X.shape=}, {y.shape=}")
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
|
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||||
|
|
||||||
|
clf, param_grid, fit_params = _create_clf(settings)
|
||||||
|
print(f"Using model: {settings.model} with parameters: {param_grid}")
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
if task == "binary":
|
metrics = _metrics["binary" if settings.task == "binary" else "multiclass"]
|
||||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
|
||||||
else:
|
|
||||||
metrics = [
|
|
||||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
|
||||||
"f1_macro",
|
|
||||||
"f1_weighted",
|
|
||||||
"precision_macro",
|
|
||||||
"precision_weighted",
|
|
||||||
"recall_macro",
|
|
||||||
"jaccard_micro",
|
|
||||||
"jaccard_macro",
|
|
||||||
"jaccard_weighted",
|
|
||||||
]
|
|
||||||
search = RandomizedSearchCV(
|
search = RandomizedSearchCV(
|
||||||
clf,
|
clf,
|
||||||
param_grid,
|
param_grid,
|
||||||
n_iter=n_iter,
|
n_iter=settings.n_iter,
|
||||||
n_jobs=8,
|
n_jobs=8,
|
||||||
cv=cv,
|
cv=cv,
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=10,
|
verbose=10,
|
||||||
scoring=metrics,
|
scoring=metrics,
|
||||||
refit="f1" if task == "binary" else "f1_weighted",
|
refit="f1" if settings.task == "binary" else "f1_weighted",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||||
search.fit(X_train, y_train, max_iter=300)
|
search.fit(X_train, y_train, **fit_params)
|
||||||
|
|
||||||
print("Best parameters combination found:")
|
print("Best parameters combination found:")
|
||||||
best_parameters = search.best_estimator_.get_params()
|
best_parameters = search.best_estimator_.get_params()
|
||||||
|
|
@ -192,35 +235,19 @@ def random_cv(
|
||||||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||||
|
|
||||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
results_dir = get_cv_results_dir(
|
||||||
|
"random_search",
|
||||||
|
grid=dataset_ensemble.grid,
|
||||||
|
level=dataset_ensemble.level,
|
||||||
|
task=settings.task,
|
||||||
|
)
|
||||||
|
|
||||||
# Store the search settings
|
# Store the search settings
|
||||||
# First convert the param_grid distributions to a serializable format
|
# First convert the param_grid distributions to a serializable format
|
||||||
param_grid_serializable = {}
|
param_grid_serializable = _serialize_param_grid(param_grid)
|
||||||
for key, dist in param_grid.items():
|
combined_settings = {
|
||||||
if isinstance(dist, loguniform):
|
**asdict(settings),
|
||||||
param_grid_serializable[key] = {
|
**asdict(dataset_ensemble),
|
||||||
"distribution": "loguniform",
|
|
||||||
"low": dist.a,
|
|
||||||
"high": dist.b,
|
|
||||||
}
|
|
||||||
elif isinstance(dist, randint):
|
|
||||||
param_grid_serializable[key] = {
|
|
||||||
"distribution": "randint",
|
|
||||||
"low": dist.a,
|
|
||||||
"high": dist.b,
|
|
||||||
}
|
|
||||||
elif isinstance(dist, list):
|
|
||||||
param_grid_serializable[key] = dist
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
|
||||||
settings = {
|
|
||||||
"task": task,
|
|
||||||
"model": model,
|
|
||||||
"grid": grid,
|
|
||||||
"level": level,
|
|
||||||
"random_state": 42,
|
|
||||||
"n_iter": n_iter,
|
|
||||||
"param_grid": param_grid_serializable,
|
"param_grid": param_grid_serializable,
|
||||||
"cv_splits": cv.get_n_splits(),
|
"cv_splits": cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
"metrics": metrics,
|
||||||
|
|
@ -229,7 +256,7 @@ def random_cv(
|
||||||
settings_file = results_dir / "search_settings.toml"
|
settings_file = results_dir / "search_settings.toml"
|
||||||
print(f"Storing search settings to {settings_file}")
|
print(f"Storing search settings to {settings_file}")
|
||||||
with open(settings_file, "w") as f:
|
with open(settings_file, "w") as f:
|
||||||
toml.dump({"settings": settings}, f)
|
toml.dump({"settings": combined_settings}, f)
|
||||||
|
|
||||||
# Store the best estimator model
|
# Store the best estimator model
|
||||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||||
|
|
@ -243,14 +270,12 @@ def random_cv(
|
||||||
params = pd.json_normalize(results["params"])
|
params = pd.json_normalize(results["params"])
|
||||||
# Concatenate the params columns with the original DataFrame
|
# Concatenate the params columns with the original DataFrame
|
||||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||||
results["grid"] = grid
|
|
||||||
results["level"] = level
|
|
||||||
results_file = results_dir / "search_results.parquet"
|
results_file = results_dir / "search_results.parquet"
|
||||||
print(f"Storing CV results to {results_file}")
|
print(f"Storing CV results to {results_file}")
|
||||||
results.to_parquet(results_file)
|
results.to_parquet(results_file)
|
||||||
|
|
||||||
# Get the inner state of the best estimator
|
# Get the inner state of the best estimator
|
||||||
if model == "espa":
|
if settings.model == "espa":
|
||||||
best_estimator = search.best_estimator_
|
best_estimator = search.best_estimator_
|
||||||
# Annotate the state with xarray metadata
|
# Annotate the state with xarray metadata
|
||||||
features = X_data.columns.tolist()
|
features = X_data.columns.tolist()
|
||||||
|
|
@ -284,8 +309,6 @@ def random_cv(
|
||||||
},
|
},
|
||||||
attrs={
|
attrs={
|
||||||
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
||||||
"grid": grid,
|
|
||||||
"level": level,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
state_file = results_dir / "best_estimator_state.nc"
|
state_file = results_dir / "best_estimator_state.nc"
|
||||||
|
|
@ -294,7 +317,7 @@ def random_cv(
|
||||||
|
|
||||||
# Predict probabilities for all cells
|
# Predict probabilities for all cells
|
||||||
print("Predicting probabilities for all cells...")
|
print("Predicting probabilities for all cells...")
|
||||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
|
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=labels)
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
print(f"Storing predicted probabilities to {preds_file}")
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
preds.to_parquet(preds_file)
|
preds.to_parquet(preds_file)
|
||||||
|
|
@ -303,9 +326,5 @@ def random_cv(
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
cyclopts.run(random_cv)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
cli()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue