From 150f14ed52c69c74e1b01dc8313b99a934a8ba00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 8 Nov 2025 03:07:38 +0100 Subject: [PATCH] Extent the dashboard to use altair --- pixi.lock | 61 +- pyproject.toml | 3 +- src/entropice/inference.py | 0 src/entropice/paths.py | 5 +- src/entropice/training.py | 236 ++---- src/entropice/training_analysis_dashboard.py | 846 ++++++++++++++++--- 6 files changed, 886 insertions(+), 265 deletions(-) create mode 100644 src/entropice/inference.py diff --git a/pixi.lock b/pixi.lock index b8785c5..73c45c5 100644 --- a/pixi.lock +++ b/pixi.lock @@ -94,6 +94,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl @@ -209,6 +210,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/c6/2d/f0b184fa88d6630aa267680bdb8623fb69cb0d024b8c6f0d23f9a0f406d3/multidict-6.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl @@ -300,6 +302,9 @@ environments: - pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl @@ -417,6 +422,27 @@ packages: - sphinxext-altair ; extra == 'doc' - vl-convert-python>=1.7.0 ; extra == 'save' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl + name: altair-tiles + version: 0.4.0 + sha256: eeed1a6d89800f6cf5aafa6a59ee735bc7c243cd133acebccabfdbf69cc7e33c + requires_dist: + - altair + - mercantile + - xyzservices + - geopandas ; extra == 'dev' + - ghp-import ; extra == 'dev' + - hatch ; extra == 'dev' + - ipykernel ; extra == 'dev' + - ipython ; extra == 'dev' + - mypy ; extra == 'dev' + - pytest ; extra == 'dev' + - ruff>=0.1.4 ; extra == 'dev' + - vega-datasets ; extra == 'dev' + - vl-convert-python ; extra == 'dev' + - jupyter-book ; extra == 'doc' + - vl-convert-python ; extra == 'doc' + requires_python: '>=3.9' - pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl name: anywidget version: 0.9.18 @@ -1274,7 +1300,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: 39f2dabdc6891e121e03650dfde2f69b084370df8561d63478c6b3b518530e54 + sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -1292,7 +1318,6 @@ packages: - geemap>=0.36.3 - geopandas>=1.1.0 - h3>=4.2.2 - - h5netcdf>=1.6.4 - ipycytoscape>=1.3.3 - ipykernel>=6.29.5 - ipywidgets>=8.1.7 @@ -1321,6 +1346,8 @@ packages: - zarr[remote]>=3.1.3 - geocube>=0.7.1,<0.8 - streamlit>=1.50.0,<2 + - altair[all]>=5.5.0,<6 + - h5netcdf>=1.7.3,<2 editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 name: entropy @@ -3159,6 +3186,15 @@ packages: version: 0.1.2 sha256: 84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 requires_python: '>=3.7' +- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl + name: mercantile + version: 1.2.1 + sha256: 30f457a73ee88261aab787b7069d85961a5703bb09dc57a170190bc042cd023f + requires_dist: + - click>=3.0 + - check-manifest ; extra == 'dev' + - hypothesis ; extra == 'test' + - pytest ; extra == 'test' - conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda sha256: 1e59d0dc811f150d39c2ff2da930d69dcb91cb05966b7df5b7d85133006668ed md5: e4ab075598123e783b788b995afbdad0 @@ -4728,6 +4764,27 @@ packages: - pysocks>=1.5.6,!=1.5.7,<2.0 ; extra == 'socks' - zstandard>=0.18.0 ; extra == 'zstd' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl + name: vega-datasets + version: 0.9.0 + sha256: 3d7c63917be6ca9b154b565f4779a31fedce57b01b5b9d99d8a34a7608062a1d + requires_dist: + - pandas + requires_python: '>=3.5' +- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + name: vegafusion + version: 2.0.3 + sha256: 0b11c19a70f1bfe3d23d0a09aeecaac7bd03fac01a966d69fbd4dd8679dcb7e7 + requires_dist: + - arro3-core + - packaging + - narwhals>=1.42 + requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + name: vl-convert-python + version: 1.8.0 + sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617 + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl name: watchdog version: 6.0.0 diff --git a/pyproject.toml b/pyproject.toml index 055db6b..7cc71d4 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "geemap>=0.36.3", "geopandas>=1.1.0", "h3>=4.2.2", - "h5netcdf>=1.6.4", "ipycytoscape>=1.3.3", "ipykernel>=6.29.5", "ipywidgets>=8.1.7", @@ -50,7 +49,7 @@ dependencies = [ "xvec>=0.5.1", "zarr[remote]>=3.1.3", "geocube>=0.7.1,<0.8", - "streamlit>=1.50.0,<2", + "streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2", ] [project.scripts] diff --git a/src/entropice/inference.py b/src/entropice/inference.py new file mode 100644 index 0000000..e69de29 diff --git a/src/entropice/paths.py b/src/entropice/paths.py index 5eeb3d3..9fe9039 100644 --- a/src/entropice/paths.py +++ b/src/entropice/paths.py @@ -87,10 +87,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: return dataset_file -def get_cv_results_file(name: str, grid: Literal["hex", "healpix"], level: int) -> Path: +def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path: gridname = _get_gridname(grid, level) now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}" results_dir.mkdir(parents=True, exist_ok=True) - results_file = results_dir / "search_results.parquet" - return results_file + return results_dir diff --git a/src/entropice/training.py b/src/entropice/training.py index 9caf913..1350185 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -1,14 +1,10 @@ # ruff: noqa: N806 """Training dataset preparation and model training.""" -from pathlib import Path from typing import Literal import cyclopts import geopandas as gpd -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt -import numpy as np import pandas as pd import seaborn as sns import torch @@ -21,7 +17,7 @@ from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from stopuhr import stopwatch from entropice.paths import ( - get_cv_results_file, + get_cv_results_dir, get_darts_rts_file, get_embeddings_store, get_era5_stores, @@ -55,23 +51,41 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int): rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16)) # Get era5 data - era5_store = get_era5_stores("yearly", grid=grid, level=level) - era5 = xr.open_zarr(era5_store, consolidated=False) - era5 = era5.sel(cell_ids=rts["cell_id"].values) - era5_df = [] - for var in era5.data_vars: - df = era5[var].drop_vars("spatial_ref").to_dataframe() - df["year"] = df.index.get_level_values("time").year - df = ( - df.pivot_table(index="cell_ids", columns="year", values=var) - .rename(columns=lambda x: f"{var}_{x}") - .rename_axis(None, axis=1) - ) - era5_df.append(df) - era5_df = pd.concat(era5_df, axis=1) - # TODO: season and shoulder data + shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} + seasons = { + 10: "winter", + 4: "summer", + } + for temporal in ["yearly", "seasonal", "shoulder"]: + era5_store = get_era5_stores(temporal, grid=grid, level=level) + era5 = xr.open_zarr(era5_store, consolidated=False) + era5 = era5.sel(cell_ids=rts["cell_id"].values) + + for var in era5.data_vars: + df = era5[var].drop_vars("spatial_ref").to_dataframe() + if temporal == "yearly": + df["t"] = df.index.get_level_values("time").year + elif temporal == "seasonal": + df["t"] = ( + df.index.get_level_values("time") + .month.map(lambda x: seasons.get(x)) + .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") + ) + elif temporal == "shoulder": + df["t"] = ( + df.index.get_level_values("time") + .month.map(lambda x: shoulder_seasons.get(x)) + .str.cat(df.index.get_level_values("time").year.astype(str), sep="_") + ) + df = ( + df.pivot_table(index="cell_ids", columns="t", values=var) + .rename(columns=lambda x: f"{var}_{x}") + .rename_axis(None, axis=1) + ) + era5_df.append(df) + era5_df = pd.concat(era5_df, axis=1) # Get embeddings data embs_store = get_embeddings_store(grid=grid, level=level) @@ -83,6 +97,7 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int): embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] # Combine datasets by cell id / cell + # TODO: use prefixes to easy split the features in analysis again dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df) print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") @@ -91,12 +106,13 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int): @cli.command() -def random_cv(grid: Literal["hex", "healpix"], level: int): +def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000): """Perform random cross-validation on the training dataset. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. + n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. """ data = get_train_dataset_file(grid=grid, level=level) @@ -125,8 +141,8 @@ def random_cv(grid: Literal["hex", "healpix"], level: int): search = RandomizedSearchCV( clf, param_grid, - n_iter=20, - n_jobs=24, + n_iter=n_iter, + n_jobs=20, cv=cv, random_state=42, verbose=10, @@ -155,141 +171,57 @@ def random_cv(grid: Literal["hex", "healpix"], level: int): results = pd.concat([results.drop(columns=["params"]), params], axis=1) results["grid"] = grid results["level"] = level - results_file = get_cv_results_file("random_search", grid=grid, level=level) + results_dir = get_cv_results_dir("random_search", grid=grid, level=level) + results_file = results_dir / "search_results.parquet" print(f"Storing CV results to {results_file}") results.to_parquet(results_file) + # Get the inner state of the best estimator + best_estimator = search.best_estimator_ + # Annotate the state with xarray metadata + features = X_data.columns.tolist() + labels = y_data.unique().tolist() + boxes = list(range(best_estimator.K_)) + box_centers = xr.DataArray( + best_estimator.S_.cpu().numpy(), + dims=["feature", "box"], + coords={"feature": features, "box": boxes}, + name="box_centers", + attrs={"description": "Centers of the boxes in feature space."}, + ) + box_assignments = xr.DataArray( + best_estimator.Lambda_.cpu().numpy(), + dims=["class", "box"], + coords={"class": labels, "box": boxes}, + name="box_assignments", + attrs={"description": "Assignments of samples to boxes."}, + ) + feature_weights = xr.DataArray( + best_estimator.W_.cpu().numpy(), + dims=["feature"], + coords={"feature": features}, + name="feature_weights", + attrs={"description": "Feature weights for each box."}, + ) + state = xr.Dataset( + { + "box_centers": box_centers, + "box_assignments": box_assignments, + "feature_weights": feature_weights, + }, + attrs={ + "description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.", + "grid": grid, + "level": level, + }, + ) + state_file = results_dir / "best_estimator_state.nc" + print(f"Storing best estimator state to {state_file}") + state.to_netcdf(state_file, engine="h5netcdf") + stopwatch.summary() print("Done.") - plot_random_cv_results(results_file) - - -def _plot_k_binned( - results: pd.DataFrame, target: str, *, vmin_percentile: float | None = None, vmax_percentile: float | None = None -): - assert vmin_percentile is None or vmax_percentile is None, ( - "Only one of vmin_percentile or vmax_percentile can be set." - ) - assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results." - assert target in results.columns, f"{target} column not found in results." - assert "eps_e" in results.columns, "eps_e column not found in results." - assert "eps_cl" in results.columns, "eps_cl column not found in results." - - # add a colorbar instead of the sampled legend - cmap = sns.color_palette("ch:", as_cmap=True) - # sufisticated normalization - if vmin_percentile is not None: - vmin = np.percentile(results[target], vmin_percentile) - norm = mcolors.Normalize(vmin=vmin) - elif vmax_percentile is not None: - vmax = np.percentile(results[target], vmax_percentile) - norm = mcolors.Normalize(vmax=vmax) - else: - norm = mcolors.Normalize() - sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) - - # nice col-wrap based on columns - n_cols = results["initial_K_binned"].unique().size - col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3) - - scatter = sns.relplot( - data=results, - x="eps_e", - y="eps_cl", - hue=target, - hue_norm=sm.norm, - palette=cmap, - legend=False, - col="initial_K_binned", - col_wrap=col_wrap, - ) - - # Apply log scale to all axes - for ax in scatter.axes.flat: - ax.set_xscale("log") - ax.set_yscale("log") - - # Tight layout - scatter.figure.tight_layout() - - # Add a shared colorbar at the bottom - scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar - cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height] - cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal") - cbar.set_label(target) - - return scatter - - -def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str): - assert "initial_K" in results.columns, "initial_K column not found in results." - assert metric in results.columns, f"{metric} not found in results." - - if target == "eps_cl": - hue = "eps_cl" - col = "eps_e_binned" - elif target == "eps_e": - hue = "eps_e" - col = "eps_cl_binned" - assert hue in results.columns, f"{hue} column not found in results." - assert col in results.columns, f"{col} column not found in results." - - return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm()) - - -@cli.command() -def plot_random_cv_results(file: Path): - """Plot analysis of the results from the RandomCVSearch. - - Args: - file (Path): The file of the results. - - """ - print(f"Plotting random CV results from {file}...") - results = pd.read_parquet(file) - # Bin the initial_K into 40er bins - results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False) - # Bin the eps_cl and eps_e into logarithmic bins - eps_cl_bins = np.logspace(-3, 7, num=10) - eps_e_bins = np.logspace(-3, 7, num=10) - results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) - results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) - - figdir = file.parent - - # K-Plots - metrics = ["f1"] - for metric in metrics: - _plot_k_binned( - results, - f"mean_test_{metric}", - vmin_percentile=50, - ).figure.savefig(figdir / f"params3d-mean_{metric}.pdf") - _plot_k_binned( - results, - f"std_test_{metric}", - vmax_percentile=50, - ).figure.savefig(figdir / f"params3d-std_{metric}.pdf") - _plot_k_binned(results, f"mean_test_{metric}").figure.savefig(figdir / f"params3d-mean_{metric}-noperc.pdf") - _plot_k_binned(results, f"std_test_{metric}").figure.savefig(figdir / f"params3d-std_{metric}-noperc.pdf") - - # eps-Plots - _plot_eps_binned( - results, - "eps_cl", - f"mean_test_{metric}", - ).figure.savefig(figdir / f"k-eps_cl-mean_{metric}.pdf") - _plot_eps_binned( - results, - "eps_e", - f"mean_test_{metric}", - ).figure.savefig(figdir / f"k-eps_e-mean_{metric}.pdf") - - # Close all figures - plt.close("all") - print("Done.") - if __name__ == "__main__": cli() diff --git a/src/entropice/training_analysis_dashboard.py b/src/entropice/training_analysis_dashboard.py index b1fa78a..aa3c249 100644 --- a/src/entropice/training_analysis_dashboard.py +++ b/src/entropice/training_analysis_dashboard.py @@ -1,18 +1,16 @@ """Streamlit dashboard for training analysis results visualization.""" +from datetime import datetime from pathlib import Path -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt +import altair as alt import numpy as np import pandas as pd -import seaborn as sns import streamlit as st +import xarray as xr from entropice.paths import RESULTS_DIR -sns.set_theme("talk", "whitegrid") - def _plot_k_binned( results: pd.DataFrame, @@ -30,50 +28,47 @@ def _plot_k_binned( assert "eps_e" in results.columns, "eps_e column not found in results." assert "eps_cl" in results.columns, "eps_cl column not found in results." - # add a colorbar instead of the sampled legend - cmap = sns.color_palette("ch:", as_cmap=True) - # sophisticated normalization + # Prepare data + plot_data = results[["eps_e", "eps_cl", "initial_K_binned", target]].copy() + + # Sort bins by their left value and convert to string with sorted categories + plot_data = plot_data.sort_values("initial_K_binned") + plot_data["initial_K_binned"] = plot_data["initial_K_binned"].astype(str) + bin_order = plot_data["initial_K_binned"].unique().tolist() + + # Determine color scale domain if vmin_percentile is not None: vmin = np.percentile(results[target], vmin_percentile) - norm = mcolors.Normalize(vmin=vmin) + color_scale = alt.Scale(scheme="viridis", domain=[vmin, plot_data[target].max()]) elif vmax_percentile is not None: vmax = np.percentile(results[target], vmax_percentile) - norm = mcolors.Normalize(vmax=vmax) + color_scale = alt.Scale(scheme="viridis", domain=[plot_data[target].min(), vmax]) else: - norm = mcolors.Normalize() - sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + color_scale = alt.Scale(scheme="viridis") - # nice col-wrap based on columns - n_cols = results["initial_K_binned"].unique().size - col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3) - - scatter = sns.relplot( - data=results, - x="eps_e", - y="eps_cl", - hue=target, - hue_norm=sm.norm, - palette=cmap, - legend=False, - col="initial_K_binned", - col_wrap=col_wrap, + # Create the chart + chart = ( + alt.Chart(plot_data) + .mark_circle(size=60, opacity=0.7) + .encode( + x=alt.X( + "eps_e:Q", + scale=alt.Scale(type="log"), + axis=alt.Axis(title="eps_e", grid=True, gridOpacity=0.5), + ), + y=alt.Y( + "eps_cl:Q", + scale=alt.Scale(type="log"), + axis=alt.Axis(title="eps_cl", grid=True, gridOpacity=0.5), + ), + color=alt.Color(f"{target}:Q", scale=color_scale, title=target), + tooltip=["eps_e:Q", "eps_cl:Q", alt.Tooltip(f"{target}:Q", format=".4f"), "initial_K_binned:N"], + ) + .properties(width=200, height=200) + .facet(facet=alt.Facet("initial_K_binned:N", title="Initial K", sort=bin_order), columns=5) ) - # Apply log scale to all axes - for ax in scatter.axes.flat: - ax.set_xscale("log") - ax.set_yscale("log") - - # Tight layout - scatter.figure.tight_layout() - - # Add a shared colorbar at the bottom - scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar - cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height] - cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal") - cbar.set_label(target) - - return scatter + return chart def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str): @@ -93,25 +88,183 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str): assert hue in results.columns, f"{hue} column not found in results." assert col in results.columns, f"{col} column not found in results." - return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm()) + # Prepare data + plot_data = results[["initial_K", metric, hue, col]].copy() + + # Sort bins by their left value and convert to string with sorted categories + plot_data = plot_data.sort_values(col) + plot_data[col] = plot_data[col].astype(str) + bin_order = plot_data[col].unique().tolist() + + # Create the chart + chart = ( + alt.Chart(plot_data) + .mark_circle(size=60, opacity=0.7) + .encode( + x=alt.X("initial_K:Q", title="Initial K"), + y=alt.Y(f"{metric}:Q", title=metric), + color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue), + tooltip=[ + "initial_K:Q", + alt.Tooltip(f"{metric}:Q", format=".4f"), + alt.Tooltip(f"{hue}:Q", format=".2e"), + f"{col}:N", + ], + ) + .properties(width=200, height=200) + .facet(facet=alt.Facet(f"{col}:N", title=col.replace("_binned", ""), sort=bin_order), columns=5) + ) + + return chart -def load_and_prepare_results(file_path: Path) -> pd.DataFrame: - """Load results file and prepare binned columns.""" +def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame: + """Load results file and prepare binned columns. + + Args: + file_path: Path to the results parquet file. + k_bin_width: Width of bins for initial_K parameter. + + Returns: + DataFrame with added binned columns. + + """ results = pd.read_parquet(file_path) - # Bin the initial_K into 40er bins - results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False) + # Automatically determine bin width for initial_K based on data range + k_min = results["initial_K"].min() + k_max = results["initial_K"].max() + # Use configurable bin width, adapted to actual data range + k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width) + results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False) + + # Automatically create logarithmic bins for epsilon parameters based on data range + # Use 10 bins spanning the actual data range + eps_cl_min = np.log10(results["eps_cl"].min()) + eps_cl_max = np.log10(results["eps_cl"].max()) + eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10) + + eps_e_min = np.log10(results["eps_e"].min()) + eps_e_max = np.log10(results["eps_e"].max()) + eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10) - # Bin the eps_cl and eps_e into logarithmic bins - eps_cl_bins = np.logspace(-3, 7, num=10) - eps_e_bins = np.logspace(-3, 7, num=10) results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins) results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins) return results +def load_and_prepare_model_state(file_path: Path) -> xr.Dataset: + """Load a model state from a NetCDF file. + + Args: + file_path (Path): The path to the NetCDF file. + + Returns: + xr.Dataset: The model state as an xarray Dataset. + + """ + return xr.open_dataset(file_path, engine="h5netcdf") + + +def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None: + """Extract embedding features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted embedding features. This DataArray has dimensions + ('agg', 'band', 'year') corresponding to the different components of the embedding features. + Returns None if no embedding features are found. + + """ + + def _is_embedding_feature(feature: str) -> bool: + parts = feature.split("_") + if len(parts) != 3: + return False + _, band, _ = parts + if not band.startswith("A"): + return False + if not band[1:].isdigit(): + return False + return True + + embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] + if len(embedding_features) == 0: + return None + + # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) + embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] + embedding_feature_array = embedding_feature_array.assign_coords( + agg=("feature", [f.split("_")[0] for f in embedding_features]), + band=("feature", [f.split("_")[1] for f in embedding_features]), + year=("feature", [f.split("_")[2] for f in embedding_features]), + ) + embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 + return embedding_feature_array + + +def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None: + """Extract ERA5 features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted ERA5 features. This DataArray has dimensions + ('variable', 'time') corresponding to the different components of the ERA5 features. + Returns None if no ERA5 features are found. + + """ + + def _is_era5_feature(feature: str) -> bool: + # Instant fit if winter or summer in the name + if "winter" in feature or "spring" in feature: + return True + # Instant fit if OND, JFM, AMJ or JAS in the name + if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]): + return True + parts = feature.split("_") + if len(parts) == 3: + _, band, year = parts + if band.startswith("A"): + return False + if year.isdigit(): + return True + elif parts[-1].isdigit(): + return True + return False + + def _extract_var_name(feature: str) -> str: + if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]): + return feature.rsplit("_", 2)[0] + else: + return feature.rsplit("_", 1)[0] + + def _extract_time_name(feature: str) -> str: + if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]): + return "_".join(feature.rsplit("_", 2)[-2:]) + else: + return feature.rsplit("_", 1)[-1] + + era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] + if len(era5_features) == 0: + return None + # Split the single feature dimension of era5 features into separate dimensions (variable, time) + era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] + era5_features_array = era5_features_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in era5_features]), + time=("feature", [_extract_time_name(f) for f in era5_features]), + ) + era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + return era5_features_array + + +# TODO: Extract common features, e.g. area or water content + + def get_available_result_files() -> list[Path]: """Get all available result files from RESULTS_DIR.""" if not RESULTS_DIR.exists(): @@ -119,14 +272,308 @@ def get_available_result_files() -> list[Path]: result_files = [] for search_dir in RESULTS_DIR.iterdir(): - if search_dir.is_dir(): - result_file = search_dir / "search_results.parquet" - if result_file.exists(): - result_files.append(result_file) + if not search_dir.is_dir(): + continue + + result_file = search_dir / "search_results.parquet" + state_file = search_dir / "best_estimator_state.nc" + if result_file.exists() and state_file.exists(): + result_files.append(search_dir) return sorted(result_files, reverse=True) # Most recent first +def _parse_results_dir_name(results_dir: Path) -> str: + gridname, date = results_dir.name.split("_random_search_cv") + gridname = gridname.lstrip("permafrost_") + date = datetime.strptime(date, "%Y%m%d-%H%M%S") + date = date.strftime("%Y-%m-%d %H:%M:%S") + return f"{gridname} ({date})" + + +def _plot_top_features(model_state: xr.Dataset, top_n: int = 10): + """Plot the top N most important features based on feature weights. + + Args: + model_state: The xarray Dataset containing the model state. + top_n: Number of top features to display. + + Returns: + Altair chart showing the top features by importance. + + """ + # Extract feature weights + feature_weights = model_state["feature_weights"].to_pandas() + + # Sort by absolute weight and take top N + top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True) + + # Create DataFrame for plotting with original (signed) weights + plot_data = pd.DataFrame( + { + "feature": top_features.index, + "weight": feature_weights.loc[top_features.index].to_numpy(), + "abs_weight": top_features.to_numpy(), + } + ) + + # Create horizontal bar chart + chart = ( + alt.Chart(plot_data) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), + color=alt.condition( + alt.datum.weight > 0, + alt.value("steelblue"), # Positive weights + alt.value("coral"), # Negative weights + ), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), + ], + ) + .properties( + width=600, + height=400, + title=f"Top {top_n} Most Important Features", + ) + ) + + return chart + + +def _plot_embedding_heatmap(embedding_array: xr.DataArray): + """Create a heatmap showing embedding feature weights across bands and years. + + Args: + embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. + + Returns: + Altair chart showing the heatmap. + + """ + # Convert to DataFrame for plotting + df = embedding_array.to_dataframe(name="weight").reset_index() + + # Create faceted heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("year:O", title="Year"), + y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("agg:N", title="Aggregation"), + alt.Tooltip("band:N", title="Band"), + alt.Tooltip("year:O", title="Year"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties(width=200, height=200) + .facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11) + ) + + return chart + + +def _plot_embedding_aggregation_summary(embedding_array: xr.DataArray): + """Create bar charts summarizing embedding weights by aggregation, band, and year. + + Args: + embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. + + Returns: + Tuple of three Altair charts (by_agg, by_band, by_year). + + """ + # Aggregate by different dimensions + by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs() + by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs() + by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs() + + # Create DataFrames + df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()}) + df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()}) + df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()}) + + # Sort by weight + df_agg = df_agg.sort_values("mean_abs_weight", ascending=True) + df_band = df_band.sort_values("mean_abs_weight", ascending=True) + df_year = df_year.sort_values("mean_abs_weight", ascending=True) + + # Create charts with different colors + chart_agg = ( + alt.Chart(df_agg) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Aggregation", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="blues"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Aggregation"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Aggregation") + ) + + chart_band = ( + alt.Chart(df_band) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Band", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="greens"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Band"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Band") + ) + + chart_year = ( + alt.Chart(df_year) + .mark_bar() + .encode( + y=alt.Y("dimension:O", title="Year", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="oranges"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:O", title="Year"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Year") + ) + + return chart_agg, chart_band, chart_year + + +def _plot_era5_heatmap(era5_array: xr.DataArray): + """Create a heatmap showing ERA5 feature weights across variables and time. + + Args: + era5_array: DataArray with dimensions (variable, time) containing feature weights. + + Returns: + Altair chart showing the heatmap. + + """ + # Convert to DataFrame for plotting + df = era5_array.to_dataframe(name="weight").reset_index() + + # Create heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("time:N", title="Time", sort=None), + y=alt.Y("variable:N", title="Variable", sort="-color"), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("variable:N", title="Variable"), + alt.Tooltip("time:N", title="Time"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties( + height=400, + title="ERA5 Feature Weights Heatmap", + ) + ) + + return chart + + +def _plot_era5_summary(era5_array: xr.DataArray): + """Create bar charts summarizing ERA5 weights by variable and time. + + Args: + era5_array: DataArray with dimensions (variable, time) containing feature weights. + + Returns: + Tuple of two Altair charts (by_variable, by_time). + + """ + # Aggregate by different dimensions + by_variable = era5_array.mean(dim="time").to_pandas().abs() + by_time = era5_array.mean(dim="variable").to_pandas().abs() + + # Create DataFrames + df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()}) + df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()}) + + # Sort by weight + df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) + df_time = df_time.sort_values("mean_abs_weight", ascending=True) + + # Create charts with different colors + chart_variable = ( + alt.Chart(df_variable) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="purples"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Variable"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Variable") + ) + + chart_time = ( + alt.Chart(df_time) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="teals"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Time"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Time") + ) + + return chart_variable, chart_time + + def main(): """Run Streamlit dashboard application.""" st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") @@ -138,23 +585,30 @@ def main(): st.sidebar.header("Configuration") # Get available result files - result_files = get_available_result_files() + result_dirs = get_available_result_files() - if not result_files: + if not result_dirs: st.error(f"No result files found in {RESULTS_DIR}") st.info("Please run a random CV search first to generate results.") return - # File selection - file_options = {str(f.parent.name): f for f in result_files} - selected_file_name = st.sidebar.selectbox( - "Select Result File", options=list(file_options.keys()), help="Choose a search result file to visualize" + # Directory selection + dir_options = {_parse_results_dir_name(f): f for f in result_dirs} + selected_dir_name = st.sidebar.selectbox( + "Select Result Directory", + options=list(dir_options.keys()), + help="Choose a search result directory to visualize", ) - selected_file = file_options[selected_file_name] + results_dir = dir_options[selected_dir_name] - # Load and prepare data + # Load and prepare data with default bin width (will be reloaded with custom width later) with st.spinner("Loading results..."): - results = load_and_prepare_results(selected_file) + results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40) + model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc") + n_features = model_state.sizes["feature"] + model_state["feature_weights"] *= n_features + embedding_feature_array = extract_embedding_features(model_state) + era5_feature_array = extract_era5_features(model_state) st.sidebar.success(f"Loaded {len(results)} results") @@ -164,11 +618,6 @@ def main(): "Select Metric", options=available_metrics, help="Choose which metric to visualize" ) - # Percentile normalization option - use_percentile = st.sidebar.checkbox( - "Use Percentile Normalization", value=True, help="Apply percentile-based color normalization to plots" - ) - # Display some basic statistics st.header("Dataset Overview") col1, col2, col3 = st.columns(3) @@ -186,50 +635,235 @@ def main(): with st.expander("Best Parameters"): best_idx = results[f"mean_test_{selected_metric}"].idxmax() best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]] - st.dataframe(best_params.to_frame().T, use_container_width=True) + st.dataframe(best_params.to_frame().T, width="content") - # Main plots - st.header(f"Visualization for {selected_metric.capitalize()}") + # Create tabs for different visualizations + tab1, tab2 = st.tabs(["Search Results", "Model State"]) - # K-binned plots - st.subheader("K-Binned Parameter Space (Mean)") - with st.spinner("Generating mean plot..."): - if use_percentile: - fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50) + with tab1: + # Main plots + st.header(f"Visualization for {selected_metric.capitalize()}") + + # K-binned plot configuration + col_toggle, col_slider = st.columns([1, 1]) + + with col_toggle: + # Percentile normalization toggle for K-binned plots + use_percentile = st.toggle( + "Use Percentile Normalization", + value=True, + help="Apply percentile-based color normalization to K-binned parameter space plots", + ) + + with col_slider: + # Bin width slider for K-binned plots + k_min = int(results["initial_K"].min()) + k_max = int(results["initial_K"].max()) + k_range = k_max - k_min + + k_bin_width = st.slider( + "Initial K Bin Width", + min_value=10, + max_value=max(100, k_range // 2), + value=40, + step=10, + help=f"Width of bins for initial_K facets (range: {k_min}-{k_max})", + ) + + # Show estimated number of bins + estimated_bins = int(np.ceil(k_range / k_bin_width)) + st.caption(f"Creating approximately {estimated_bins} bins for initial_K") + + # Reload data if bin width changed from default + if k_bin_width != 40: + with st.spinner("Re-binning data..."): + results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width) + + # K-binned plots + col1, col2 = st.columns(2) + + with col1: + st.subheader("K-Binned Parameter Space (Mean)") + with st.spinner("Generating mean plot..."): + if use_percentile: + chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50) + else: + chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}") + st.altair_chart(chart1, use_container_width=True) + + with col2: + st.subheader("K-Binned Parameter Space (Std)") + with st.spinner("Generating std plot..."): + if use_percentile: + chart2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50) + else: + chart2 = _plot_k_binned(results, f"std_test_{selected_metric}") + st.altair_chart(chart2, use_container_width=True) + + # Epsilon-binned plots + col1, col2 = st.columns(2) + + with col1: + st.subheader("K vs eps_cl") + with st.spinner("Generating eps_cl plot..."): + chart3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}") + st.altair_chart(chart3, use_container_width=True) + + with col2: + st.subheader("K vs eps_e") + with st.spinner("Generating eps_e plot..."): + chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}") + st.altair_chart(chart4, use_container_width=True) + + # Optional: Raw data table + with st.expander("View Raw Results Data"): + st.dataframe(results, width="stretch") + + with tab2: + # Model state visualization + st.header("Best Estimator Model State") + + # Show basic model state info + with st.expander("Model State Information"): + st.write(f"**Variables:** {list(model_state.data_vars)}") + st.write(f"**Dimensions:** {dict(model_state.sizes)}") + st.write(f"**Coordinates:** {list(model_state.coords)}") + + # Show statistics + st.write("**Feature Weight Statistics:**") + feature_weights = model_state["feature_weights"].to_pandas() + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Mean Weight", f"{feature_weights.mean():.4f}") + with col2: + st.metric("Max Weight", f"{feature_weights.max():.4f}") + with col3: + st.metric("Total Features", len(feature_weights)) + + # Feature importance plot + st.subheader("Feature Importance") + st.markdown("The most important features based on learned feature weights from the best estimator.") + + # Slider to control number of features to display + top_n = st.slider( + "Number of top features to display", + min_value=5, + max_value=50, + value=10, + step=5, + help="Select how many of the most important features to visualize", + ) + + with st.spinner("Generating feature importance plot..."): + feature_chart = _plot_top_features(model_state, top_n=top_n) + st.altair_chart(feature_chart, use_container_width=True) + + st.markdown( + """ + **Interpretation:** + - **Magnitude**: Larger absolute values indicate more important features + """ + ) + + # Embedding features analysis (if present) + if embedding_feature_array is not None: + st.subheader("Embedding Feature Analysis") + st.markdown( + """ + Analysis of embedding features showing which aggregations, bands, and years + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating dimension summaries..."): + chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array) + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_agg, use_container_width=True) + with col2: + st.altair_chart(chart_band, use_container_width=True) + with col3: + st.altair_chart(chart_year, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap by Aggregation") + st.markdown("Shows the weight of each band-year combination for each aggregation type.") + with st.spinner("Generating heatmap..."): + heatmap_chart = _plot_embedding_heatmap(embedding_feature_array) + st.altair_chart(heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("Embedding Feature Statistics"): + st.write("**Overall Statistics:**") + n_emb_features = embedding_feature_array.size + mean_weight = float(embedding_feature_array.mean().values) + max_weight = float(embedding_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Embedding Features", n_emb_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top embedding features + st.write("**Top 10 Embedding Features:**") + emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() + top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] + st.dataframe(top_emb, width="stretch") else: - fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}") - st.pyplot(fig1.figure) - plt.close() + st.info("No embedding features found in this model.") - st.subheader("K-Binned Parameter Space (Std)") - with st.spinner("Generating std plot..."): - if use_percentile: - fig2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50) + # ERA5 features analysis (if present) + if era5_feature_array is not None: + st.subheader("ERA5 Feature Analysis") + st.markdown( + """ + Analysis of ERA5 climate features showing which variables and time periods + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating ERA5 dimension summaries..."): + chart_variable, chart_time = _plot_era5_summary(era5_feature_array) + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, use_container_width=True) + with col2: + st.altair_chart(chart_time, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap") + st.markdown("Shows the weight of each variable-time combination.") + with st.spinner("Generating ERA5 heatmap..."): + era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("ERA5 Feature Statistics"): + st.write("**Overall Statistics:**") + n_era5_features = era5_feature_array.size + mean_weight = float(era5_feature_array.mean().values) + max_weight = float(era5_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total ERA5 Features", n_era5_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top ERA5 features + st.write("**Top 10 ERA5 Features:**") + era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() + top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] + st.dataframe(top_era5, width="stretch") else: - fig2 = _plot_k_binned(results, f"std_test_{selected_metric}") - st.pyplot(fig2.figure) - plt.close() - - # Epsilon-binned plots - col1, col2 = st.columns(2) - - with col1: - st.subheader("K vs eps_cl") - with st.spinner("Generating eps_cl plot..."): - fig3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}") - st.pyplot(fig3.figure) - plt.close() - - with col2: - st.subheader("K vs eps_e") - with st.spinner("Generating eps_e plot..."): - fig4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}") - st.pyplot(fig4.figure) - plt.close() - - # Optional: Raw data table - with st.expander("View Raw Results Data"): - st.dataframe(results, use_container_width=True) + st.info("No ERA5 features found in this model.") if __name__ == "__main__":