Add multiclass training
This commit is contained in:
parent
553b54bb32
commit
d5b35d6da4
7 changed files with 814 additions and 4867 deletions
11
pixi.lock
generated
11
pixi.lock
generated
|
|
@ -285,6 +285,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
||||
|
|
@ -1301,7 +1302,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
|
||||
sha256: 852c87cdbd1d452fccaa6253c6ce7410dda9fa32d2f951d35441e454414acfc0
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -1350,6 +1351,7 @@ packages:
|
|||
- altair[all]>=5.5.0,<6
|
||||
- h5netcdf>=1.7.3,<2
|
||||
- streamlit-folium>=0.25.3,<0.26
|
||||
- st-theme>=1.2.3,<2
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
name: entropy
|
||||
|
|
@ -4507,6 +4509,13 @@ packages:
|
|||
version: 1.0.1
|
||||
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
|
||||
name: st-theme
|
||||
version: 1.2.3
|
||||
sha256: 0a54d9817dd5f8a6d7b0d071b25ae72eacf536c63a5fb97374923938021b1389
|
||||
requires_dist:
|
||||
- streamlit>=1.33
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||
name: stack-data
|
||||
version: 0.6.3
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ dependencies = [
|
|||
"geocube>=0.7.1,<0.8",
|
||||
"streamlit>=1.50.0,<2",
|
||||
"altair[all]>=5.5.0,<6",
|
||||
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
|
||||
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
|||
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
||||
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
||||
|
||||
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
|
||||
darts_counts = grid_gdf[darts_counts_columns]
|
||||
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
|
||||
|
||||
output_path = get_darts_rts_file(grid, level)
|
||||
grid_gdf.to_parquet(output_path)
|
||||
print(f"Saved RTS labels to {output_path}")
|
||||
|
|
|
|||
|
|
@ -87,9 +87,14 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
|||
return dataset_file
|
||||
|
||||
|
||||
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_cv_results_dir(
|
||||
name: str,
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
task: Literal["binary", "multi"],
|
||||
) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_binary"
|
||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
return results_dir
|
||||
|
|
|
|||
|
|
@ -29,7 +29,50 @@ pretty.install()
|
|||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
|
||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
|
||||
"""Create X and y data from the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
Returns:
|
||||
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
data = data[data["darts_has_coverage"]]
|
||||
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||
if task == "binary":
|
||||
labels = ["No RTS", "RTS"]
|
||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||
else:
|
||||
# Put into n categories (log scaled)
|
||||
y_data = data.loc[X_data.index, "darts_rts_count"]
|
||||
n_categories = 5
|
||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||
# Change the first interval to start at 1 and add a category for 0
|
||||
bins = pd.IntervalIndex.from_tuples(
|
||||
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
||||
)
|
||||
y_data = pd.cut(y_data, bins=bins)
|
||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||
y_data = y_data.cat.codes
|
||||
return data, X_data, y_data, labels
|
||||
|
||||
|
||||
def random_cv(
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
n_iter: int = 2000,
|
||||
robust: bool = False,
|
||||
task: Literal["binary", "multi"] = "binary",
|
||||
):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
|
|
@ -37,21 +80,17 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
data = data[data["darts_has_coverage"]]
|
||||
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
||||
print(f"{y_data.describe()=}")
|
||||
X = X_data.to_numpy(dtype="float32")
|
||||
y = y_data.to_numpy(dtype="int8")
|
||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||
print(f"{X.shape=}, {y.shape=}")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||
|
||||
param_grid = {
|
||||
|
|
@ -62,7 +101,21 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||
if task == "binary":
|
||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||
else:
|
||||
metrics = [
|
||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
|
||||
"f1_macro",
|
||||
"f1_weighted",
|
||||
"precision_macro",
|
||||
"precision_weighted",
|
||||
"recall_macro",
|
||||
"recall_weighted",
|
||||
"jaccard_micro",
|
||||
"jaccard_macro",
|
||||
"jaccard_weighted",
|
||||
]
|
||||
search = RandomizedSearchCV(
|
||||
clf,
|
||||
param_grid,
|
||||
|
|
@ -72,7 +125,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
random_state=42,
|
||||
verbose=3,
|
||||
scoring=metrics,
|
||||
refit="f1",
|
||||
refit="f1" if task == "binary" else "f1_weighted",
|
||||
)
|
||||
|
||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||
|
|
@ -88,10 +141,11 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||
|
||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
||||
|
||||
# Store the search settings
|
||||
settings = {
|
||||
"task": task,
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"random_state": 42,
|
||||
|
|
@ -115,7 +169,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
},
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
"classes": ["No RTS", "RTS"],
|
||||
"classes": labels,
|
||||
}
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
|
|
@ -144,7 +198,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
best_estimator = search.best_estimator_
|
||||
# Annotate the state with xarray metadata
|
||||
features = X_data.columns.tolist()
|
||||
labels = y_data.unique().tolist()
|
||||
boxes = list(range(best_estimator.K_))
|
||||
box_centers = xr.DataArray(
|
||||
best_estimator.S_.cpu().numpy(),
|
||||
|
|
@ -185,7 +238,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
|||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue