Make the training work again
This commit is contained in:
parent
2be2244cdb
commit
d22b857722
4 changed files with 156 additions and 133 deletions
|
|
@ -129,8 +129,8 @@ def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
|||
covered = ~grid_gdf["dartsml_coverage"].isna()
|
||||
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
||||
|
||||
grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
||||
grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
||||
grid_gdf["dartsml_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
||||
grid_gdf["dartsml_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
||||
|
||||
output_path = get_darts_rts_file(grid, level, labels=True)
|
||||
grid_gdf.to_parquet(output_path)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import json
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
|
@ -55,14 +56,17 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "
|
|||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass
|
||||
class DatasetEnsemble:
|
||||
grid: Literal["hex", "healpix"]
|
||||
level: int
|
||||
target: Literal["darts_rts", "darts_mllabels"]
|
||||
members: list[L2Dataset]
|
||||
dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict)
|
||||
variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict)
|
||||
members: list[L2Dataset] = field(
|
||||
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
)
|
||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||
variable_filters: dict[str, list[str]] = field(default_factory=dict)
|
||||
filter_target: str | Literal[False] = False
|
||||
add_lonlat: bool = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
# ruff: noqa: N806
|
||||
"""Inference runs on trained models."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import torch
|
||||
from entropy import ESPAClassifier
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from entropice.paths import get_train_dataset_file
|
||||
from entropice.dataset import DatasetEnsemble
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -18,32 +16,34 @@ pretty.install()
|
|||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier, classes: list) -> gpd.GeoDataFrame:
|
||||
def predict_proba(e: DatasetEnsemble, clf: BaseEstimator, classes: list) -> gpd.GeoDataFrame:
|
||||
"""Get predicted probabilities for each cell.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
clf (ESPAClassifier): The trained classifier to use for predictions.
|
||||
e (DatasetEnsemble): The dataset ensemble configuration.
|
||||
clf (BaseEstimator): The trained classifier to use for predictions.
|
||||
classes (list): List of class names.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted probabilities for each cell.
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
data = e.create()
|
||||
print(f"Predicting probabilities for {len(data)} cells...")
|
||||
|
||||
# Predict in batches to avoid memory issues
|
||||
batch_size = 10_000
|
||||
preds = []
|
||||
|
||||
cols_to_drop = ["geometry"]
|
||||
if e.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data.iloc[i : i + batch_size]
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
|
||||
cell_ids = X_batch.index.to_numpy()
|
||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float64")
|
||||
X_batch = torch.asarray(X_batch, device=0)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
"""Training of classification models training."""
|
||||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import toml
|
||||
import torch
|
||||
|
|
@ -15,63 +15,85 @@ from cuml.neighbors import KNeighborsClassifier
|
|||
from entropy import ESPAClassifier
|
||||
from rich import pretty, traceback
|
||||
from scipy.stats import loguniform, randint
|
||||
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||
from sklearn import set_config
|
||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||
from stopuhr import stopwatch
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
from entropice.dataset import DatasetEnsemble
|
||||
from entropice.inference import predict_proba
|
||||
from entropice.paths import (
|
||||
get_cv_results_dir,
|
||||
get_train_dataset_file,
|
||||
)
|
||||
from entropice.paths import get_cv_results_dir
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml"))
|
||||
|
||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"):
|
||||
"""Create X and y data from the training dataset.
|
||||
_metrics = {
|
||||
"binary": ["accuracy", "recall", "precision", "f1", "jaccard"],
|
||||
"multiclass": [
|
||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||
"f1_macro",
|
||||
"f1_weighted",
|
||||
"precision_macro",
|
||||
"precision_weighted",
|
||||
"recall_macro",
|
||||
"jaccard_micro",
|
||||
"jaccard_macro",
|
||||
"jaccard_weighted",
|
||||
],
|
||||
}
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
Returns:
|
||||
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass
|
||||
class CVSettings:
|
||||
n_iter: int = 2000
|
||||
robust: bool = False
|
||||
task: Literal["binary", "count", "density"] = "binary"
|
||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
data = data[data["darts_has_coverage"]]
|
||||
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
def _create_xy_data(e: DatasetEnsemble, task: Literal["binary", "count", "density"] = "binary"):
|
||||
data = e.create()
|
||||
|
||||
covcol = "dartsml_has_coverage" if e.target == "darts_mllabels" else "darts_has_coverage"
|
||||
bincol = "dartsml_has_rts" if e.target == "darts_mllabels" else "darts_has_rts"
|
||||
countcol = "dartsml_rts_count" if e.target == "darts_mllabels" else "darts_rts_count"
|
||||
densitycol = "dartsml_rts_density" if e.target == "darts_mllabels" else "darts_rts_density"
|
||||
|
||||
data = data[data[covcol]].reset_index(drop=True)
|
||||
|
||||
cols_to_drop = ["geometry"]
|
||||
if e.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||
if task == "binary":
|
||||
labels = ["No RTS", "RTS"]
|
||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||
y_data = data.loc[X_data.index, bincol]
|
||||
elif task == "count":
|
||||
# Put into n categories (log scaled)
|
||||
y_data = data.loc[X_data.index, "darts_rts_count"]
|
||||
y_data = data.loc[X_data.index, countcol]
|
||||
n_categories = 5
|
||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||
# Change the first interval to start at 1 and add a category for 0
|
||||
bins = pd.IntervalIndex.from_tuples(
|
||||
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
||||
)
|
||||
print(f"{bins=}")
|
||||
y_data = pd.cut(y_data, bins=bins)
|
||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||
y_data = y_data.cat.codes
|
||||
elif task == "density":
|
||||
y_data = data.loc[X_data.index, "darts_rts_density"]
|
||||
y_data = data.loc[X_data.index, densitycol]
|
||||
n_categories = 5
|
||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||
# Change the first interval to start at 0
|
||||
bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins])
|
||||
print(f"{bins=}")
|
||||
y_data = pd.cut(y_data, bins=bins)
|
||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||
y_data = y_data.cat.codes
|
||||
|
|
@ -80,37 +102,11 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
|||
return data, X_data, y_data, labels
|
||||
|
||||
|
||||
def random_cv(
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
n_iter: int = 2000,
|
||||
robust: bool = False,
|
||||
task: Literal["binary", "count", "density"] = "binary",
|
||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa",
|
||||
def _create_clf(
|
||||
settings: CVSettings,
|
||||
):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
"""
|
||||
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
||||
print(f"{y_data.describe()=}")
|
||||
X = X_data.to_numpy(dtype="float64")
|
||||
y = y_data.to_numpy(dtype="int8")
|
||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||
print(f"{X.shape=}, {y.shape=}")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||
|
||||
if model == "espa":
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||
if task == "binary":
|
||||
if settings.model == "espa":
|
||||
if settings.task == "binary":
|
||||
param_grid = {
|
||||
"eps_cl": loguniform(1e-4, 1e1),
|
||||
"eps_e": loguniform(1e1, 1e7),
|
||||
|
|
@ -122,7 +118,9 @@ def random_cv(
|
|||
"eps_e": loguniform(1e4, 1e8),
|
||||
"initial_K": randint(400, 800),
|
||||
}
|
||||
elif model == "xgboost":
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=settings.robust)
|
||||
fit_params = {"max_iter": 300}
|
||||
elif settings.model == "xgboost":
|
||||
param_grid = {
|
||||
"learning_rate": loguniform(1e-4, 1e-1),
|
||||
"max_depth": randint(3, 15),
|
||||
|
|
@ -131,57 +129,102 @@ def random_cv(
|
|||
"colsample_bytree": loguniform(0.5, 1.0),
|
||||
}
|
||||
clf = XGBClassifier(
|
||||
objective="multi:softprob" if task != "binary" else "binary:logistic",
|
||||
eval_metric="mlogloss" if task != "binary" else "logloss",
|
||||
objective="multi:softprob" if settings.task != "binary" else "binary:logistic",
|
||||
eval_metric="mlogloss" if settings.task != "binary" else "logloss",
|
||||
random_state=42,
|
||||
tree_method="gpu_hist",
|
||||
device="cuda",
|
||||
)
|
||||
elif model == "rf":
|
||||
fit_params = {}
|
||||
elif settings.model == "rf":
|
||||
param_grid = {
|
||||
"max_depth": randint(5, 50),
|
||||
"n_estimators": randint(50, 500),
|
||||
}
|
||||
clf = RandomForestClassifier(random_state=42)
|
||||
elif model == "knn":
|
||||
fit_params = {}
|
||||
elif settings.model == "knn":
|
||||
param_grid = {
|
||||
"n_neighbors": randint(3, 15),
|
||||
"weights": ["uniform", "distance"],
|
||||
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
||||
}
|
||||
clf = KNeighborsClassifier(random_state=42)
|
||||
fit_params = {}
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
raise ValueError(f"Unknown model: {settings.model}")
|
||||
|
||||
return clf, param_grid, fit_params
|
||||
|
||||
|
||||
def _serialize_param_grid(param_grid):
|
||||
param_grid_serializable = {}
|
||||
for key, dist in param_grid.items():
|
||||
# ! Hacky, but I can't find a better way to serialize scipy distributions once they are created
|
||||
if isinstance(dist, rv_continuous_frozen):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "loguniform",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, rv_discrete_frozen):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "randint",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, list):
|
||||
param_grid_serializable[key] = dist
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
||||
return param_grid_serializable
|
||||
|
||||
|
||||
@cli.default
|
||||
def random_cv(
|
||||
dataset_ensemble: DatasetEnsemble,
|
||||
settings: CVSettings = CVSettings(),
|
||||
):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
"""
|
||||
print("Creating training data...")
|
||||
_, X_data, y_data, labels = _create_xy_data(dataset_ensemble, task=settings.task)
|
||||
print(f"Using {settings.task}-class classification with {len(labels)} classes: {labels}")
|
||||
print(f"{y_data.describe()=}")
|
||||
X = X_data.to_numpy(dtype="float64")
|
||||
y = y_data.to_numpy(dtype="int8")
|
||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||
print(f"{X.shape=}, {y.shape=}")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||
|
||||
clf, param_grid, fit_params = _create_clf(settings)
|
||||
print(f"Using model: {settings.model} with parameters: {param_grid}")
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
if task == "binary":
|
||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||
else:
|
||||
metrics = [
|
||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||
"f1_macro",
|
||||
"f1_weighted",
|
||||
"precision_macro",
|
||||
"precision_weighted",
|
||||
"recall_macro",
|
||||
"jaccard_micro",
|
||||
"jaccard_macro",
|
||||
"jaccard_weighted",
|
||||
]
|
||||
metrics = _metrics["binary" if settings.task == "binary" else "multiclass"]
|
||||
search = RandomizedSearchCV(
|
||||
clf,
|
||||
param_grid,
|
||||
n_iter=n_iter,
|
||||
n_iter=settings.n_iter,
|
||||
n_jobs=8,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=10,
|
||||
scoring=metrics,
|
||||
refit="f1" if task == "binary" else "f1_weighted",
|
||||
refit="f1" if settings.task == "binary" else "f1_weighted",
|
||||
)
|
||||
|
||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||
search.fit(X_train, y_train, max_iter=300)
|
||||
search.fit(X_train, y_train, **fit_params)
|
||||
|
||||
print("Best parameters combination found:")
|
||||
best_parameters = search.best_estimator_.get_params()
|
||||
|
|
@ -192,35 +235,19 @@ def random_cv(
|
|||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||
|
||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
||||
results_dir = get_cv_results_dir(
|
||||
"random_search",
|
||||
grid=dataset_ensemble.grid,
|
||||
level=dataset_ensemble.level,
|
||||
task=settings.task,
|
||||
)
|
||||
|
||||
# Store the search settings
|
||||
# First convert the param_grid distributions to a serializable format
|
||||
param_grid_serializable = {}
|
||||
for key, dist in param_grid.items():
|
||||
if isinstance(dist, loguniform):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "loguniform",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, randint):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "randint",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, list):
|
||||
param_grid_serializable[key] = dist
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
||||
settings = {
|
||||
"task": task,
|
||||
"model": model,
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"random_state": 42,
|
||||
"n_iter": n_iter,
|
||||
param_grid_serializable = _serialize_param_grid(param_grid)
|
||||
combined_settings = {
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
"param_grid": param_grid_serializable,
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
|
|
@ -229,7 +256,7 @@ def random_cv(
|
|||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": settings}, f)
|
||||
toml.dump({"settings": combined_settings}, f)
|
||||
|
||||
# Store the best estimator model
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
|
|
@ -243,14 +270,12 @@ def random_cv(
|
|||
params = pd.json_normalize(results["params"])
|
||||
# Concatenate the params columns with the original DataFrame
|
||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||
results["grid"] = grid
|
||||
results["level"] = level
|
||||
results_file = results_dir / "search_results.parquet"
|
||||
print(f"Storing CV results to {results_file}")
|
||||
results.to_parquet(results_file)
|
||||
|
||||
# Get the inner state of the best estimator
|
||||
if model == "espa":
|
||||
if settings.model == "espa":
|
||||
best_estimator = search.best_estimator_
|
||||
# Annotate the state with xarray metadata
|
||||
features = X_data.columns.tolist()
|
||||
|
|
@ -284,8 +309,6 @@ def random_cv(
|
|||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
},
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
|
|
@ -294,7 +317,7 @@ def random_cv(
|
|||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
|
||||
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=labels)
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
|
@ -303,9 +326,5 @@ def random_cv(
|
|||
print("Done.")
|
||||
|
||||
|
||||
def main():
|
||||
cyclopts.run(random_cv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue