Add multiclass training

This commit is contained in:
Tobias Hölzer 2025-11-09 22:34:03 +01:00
parent 553b54bb32
commit d5b35d6da4
7 changed files with 814 additions and 4867 deletions

11
pixi.lock generated
View file

@ -285,6 +285,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
@ -1301,7 +1302,7 @@ packages:
- pypi: ./ - pypi: ./
name: entropice name: entropice
version: 0.1.0 version: 0.1.0
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9 sha256: 852c87cdbd1d452fccaa6253c6ce7410dda9fa32d2f951d35441e454414acfc0
requires_dist: requires_dist:
- aiohttp>=3.12.11 - aiohttp>=3.12.11
- bokeh>=3.7.3 - bokeh>=3.7.3
@ -1350,6 +1351,7 @@ packages:
- altair[all]>=5.5.0,<6 - altair[all]>=5.5.0,<6
- h5netcdf>=1.7.3,<2 - h5netcdf>=1.7.3,<2
- streamlit-folium>=0.25.3,<0.26 - streamlit-folium>=0.25.3,<0.26
- st-theme>=1.2.3,<2
editable: true editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy name: entropy
@ -4507,6 +4509,13 @@ packages:
version: 1.0.1 version: 1.0.1
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718 sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
requires_python: '>=3.6' requires_python: '>=3.6'
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
name: st-theme
version: 1.2.3
sha256: 0a54d9817dd5f8a6d7b0d071b25ae72eacf536c63a5fb97374923938021b1389
requires_dist:
- streamlit>=1.33
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
name: stack-data name: stack-data
version: 0.6.3 version: 0.6.3

View file

@ -51,7 +51,7 @@ dependencies = [
"geocube>=0.7.1,<0.8", "geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2", "streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6", "altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
] ]
[project.scripts] [project.scripts]

View file

@ -67,6 +67,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1) grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1) grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
output_path = get_darts_rts_file(grid, level) output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path) grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}") print(f"Saved RTS labels to {output_path}")

View file

@ -87,9 +87,14 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file return dataset_file
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path: def get_cv_results_dir(
name: str,
grid: Literal["hex", "healpix"],
level: int,
task: Literal["binary", "multi"],
) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_binary" results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
return results_dir return results_dir

View file

@ -29,7 +29,50 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False): def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
"""Create X and y data from the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
Returns:
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
data = data[data["darts_has_coverage"]]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
if task == "binary":
labels = ["No RTS", "RTS"]
y_data = data.loc[X_data.index, "darts_has_rts"]
else:
# Put into n categories (log scaled)
y_data = data.loc[X_data.index, "darts_rts_count"]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
# Change the first interval to start at 1 and add a category for 0
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
return data, X_data, y_data, labels
def random_cv(
grid: Literal["hex", "healpix"],
level: int,
n_iter: int = 2000,
robust: bool = False,
task: Literal["binary", "multi"] = "binary",
):
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
Args: Args:
@ -37,21 +80,17 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False. robust (bool, optional): Whether to use robust training. Defaults to False.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
""" """
data = get_train_dataset_file(grid=grid, level=level) _, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
data = gpd.read_parquet(data) print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
data = data[data["darts_has_coverage"]] print(f"{y_data.describe()=}")
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
y_data = data.loc[X_data.index, "darts_has_rts"]
X = X_data.to_numpy(dtype="float32") X = X_data.to_numpy(dtype="float32")
y = y_data.to_numpy(dtype="int8") y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}") print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
param_grid = { param_grid = {
@ -62,7 +101,21 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU if task == "binary":
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
else:
metrics = [
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
"f1_macro",
"f1_weighted",
"precision_macro",
"precision_weighted",
"recall_macro",
"recall_weighted",
"jaccard_micro",
"jaccard_macro",
"jaccard_weighted",
]
search = RandomizedSearchCV( search = RandomizedSearchCV(
clf, clf,
param_grid, param_grid,
@ -72,7 +125,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
random_state=42, random_state=42,
verbose=3, verbose=3,
scoring=metrics, scoring=metrics,
refit="f1", refit="f1" if task == "binary" else "f1_weighted",
) )
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
@ -88,10 +141,11 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}")
results_dir = get_cv_results_dir("random_search", grid=grid, level=level) results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
# Store the search settings # Store the search settings
settings = { settings = {
"task": task,
"grid": grid, "grid": grid,
"level": level, "level": level,
"random_state": 42, "random_state": 42,
@ -115,7 +169,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
}, },
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": ["No RTS", "RTS"], "classes": labels,
} }
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
@ -144,7 +198,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
best_estimator = search.best_estimator_ best_estimator = search.best_estimator_
# Annotate the state with xarray metadata # Annotate the state with xarray metadata
features = X_data.columns.tolist() features = X_data.columns.tolist()
labels = y_data.unique().tolist()
boxes = list(range(best_estimator.K_)) boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray( box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(), best_estimator.S_.cpu().numpy(),
@ -185,7 +238,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"]) preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
preds_file = results_dir / "predicted_probabilities.parquet" preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}") print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file) preds.to_parquet(preds_file)

File diff suppressed because it is too large Load diff

4742
uv.lock generated

File diff suppressed because it is too large Load diff