From fb522ddad599f7a909699df76420b8f100b5e08b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 8 Nov 2025 22:44:08 +0100 Subject: [PATCH] Add an inference map --- pixi.lock | 14 +- pyproject.toml | 7 +- src/entropice/dataset.py | 130 +++++ src/entropice/inference.py | 61 +++ src/entropice/training.py | 139 +++--- src/entropice/training_analysis_dashboard.py | 480 ++++++++++++------- 6 files changed, 580 insertions(+), 251 deletions(-) create mode 100644 src/entropice/dataset.py diff --git a/pixi.lock b/pixi.lock index 73c45c5..75bee29 100644 --- a/pixi.lock +++ b/pixi.lock @@ -288,6 +288,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/31/cc/099fab5a73909a117e9689c7da4c39a248595187f0f30dd879ad1d2c34ce/tblib-3.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl @@ -1300,7 +1301,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0 + sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -1348,6 +1349,7 @@ packages: - streamlit>=1.50.0,<2 - altair[all]>=5.5.0,<6 - h5netcdf>=1.7.3,<2 + - streamlit-folium>=0.25.3,<0.26 editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 name: entropy @@ -4560,6 +4562,16 @@ packages: - streamlit[auth,charts,pdf,snowflake,sql] ; extra == 'all' - rich>=11.0.0 ; extra == 'all' requires_python: '>=3.9,!=3.9.7' +- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl + name: streamlit-folium + version: 0.25.3 + sha256: cfdf085764da3f9b5e1e0668f6e4cc0385ff041c98133d023800983a875ca26c + requires_dist: + - streamlit>=1.13.0 + - folium>=0.13,!=0.15.0 + - jinja2 + - branca + requires_python: '>=3.9' - conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda sha256: 09d3b6ac51d437bc996ad006d9f749ca5c645c1900a854a6c8f193cbd13f03a8 md5: 8c09fac3785696e1c477156192d64b91 diff --git a/pyproject.toml b/pyproject.toml index 7cc71d4..f7ee2a1 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,9 @@ dependencies = [ "xvec>=0.5.1", "zarr[remote]>=3.1.3", "geocube>=0.7.1,<0.8", - "streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2", + "streamlit>=1.50.0,<2", + "altair[all]>=5.5.0,<6", + "h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", ] [project.scripts] @@ -57,7 +59,8 @@ create-grid = "entropice.grids:main" darts = "entropice.darts:main" alpha-earth = "entropice.alphaearth:main" era5 = "entropice.era5:cli" -train = "entropice.training:cli" +train = "entropice.training:main" +dataset = "entropice.dataset:main" [build-system] requires = ["hatchling"] diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py new file mode 100644 index 0000000..464399d --- /dev/null +++ b/src/entropice/dataset.py @@ -0,0 +1,130 @@ +"""Training dataset preparation and model training.""" + +from typing import Literal + +import cyclopts +import geopandas as gpd +import pandas as pd +import seaborn as sns +import xarray as xr +from rich import pretty, traceback +from sklearn import set_config +from stopuhr import stopwatch + +from entropice.paths import ( + get_darts_rts_file, + get_embeddings_store, + get_era5_stores, + get_train_dataset_file, +) + +traceback.install() +pretty.install() + +set_config(array_api_dispatch=True) + +sns.set_theme("talk", "whitegrid") + + +shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} +seasons = {10: "winter", 4: "summer"} + + +@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"]) +def _prep_era5( + rts: gpd.GeoDataFrame, + temporal: Literal["yearly", "seasonal", "shoulder"], + grid: Literal["hex", "healpix"], + level: int, +) -> pd.DataFrame: + era5_df = [] + era5_store = get_era5_stores(temporal, grid=grid, level=level) + era5 = xr.open_zarr(era5_store, consolidated=False) + era5 = era5.sel(cell_ids=rts["cell_id"].values) + + for var in era5.data_vars: + df = era5[var].drop_vars("spatial_ref").to_dataframe() + if temporal == "yearly": + df["t"] = df.index.get_level_values("time").year + elif temporal == "seasonal": + df["t"] = ( + df.index.get_level_values("time") + .month.map(lambda x: seasons.get(x)) + .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") + ) + elif temporal == "shoulder": + df["t"] = ( + df.index.get_level_values("time") + .month.map(lambda x: shoulder_seasons.get(x)) + .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") + ) + df = ( + df.pivot_table(index="cell_ids", columns="t", values=var) + .rename(columns=lambda x: f"{var}_{x}") + .rename_axis(None, axis=1) + ) + era5_df.append(df) + era5_df = pd.concat(era5_df, axis=1) + era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"}) + return era5_df + + +@stopwatch("Prepare embeddings data") +def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame: + embs_store = get_embeddings_store(grid=grid, level=level) + embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__ + embeddings = embeddings.sel(cell=rts["cell_id"].values) + + embeddings_df = embeddings.to_dataframe(name="value") + embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value") + embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] + + embeddings_df = embeddings_df.rename( + columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"} + ) + return embeddings_df + + +def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False): + """Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings. + + Args: + grid (Literal["hex", "healpix"]): The grid type to use. + level (int): The grid level to use. + + """ + rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level)) + # Filter to coverage + if filter_target: + rts = rts[rts["darts_has_coverage"]] + # Convert hex cell_id to int + if grid == "hex": + rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16)) + + # Add the lat / lon of the cell centers + rts["lon"] = rts.geometry.centroid.x + rts["lat"] = rts.geometry.centroid.y + + # Get era5 data + era5_yearly = _prep_era5(rts, "yearly", grid, level) + era5_seasonal = _prep_era5(rts, "seasonal", grid, level) + era5_shoulder = _prep_era5(rts, "shoulder", grid, level) + + # Get embeddings data + embeddings = _prep_embeddings(rts, grid, level) + + # Combine datasets by cell id / cell + with stopwatch("Combine datasets"): + dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings) + print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") + + dataset_file = get_train_dataset_file(grid=grid, level=level) + dataset.reset_index().to_parquet(dataset_file) + + +def main(): + cyclopts.run(prepare_dataset) + + +if __name__ == "__main__": + main() diff --git a/src/entropice/inference.py b/src/entropice/inference.py index e69de29..9b7263a 100644 --- a/src/entropice/inference.py +++ b/src/entropice/inference.py @@ -0,0 +1,61 @@ +# ruff: noqa: N806 +"""Inference runs on trained models.""" + +from typing import Literal + +import geopandas as gpd +import pandas as pd +import torch +from entropy import ESPAClassifier +from rich import pretty, traceback +from sklearn import set_config + +from entropice.paths import get_train_dataset_file + +traceback.install() +pretty.install() + +set_config(array_api_dispatch=True) + + +def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier): + """Get predicted probabilities for each cell. + + Args: + grid (Literal["hex", "healpix"]): The grid type to use. + level (int): The grid level to use. + clf (ESPAClassifier): The trained classifier to use for predictions. + + Returns: + list: A list of predicted probabilities for each cell. + + """ + data = get_train_dataset_file(grid=grid, level=level) + data = gpd.read_parquet(data) + print(f"Predicting probabilities for {len(data)} cells...") + + # Predict in batches to avoid memory issues + batch_size = 100_000 + preds = [] + for i in range(0, len(data), batch_size): + batch = data.iloc[i : i + batch_size] + cols_to_drop = ["cell_id", "geometry", "darts_has_rts"] + cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] + X_batch = batch.drop(columns=cols_to_drop).dropna() + cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy() + cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() + X_batch = X_batch.to_numpy(dtype="float32") + X_batch = torch.asarray(X_batch, device=0) + batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy() + batch_preds = gpd.GeoDataFrame( + { + "cell_id": cell_ids, + "predicted_proba": batch_preds, + "geometry": cell_geoms, + }, + crs="epsg:3413", + ) + preds.append(batch_preds) + preds = gpd.GeoDataFrame(pd.concat(preds)) + + return preds diff --git a/src/entropice/training.py b/src/entropice/training.py index 1350185..44ee0cb 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -1,12 +1,13 @@ # ruff: noqa: N806 """Training dataset preparation and model training.""" +import pickle from typing import Literal import cyclopts import geopandas as gpd import pandas as pd -import seaborn as sns +import toml import torch import xarray as xr from entropy import ESPAClassifier @@ -16,11 +17,9 @@ from sklearn import set_config from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from stopuhr import stopwatch +from entropice.inference import predict_proba from entropice.paths import ( get_cv_results_dir, - get_darts_rts_file, - get_embeddings_store, - get_era5_stores, get_train_dataset_file, ) @@ -29,83 +28,7 @@ pretty.install() set_config(array_api_dispatch=True) -sns.set_theme("talk", "whitegrid") -cli = cyclopts.App() - - -@cli.command() -def prepare_dataset(grid: Literal["hex", "healpix"], level: int): - """Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings. - - Args: - grid (Literal["hex", "healpix"]): The grid type to use. - level (int): The grid level to use. - - """ - rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level)) - # Filter to coverage - rts = rts[rts["darts_has_coverage"]] - # Convert hex cell_id to int - if grid == "hex": - rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16)) - - # Get era5 data - era5_df = [] - - shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} - seasons = { - 10: "winter", - 4: "summer", - } - for temporal in ["yearly", "seasonal", "shoulder"]: - era5_store = get_era5_stores(temporal, grid=grid, level=level) - era5 = xr.open_zarr(era5_store, consolidated=False) - era5 = era5.sel(cell_ids=rts["cell_id"].values) - - for var in era5.data_vars: - df = era5[var].drop_vars("spatial_ref").to_dataframe() - if temporal == "yearly": - df["t"] = df.index.get_level_values("time").year - elif temporal == "seasonal": - df["t"] = ( - df.index.get_level_values("time") - .month.map(lambda x: seasons.get(x)) - .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") - ) - elif temporal == "shoulder": - df["t"] = ( - df.index.get_level_values("time") - .month.map(lambda x: shoulder_seasons.get(x)) - .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") - ) - df = ( - df.pivot_table(index="cell_ids", columns="t", values=var) - .rename(columns=lambda x: f"{var}_{x}") - .rename_axis(None, axis=1) - ) - era5_df.append(df) - era5_df = pd.concat(era5_df, axis=1) - - # Get embeddings data - embs_store = get_embeddings_store(grid=grid, level=level) - embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__ - embeddings = embeddings.sel(cell=rts["cell_id"].values) - - embeddings_df = embeddings.to_dataframe(name="value") - embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value") - embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] - - # Combine datasets by cell id / cell - # TODO: use prefixes to easy split the features in analysis again - dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df) - print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") - - dataset_file = get_train_dataset_file(grid=grid, level=level) - dataset.reset_index().to_parquet(dataset_file) - - -@cli.command() def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): """Perform random cross-validation on the training dataset. @@ -117,6 +40,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): """ data = get_train_dataset_file(grid=grid, level=level) data = gpd.read_parquet(data) + data = data[data["darts_has_coverage"]] cols_to_drop = ["cell_id", "geometry", "darts_has_rts"] cols_to_drop += [col for col in data.columns if col.startswith("darts_")] @@ -142,7 +66,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): clf, param_grid, n_iter=n_iter, - n_jobs=20, + n_jobs=16, cv=cv, random_state=42, verbose=10, @@ -163,6 +87,45 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}") + results_dir = get_cv_results_dir("random_search", grid=grid, level=level) + + # Store the search settings + settings = { + "grid": grid, + "level": level, + "random_state": 42, + "n_iter": n_iter, + "param_grid": { + "eps_cl": { + "distribution": "loguniform", + "low": param_grid["eps_cl"].a, + "high": param_grid["eps_cl"].b, + }, + "eps_e": { + "distribution": "loguniform", + "low": param_grid["eps_e"].a, + "high": param_grid["eps_e"].b, + }, + "initial_K": { + "distribution": "randint", + "low": param_grid["initial_K"].a, + "high": param_grid["initial_K"].b, + }, + }, + "cv_splits": cv.get_n_splits(), + "metrics": metrics, + } + settings_file = results_dir / "search_settings.toml" + print(f"Storing search settings to {settings_file}") + with open(settings_file, "w") as f: + toml.dump({"settings": settings}, f) + + # Store the best estimator model + best_model_file = results_dir / "best_estimator_model.pkl" + print(f"Storing best estimator model to {best_model_file}") + with open(best_model_file, "wb") as f: + pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL) + # Store the search results results = pd.DataFrame(search.cv_results_) # Parse the params into individual columns @@ -171,7 +134,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): results = pd.concat([results.drop(columns=["params"]), params], axis=1) results["grid"] = grid results["level"] = level - results_dir = get_cv_results_dir("random_search", grid=grid, level=level) results_file = results_dir / "search_results.parquet" print(f"Storing CV results to {results_file}") results.to_parquet(results_file) @@ -219,9 +181,20 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): print(f"Storing best estimator state to {state_file}") state.to_netcdf(state_file, engine="h5netcdf") + # Predict probabilities for all cells + print("Predicting probabilities for all cells...") + preds = predict_proba(grid=grid, level=level, clf=best_estimator) + preds_file = results_dir / "predicted_probabilities.parquet" + print(f"Storing predicted probabilities to {preds_file}") + preds.to_parquet(preds_file) + stopwatch.summary() print("Done.") +def main(): + cyclopts.run(random_cv) + + if __name__ == "__main__": - cli() + main() diff --git a/src/entropice/training_analysis_dashboard.py b/src/entropice/training_analysis_dashboard.py index aa3c249..5029c17 100644 --- a/src/entropice/training_analysis_dashboard.py +++ b/src/entropice/training_analysis_dashboard.py @@ -4,14 +4,177 @@ from datetime import datetime from pathlib import Path import altair as alt +import folium +import geopandas as gpd import numpy as np import pandas as pd import streamlit as st import xarray as xr +from streamlit_folium import st_folium from entropice.paths import RESULTS_DIR +def get_available_result_files() -> list[Path]: + """Get all available result files from RESULTS_DIR.""" + if not RESULTS_DIR.exists(): + return [] + + result_files = [] + for search_dir in RESULTS_DIR.iterdir(): + if not search_dir.is_dir(): + continue + + result_file = search_dir / "search_results.parquet" + state_file = search_dir / "best_estimator_state.nc" + preds_file = search_dir / "predicted_probabilities.parquet" + settings_file = search_dir / "search_settings.toml" + if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists(): + result_files.append(search_dir) + + return sorted(result_files, reverse=True) # Most recent first + + +def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: + """Load results file and prepare binned columns. + + Args: + file_path: Path to the results parquet file. + k_bin_width: Width of bins for initial_K parameter. + + Returns: + DataFrame with added binned columns. + + """ + results = pd.read_parquet(file_path) + + # Automatically determine bin width for initial_K based on data range + k_min = results["initial_K"].min() + k_max = results["initial_K"].max() + # Use configurable bin width, adapted to actual data range + k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) + results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) + + # Automatically create logarithmic bins for epsilon parameters based on data range + # Use 10 bins spanning the actual data range + eps_cl_min = np.log10(results["eps_cl"].min()) + eps_cl_max = np.log10(results["eps_cl"].max()) + eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) + + eps_e_min = np.log10(results["eps_e"].min()) + eps_e_max = np.log10(results["eps_e"].max()) + eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) + + results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) + results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) + + return results + + +def load_and_prepare_model_state(file_path: Path) -> xr.Dataset: + """Load a model state from a NetCDF file. + + Args: + file_path (Path): The path to the NetCDF file. + + Returns: + xr.Dataset: The model state as an xarray Dataset. + + """ + return xr.open_dataset(file_path, engine="h5netcdf") + + +def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None: + """Extract embedding features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted embedding features. This DataArray has dimensions + ('agg', 'band', 'year') corresponding to the different components of the embedding features. + Returns None if no embedding features are found. + + """ + + def _is_embedding_feature(feature: str) -> bool: + return feature.startswith("embeddings_") + + embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] + if len(embedding_features) == 0: + return None + + # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) + embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] + embedding_feature_array = embedding_feature_array.assign_coords( + agg=("feature", [f.split("_")[1] for f in embedding_features]), + band=("feature", [f.split("_")[2] for f in embedding_features]), + year=("feature", [f.split("_")[3] for f in embedding_features]), + ) + embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 + return embedding_feature_array + + +def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None: + """Extract ERA5 features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted ERA5 features. This DataArray has dimensions + ('variable', 'time') corresponding to the different components of the ERA5 features. + Returns None if no ERA5 features are found. + + """ + + def _is_era5_feature(feature: str) -> bool: + return feature.startswith("era5_") + + def _extract_var_name(feature: str) -> str: + feature = feature.replace("era5_", "") + if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]): + return feature.rsplit("_", 2)[0] + else: + return feature.rsplit("_", 1)[0] + + def _extract_time_name(feature: str) -> str: + feature = feature.replace("era5_", "") + if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]): + return "_".join(feature.rsplit("_", 2)[-2:]) + else: + return feature.rsplit("_", 1)[-1] + + era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] + if len(era5_features) == 0: + return None + # Split the single feature dimension of era5 features into separate dimensions (variable, time) + era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] + era5_features_array = era5_features_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in era5_features]), + time=("feature", [_extract_time_name(f) for f in era5_features]), + ) + era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + return era5_features_array + + +# TODO: Extract common features, e.g. area or water content + + +def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map: + """Plot predicted probabilities on a map using Streamlit and Folium. + + Args: + preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns. + + Returns: + folium.Map: A Folium map object with the predicted probabilities visualized. + + """ + m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron") + return m + + def _plot_k_binned( results: pd.DataFrame, target: str, @@ -118,171 +281,6 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str): return chart -def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: - """Load results file and prepare binned columns. - - Args: - file_path: Path to the results parquet file. - k_bin_width: Width of bins for initial_K parameter. - - Returns: - DataFrame with added binned columns. - - """ - results = pd.read_parquet(file_path) - - # Automatically determine bin width for initial_K based on data range - k_min = results["initial_K"].min() - k_max = results["initial_K"].max() - # Use configurable bin width, adapted to actual data range - k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) - results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) - - # Automatically create logarithmic bins for epsilon parameters based on data range - # Use 10 bins spanning the actual data range - eps_cl_min = np.log10(results["eps_cl"].min()) - eps_cl_max = np.log10(results["eps_cl"].max()) - eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) - - eps_e_min = np.log10(results["eps_e"].min()) - eps_e_max = np.log10(results["eps_e"].max()) - eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) - - results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) - results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) - - return results - - -def load_and_prepare_model_state(file_path: Path) -> xr.Dataset: - """Load a model state from a NetCDF file. - - Args: - file_path (Path): The path to the NetCDF file. - - Returns: - xr.Dataset: The model state as an xarray Dataset. - - """ - return xr.open_dataset(file_path, engine="h5netcdf") - - -def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None: - """Extract embedding features from the model state. - - Args: - model_state: The xarray Dataset containing the model state. - - Returns: - xr.DataArray: The extracted embedding features. This DataArray has dimensions - ('agg', 'band', 'year') corresponding to the different components of the embedding features. - Returns None if no embedding features are found. - - """ - - def _is_embedding_feature(feature: str) -> bool: - parts = feature.split("_") - if len(parts) != 3: - return False - _, band, _ = parts - if not band.startswith("A"): - return False - if not band[1:].isdigit(): - return False - return True - - embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] - if len(embedding_features) == 0: - return None - - # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) - embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] - embedding_feature_array = embedding_feature_array.assign_coords( - agg=("feature", [f.split("_")[0] for f in embedding_features]), - band=("feature", [f.split("_")[1] for f in embedding_features]), - year=("feature", [f.split("_")[2] for f in embedding_features]), - ) - embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 - return embedding_feature_array - - -def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None: - """Extract ERA5 features from the model state. - - Args: - model_state: The xarray Dataset containing the model state. - - Returns: - xr.DataArray: The extracted ERA5 features. This DataArray has dimensions - ('variable', 'time') corresponding to the different components of the ERA5 features. - Returns None if no ERA5 features are found. - - """ - - def _is_era5_feature(feature: str) -> bool: - # Instant fit if winter or summer in the name - if "winter" in feature or "spring" in feature: - return True - # Instant fit if OND, JFM, AMJ or JAS in the name - if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]): - return True - parts = feature.split("_") - if len(parts) == 3: - _, band, year = parts - if band.startswith("A"): - return False - if year.isdigit(): - return True - elif parts[-1].isdigit(): - return True - return False - - def _extract_var_name(feature: str) -> str: - if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]): - return feature.rsplit("_", 2)[0] - else: - return feature.rsplit("_", 1)[0] - - def _extract_time_name(feature: str) -> str: - if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]): - return "_".join(feature.rsplit("_", 2)[-2:]) - else: - return feature.rsplit("_", 1)[-1] - - era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] - if len(era5_features) == 0: - return None - # Split the single feature dimension of era5 features into separate dimensions (variable, time) - era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] - era5_features_array = era5_features_array.assign_coords( - variable=("feature", [_extract_var_name(f) for f in era5_features]), - time=("feature", [_extract_time_name(f) for f in era5_features]), - ) - era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 - return era5_features_array - - -# TODO: Extract common features, e.g. area or water content - - -def get_available_result_files() -> list[Path]: - """Get all available result files from RESULTS_DIR.""" - if not RESULTS_DIR.exists(): - return [] - - result_files = [] - for search_dir in RESULTS_DIR.iterdir(): - if not search_dir.is_dir(): - continue - - result_file = search_dir / "search_results.parquet" - state_file = search_dir / "best_estimator_state.nc" - if result_file.exists() and state_file.exists(): - result_files.append(search_dir) - - return sorted(result_files, reverse=True) # Most recent first - - def _parse_results_dir_name(results_dir: Path) -> str: gridname, date = results_dir.name.split("_random_search_cv") gridname = gridname.lstrip("permafrost_") @@ -574,6 +572,95 @@ def _plot_era5_summary(era5_array: xr.DataArray): return chart_variable, chart_time +def _plot_box_assignments(model_state: xr.Dataset): + """Create a heatmap showing which boxes are assigned to which labels/classes. + + Args: + model_state: The xarray Dataset containing the model state with box_assignments. + + Returns: + Altair chart showing the box-to-label assignment heatmap. + + """ + # Extract box assignments + box_assignments = model_state["box_assignments"] + + # Convert to DataFrame for plotting + df = box_assignments.to_dataframe(name="assignment").reset_index() + + # Create heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)), + y=alt.Y("class:N", title="Class Label"), + color=alt.Color( + "assignment:Q", + scale=alt.Scale(scheme="viridis"), + title="Assignment Strength", + ), + tooltip=[ + alt.Tooltip("class:N", title="Class"), + alt.Tooltip("box:O", title="Box"), + alt.Tooltip("assignment:Q", format=".4f", title="Assignment"), + ], + ) + .properties( + height=150, + title="Box-to-Label Assignments (Lambda Matrix)", + ) + ) + + return chart + + +def _plot_box_assignment_bars(model_state: xr.Dataset): + """Create a bar chart showing how many boxes are assigned to each class. + + Args: + model_state: The xarray Dataset containing the model state with box_assignments. + + Returns: + Altair chart showing count of boxes per class. + + """ + # Extract box assignments + box_assignments = model_state["box_assignments"] + + # Convert to DataFrame + df = box_assignments.to_dataframe(name="assignment").reset_index() + + # For each box, find which class it's most strongly assigned to + box_to_class = df.groupby("box")["assignment"].idxmax() + primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True) + + # Count boxes per class + counts = primary_classes.groupby("class").size().reset_index(name="count") + + # Create bar chart + chart = ( + alt.Chart(counts) + .mark_bar() + .encode( + x=alt.X("class:N", title="Class Label"), + y=alt.Y("count:Q", title="Number of Boxes"), + color=alt.Color("class:N", title="Class", legend=None), + tooltip=[ + alt.Tooltip("class:N", title="Class"), + alt.Tooltip("count:Q", title="Number of Boxes"), + ], + ) + .properties( + width=600, + height=300, + title="Number of Boxes Assigned to Each Class (by Primary Assignment)", + ) + ) + + return chart + + def main(): """Run Streamlit dashboard application.""" st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") @@ -609,6 +696,7 @@ def main(): model_state["feature_weights"] *= n_features embedding_feature_array = extract_embedding_features(model_state) era5_feature_array = extract_era5_features(model_state) + predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") st.sidebar.success(f"Loaded {len(results)} results") @@ -618,6 +706,12 @@ def main(): "Select Metric", options=available_metrics, help="Choose which metric to visualize" ) + # Map visualization + st.header("Predictions Map") + st.markdown("Map showing predicted classes from the best estimator") + m = _plot_prediction_map(predictions) + st_folium(m, width="100%", height=300) + # Display some basic statistics st.header("Dataset Overview") col1, col2, col3 = st.columns(3) @@ -765,6 +859,62 @@ def main(): """ ) + # Box-to-Label Assignment Visualization + st.subheader("Box-to-Label Assignments") + st.markdown( + """ + This visualization shows how the learned boxes (prototypes in feature space) are + assigned to different class labels. The ESPA classifier learns K boxes and assigns + them to classes through the Lambda matrix. Higher values indicate stronger assignment + of a box to a particular class. + """ + ) + + with st.spinner("Generating box assignment visualizations..."): + col1, col2 = st.columns([0.7, 0.3]) + + with col1: + st.markdown("### Assignment Heatmap") + box_assignment_heatmap = _plot_box_assignments(model_state) + st.altair_chart(box_assignment_heatmap, use_container_width=True) + + with col2: + st.markdown("### Box Count by Class") + box_assignment_bars = _plot_box_assignment_bars(model_state) + st.altair_chart(box_assignment_bars, use_container_width=True) + + # Show statistics + with st.expander("Box Assignment Statistics"): + box_assignments = model_state["box_assignments"].to_pandas() + st.write("**Assignment Matrix Statistics:**") + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Boxes", len(box_assignments.columns)) + with col2: + st.metric("Number of Classes", len(box_assignments.index)) + with col3: + st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}") + with col4: + st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}") + + # Show which boxes are most strongly assigned to each class + st.write("**Top Box Assignments per Class:**") + for class_label in box_assignments.index: + top_boxes = box_assignments.loc[class_label].nlargest(5) + st.write( + f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} " + f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})" + ) + + st.markdown( + """ + **Interpretation:** + - Each box can be assigned to multiple classes with different strengths + - Boxes with higher assignment values for a class contribute more to that class's predictions + - The distribution shows how the model partitions the feature space for classification + """ + ) + # Embedding features analysis (if present) if embedding_feature_array is not None: st.subheader("Embedding Feature Analysis")