Extent the dashboard to use altair

This commit is contained in:
Tobias Hölzer 2025-11-08 03:07:38 +01:00
parent d498b1e752
commit 150f14ed52
6 changed files with 886 additions and 265 deletions

61
pixi.lock generated
View file

@ -94,6 +94,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
@ -209,6 +210,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/c6/2d/f0b184fa88d6630aa267680bdb8623fb69cb0d024b8c6f0d23f9a0f406d3/multidict-6.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
@ -300,6 +302,9 @@ environments:
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl
@ -417,6 +422,27 @@ packages:
- sphinxext-altair ; extra == 'doc'
- vl-convert-python>=1.7.0 ; extra == 'save'
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
name: altair-tiles
version: 0.4.0
sha256: eeed1a6d89800f6cf5aafa6a59ee735bc7c243cd133acebccabfdbf69cc7e33c
requires_dist:
- altair
- mercantile
- xyzservices
- geopandas ; extra == 'dev'
- ghp-import ; extra == 'dev'
- hatch ; extra == 'dev'
- ipykernel ; extra == 'dev'
- ipython ; extra == 'dev'
- mypy ; extra == 'dev'
- pytest ; extra == 'dev'
- ruff>=0.1.4 ; extra == 'dev'
- vega-datasets ; extra == 'dev'
- vl-convert-python ; extra == 'dev'
- jupyter-book ; extra == 'doc'
- vl-convert-python ; extra == 'doc'
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl
name: anywidget
version: 0.9.18
@ -1274,7 +1300,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: 39f2dabdc6891e121e03650dfde2f69b084370df8561d63478c6b3b518530e54
sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -1292,7 +1318,6 @@ packages:
- geemap>=0.36.3
- geopandas>=1.1.0
- h3>=4.2.2
- h5netcdf>=1.6.4
- ipycytoscape>=1.3.3
- ipykernel>=6.29.5
- ipywidgets>=8.1.7
@ -1321,6 +1346,8 @@ packages:
- zarr[remote]>=3.1.3
- geocube>=0.7.1,<0.8
- streamlit>=1.50.0,<2
- altair[all]>=5.5.0,<6
- h5netcdf>=1.7.3,<2
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
@ -3159,6 +3186,15 @@ packages:
version: 0.1.2
sha256: 84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
name: mercantile
version: 1.2.1
sha256: 30f457a73ee88261aab787b7069d85961a5703bb09dc57a170190bc042cd023f
requires_dist:
- click>=3.0
- check-manifest ; extra == 'dev'
- hypothesis ; extra == 'test'
- pytest ; extra == 'test'
- conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda
sha256: 1e59d0dc811f150d39c2ff2da930d69dcb91cb05966b7df5b7d85133006668ed
md5: e4ab075598123e783b788b995afbdad0
@ -4728,6 +4764,27 @@ packages:
- pysocks>=1.5.6,!=1.5.7,<2.0 ; extra == 'socks'
- zstandard>=0.18.0 ; extra == 'zstd'
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
name: vega-datasets
version: 0.9.0
sha256: 3d7c63917be6ca9b154b565f4779a31fedce57b01b5b9d99d8a34a7608062a1d
requires_dist:
- pandas
requires_python: '>=3.5'
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: vegafusion
version: 2.0.3
sha256: 0b11c19a70f1bfe3d23d0a09aeecaac7bd03fac01a966d69fbd4dd8679dcb7e7
requires_dist:
- arro3-core
- packaging
- narwhals>=1.42
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: vl-convert-python
version: 1.8.0
sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
name: watchdog
version: 6.0.0

View file

@ -22,7 +22,6 @@ dependencies = [
"geemap>=0.36.3",
"geopandas>=1.1.0",
"h3>=4.2.2",
"h5netcdf>=1.6.4",
"ipycytoscape>=1.3.3",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.7",
@ -50,7 +49,7 @@ dependencies = [
"xvec>=0.5.1",
"zarr[remote]>=3.1.3",
"geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2",
"streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2",
]
[project.scripts]

View file

View file

@ -87,10 +87,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file
def get_cv_results_file(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}"
results_dir.mkdir(parents=True, exist_ok=True)
results_file = results_dir / "search_results.parquet"
return results_file
return results_dir

View file

@ -1,14 +1,10 @@
# ruff: noqa: N806
"""Training dataset preparation and model training."""
from pathlib import Path
from typing import Literal
import cyclopts
import geopandas as gpd
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
@ -21,7 +17,7 @@ from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from stopuhr import stopwatch
from entropice.paths import (
get_cv_results_file,
get_cv_results_dir,
get_darts_rts_file,
get_embeddings_store,
get_era5_stores,
@ -55,23 +51,41 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
# Get era5 data
era5_store = get_era5_stores("yearly", grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
era5_df = []
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
df["year"] = df.index.get_level_values("time").year
df = (
df.pivot_table(index="cell_ids", columns="year", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
# TODO: season and shoulder data
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
seasons = {
10: "winter",
4: "summer",
}
for temporal in ["yearly", "seasonal", "shoulder"]:
era5_store = get_era5_stores(temporal, grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
if temporal == "yearly":
df["t"] = df.index.get_level_values("time").year
elif temporal == "seasonal":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
# Get embeddings data
embs_store = get_embeddings_store(grid=grid, level=level)
@ -83,6 +97,7 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# Combine datasets by cell id / cell
# TODO: use prefixes to easy split the features in analysis again
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
@ -91,12 +106,13 @@ def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
@cli.command()
def random_cv(grid: Literal["hex", "healpix"], level: int):
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"""Perform random cross-validation on the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
"""
data = get_train_dataset_file(grid=grid, level=level)
@ -125,8 +141,8 @@ def random_cv(grid: Literal["hex", "healpix"], level: int):
search = RandomizedSearchCV(
clf,
param_grid,
n_iter=20,
n_jobs=24,
n_iter=n_iter,
n_jobs=20,
cv=cv,
random_state=42,
verbose=10,
@ -155,141 +171,57 @@ def random_cv(grid: Literal["hex", "healpix"], level: int):
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
results["grid"] = grid
results["level"] = level
results_file = get_cv_results_file("random_search", grid=grid, level=level)
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
results_file = results_dir / "search_results.parquet"
print(f"Storing CV results to {results_file}")
results.to_parquet(results_file)
# Get the inner state of the best estimator
best_estimator = search.best_estimator_
# Annotate the state with xarray metadata
features = X_data.columns.tolist()
labels = y_data.unique().tolist()
boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(),
dims=["feature", "box"],
coords={"feature": features, "box": boxes},
name="box_centers",
attrs={"description": "Centers of the boxes in feature space."},
)
box_assignments = xr.DataArray(
best_estimator.Lambda_.cpu().numpy(),
dims=["class", "box"],
coords={"class": labels, "box": boxes},
name="box_assignments",
attrs={"description": "Assignments of samples to boxes."},
)
feature_weights = xr.DataArray(
best_estimator.W_.cpu().numpy(),
dims=["feature"],
coords={"feature": features},
name="feature_weights",
attrs={"description": "Feature weights for each box."},
)
state = xr.Dataset(
{
"box_centers": box_centers,
"box_assignments": box_assignments,
"feature_weights": feature_weights,
},
attrs={
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
"grid": grid,
"level": level,
},
)
state_file = results_dir / "best_estimator_state.nc"
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
stopwatch.summary()
print("Done.")
plot_random_cv_results(results_file)
def _plot_k_binned(
results: pd.DataFrame, target: str, *, vmin_percentile: float | None = None, vmax_percentile: float | None = None
):
assert vmin_percentile is None or vmax_percentile is None, (
"Only one of vmin_percentile or vmax_percentile can be set."
)
assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results."
assert target in results.columns, f"{target} column not found in results."
assert "eps_e" in results.columns, "eps_e column not found in results."
assert "eps_cl" in results.columns, "eps_cl column not found in results."
# add a colorbar instead of the sampled legend
cmap = sns.color_palette("ch:", as_cmap=True)
# sufisticated normalization
if vmin_percentile is not None:
vmin = np.percentile(results[target], vmin_percentile)
norm = mcolors.Normalize(vmin=vmin)
elif vmax_percentile is not None:
vmax = np.percentile(results[target], vmax_percentile)
norm = mcolors.Normalize(vmax=vmax)
else:
norm = mcolors.Normalize()
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# nice col-wrap based on columns
n_cols = results["initial_K_binned"].unique().size
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
scatter = sns.relplot(
data=results,
x="eps_e",
y="eps_cl",
hue=target,
hue_norm=sm.norm,
palette=cmap,
legend=False,
col="initial_K_binned",
col_wrap=col_wrap,
)
# Apply log scale to all axes
for ax in scatter.axes.flat:
ax.set_xscale("log")
ax.set_yscale("log")
# Tight layout
scatter.figure.tight_layout()
# Add a shared colorbar at the bottom
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
cbar.set_label(target)
return scatter
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
assert "initial_K" in results.columns, "initial_K column not found in results."
assert metric in results.columns, f"{metric} not found in results."
if target == "eps_cl":
hue = "eps_cl"
col = "eps_e_binned"
elif target == "eps_e":
hue = "eps_e"
col = "eps_cl_binned"
assert hue in results.columns, f"{hue} column not found in results."
assert col in results.columns, f"{col} column not found in results."
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
@cli.command()
def plot_random_cv_results(file: Path):
"""Plot analysis of the results from the RandomCVSearch.
Args:
file (Path): The file of the results.
"""
print(f"Plotting random CV results from {file}...")
results = pd.read_parquet(file)
# Bin the initial_K into 40er bins
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
# Bin the eps_cl and eps_e into logarithmic bins
eps_cl_bins = np.logspace(-3, 7, num=10)
eps_e_bins = np.logspace(-3, 7, num=10)
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
figdir = file.parent
# K-Plots
metrics = ["f1"]
for metric in metrics:
_plot_k_binned(
results,
f"mean_test_{metric}",
vmin_percentile=50,
).figure.savefig(figdir / f"params3d-mean_{metric}.pdf")
_plot_k_binned(
results,
f"std_test_{metric}",
vmax_percentile=50,
).figure.savefig(figdir / f"params3d-std_{metric}.pdf")
_plot_k_binned(results, f"mean_test_{metric}").figure.savefig(figdir / f"params3d-mean_{metric}-noperc.pdf")
_plot_k_binned(results, f"std_test_{metric}").figure.savefig(figdir / f"params3d-std_{metric}-noperc.pdf")
# eps-Plots
_plot_eps_binned(
results,
"eps_cl",
f"mean_test_{metric}",
).figure.savefig(figdir / f"k-eps_cl-mean_{metric}.pdf")
_plot_eps_binned(
results,
"eps_e",
f"mean_test_{metric}",
).figure.savefig(figdir / f"k-eps_e-mean_{metric}.pdf")
# Close all figures
plt.close("all")
print("Done.")
if __name__ == "__main__":
cli()

View file

@ -1,18 +1,16 @@
"""Streamlit dashboard for training analysis results visualization."""
from datetime import datetime
from pathlib import Path
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import altair as alt
import numpy as np
import pandas as pd
import seaborn as sns
import streamlit as st
import xarray as xr
from entropice.paths import RESULTS_DIR
sns.set_theme("talk", "whitegrid")
def _plot_k_binned(
results: pd.DataFrame,
@ -30,50 +28,47 @@ def _plot_k_binned(
assert "eps_e" in results.columns, "eps_e column not found in results."
assert "eps_cl" in results.columns, "eps_cl column not found in results."
# add a colorbar instead of the sampled legend
cmap = sns.color_palette("ch:", as_cmap=True)
# sophisticated normalization
# Prepare data
plot_data = results[["eps_e", "eps_cl", "initial_K_binned", target]].copy()
# Sort bins by their left value and convert to string with sorted categories
plot_data = plot_data.sort_values("initial_K_binned")
plot_data["initial_K_binned"] = plot_data["initial_K_binned"].astype(str)
bin_order = plot_data["initial_K_binned"].unique().tolist()
# Determine color scale domain
if vmin_percentile is not None:
vmin = np.percentile(results[target], vmin_percentile)
norm = mcolors.Normalize(vmin=vmin)
color_scale = alt.Scale(scheme="viridis", domain=[vmin, plot_data[target].max()])
elif vmax_percentile is not None:
vmax = np.percentile(results[target], vmax_percentile)
norm = mcolors.Normalize(vmax=vmax)
color_scale = alt.Scale(scheme="viridis", domain=[plot_data[target].min(), vmax])
else:
norm = mcolors.Normalize()
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
color_scale = alt.Scale(scheme="viridis")
# nice col-wrap based on columns
n_cols = results["initial_K_binned"].unique().size
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
scatter = sns.relplot(
data=results,
x="eps_e",
y="eps_cl",
hue=target,
hue_norm=sm.norm,
palette=cmap,
legend=False,
col="initial_K_binned",
col_wrap=col_wrap,
# Create the chart
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(
"eps_e:Q",
scale=alt.Scale(type="log"),
axis=alt.Axis(title="eps_e", grid=True, gridOpacity=0.5),
),
y=alt.Y(
"eps_cl:Q",
scale=alt.Scale(type="log"),
axis=alt.Axis(title="eps_cl", grid=True, gridOpacity=0.5),
),
color=alt.Color(f"{target}:Q", scale=color_scale, title=target),
tooltip=["eps_e:Q", "eps_cl:Q", alt.Tooltip(f"{target}:Q", format=".4f"), "initial_K_binned:N"],
)
.properties(width=200, height=200)
.facet(facet=alt.Facet("initial_K_binned:N", title="Initial K", sort=bin_order), columns=5)
)
# Apply log scale to all axes
for ax in scatter.axes.flat:
ax.set_xscale("log")
ax.set_yscale("log")
# Tight layout
scatter.figure.tight_layout()
# Add a shared colorbar at the bottom
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
cbar.set_label(target)
return scatter
return chart
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
@ -93,25 +88,183 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
assert hue in results.columns, f"{hue} column not found in results."
assert col in results.columns, f"{col} column not found in results."
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
# Prepare data
plot_data = results[["initial_K", metric, hue, col]].copy()
# Sort bins by their left value and convert to string with sorted categories
plot_data = plot_data.sort_values(col)
plot_data[col] = plot_data[col].astype(str)
bin_order = plot_data[col].unique().tolist()
# Create the chart
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X("initial_K:Q", title="Initial K"),
y=alt.Y(f"{metric}:Q", title=metric),
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue),
tooltip=[
"initial_K:Q",
alt.Tooltip(f"{metric}:Q", format=".4f"),
alt.Tooltip(f"{hue}:Q", format=".2e"),
f"{col}:N",
],
)
.properties(width=200, height=200)
.facet(facet=alt.Facet(f"{col}:N", title=col.replace("_binned", ""), sort=bin_order), columns=5)
)
return chart
def load_and_prepare_results(file_path: Path) -> pd.DataFrame:
"""Load results file and prepare binned columns."""
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns.
Args:
file_path: Path to the results parquet file.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results = pd.read_parquet(file_path)
# Bin the initial_K into 40er bins
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
# Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min()
k_max = results["initial_K"].max()
# Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min())
eps_cl_max = np.log10(results["eps_cl"].max())
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
eps_e_min = np.log10(results["eps_e"].min())
eps_e_max = np.log10(results["eps_e"].max())
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
# Bin the eps_cl and eps_e into logarithmic bins
eps_cl_bins = np.logspace(-3, 7, num=10)
eps_e_bins = np.logspace(-3, 7, num=10)
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
return results
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
"""Load a model state from a NetCDF file.
Args:
file_path (Path): The path to the NetCDF file.
Returns:
xr.Dataset: The model state as an xarray Dataset.
"""
return xr.open_dataset(file_path, engine="h5netcdf")
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
('agg', 'band', 'year') corresponding to the different components of the embedding features.
Returns None if no embedding features are found.
"""
def _is_embedding_feature(feature: str) -> bool:
parts = feature.split("_")
if len(parts) != 3:
return False
_, band, _ = parts
if not band.startswith("A"):
return False
if not band[1:].isdigit():
return False
return True
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[0] for f in embedding_features]),
band=("feature", [f.split("_")[1] for f in embedding_features]),
year=("feature", [f.split("_")[2] for f in embedding_features]),
)
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
return embedding_feature_array
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
Returns None if no ERA5 features are found.
"""
def _is_era5_feature(feature: str) -> bool:
# Instant fit if winter or summer in the name
if "winter" in feature or "spring" in feature:
return True
# Instant fit if OND, JFM, AMJ or JAS in the name
if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]):
return True
parts = feature.split("_")
if len(parts) == 3:
_, band, year = parts
if band.startswith("A"):
return False
if year.isdigit():
return True
elif parts[-1].isdigit():
return True
return False
def _extract_var_name(feature: str) -> str:
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
return feature.rsplit("_", 2)[0]
else:
return feature.rsplit("_", 1)[0]
def _extract_time_name(feature: str) -> str:
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
return "_".join(feature.rsplit("_", 2)[-2:])
else:
return feature.rsplit("_", 1)[-1]
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
# TODO: Extract common features, e.g. area or water content
def get_available_result_files() -> list[Path]:
"""Get all available result files from RESULTS_DIR."""
if not RESULTS_DIR.exists():
@ -119,14 +272,308 @@ def get_available_result_files() -> list[Path]:
result_files = []
for search_dir in RESULTS_DIR.iterdir():
if search_dir.is_dir():
result_file = search_dir / "search_results.parquet"
if result_file.exists():
result_files.append(result_file)
if not search_dir.is_dir():
continue
result_file = search_dir / "search_results.parquet"
state_file = search_dir / "best_estimator_state.nc"
if result_file.exists() and state_file.exists():
result_files.append(search_dir)
return sorted(result_files, reverse=True) # Most recent first
def _parse_results_dir_name(results_dir: Path) -> str:
gridname, date = results_dir.name.split("_random_search_cv")
gridname = gridname.lstrip("permafrost_")
date = datetime.strptime(date, "%Y%m%d-%H%M%S")
date = date.strftime("%Y-%m-%d %H:%M:%S")
return f"{gridname} ({date})"
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
"""Plot the top N most important features based on feature weights.
Args:
model_state: The xarray Dataset containing the model state.
top_n: Number of top features to display.
Returns:
Altair chart showing the top features by importance.
"""
# Extract feature weights
feature_weights = model_state["feature_weights"].to_pandas()
# Sort by absolute weight and take top N
top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True)
# Create DataFrame for plotting with original (signed) weights
plot_data = pd.DataFrame(
{
"feature": top_features.index,
"weight": feature_weights.loc[top_features.index].to_numpy(),
"abs_weight": top_features.to_numpy(),
}
)
# Create horizontal bar chart
chart = (
alt.Chart(plot_data)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=400,
title=f"Top {top_n} Most Important Features",
)
)
return chart
def _plot_embedding_heatmap(embedding_array: xr.DataArray):
"""Create a heatmap showing embedding feature weights across bands and years.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = embedding_array.to_dataframe(name="weight").reset_index()
# Create faceted heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("year:O", title="Year"),
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("agg:N", title="Aggregation"),
alt.Tooltip("band:N", title="Band"),
alt.Tooltip("year:O", title="Year"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(width=200, height=200)
.facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11)
)
return chart
def _plot_embedding_aggregation_summary(embedding_array: xr.DataArray):
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Tuple of three Altair charts (by_agg, by_band, by_year).
"""
# Aggregate by different dimensions
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
# Create DataFrames
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
# Sort by weight
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
df_band = df_band.sort_values("mean_abs_weight", ascending=True)
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_agg = (
alt.Chart(df_agg)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Aggregation", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="blues"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Aggregation"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Aggregation")
)
chart_band = (
alt.Chart(df_band)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Band", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="greens"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Band"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Band")
)
chart_year = (
alt.Chart(df_year)
.mark_bar()
.encode(
y=alt.Y("dimension:O", title="Year", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="oranges"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:O", title="Year"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Year")
)
return chart_agg, chart_band, chart_year
def _plot_era5_heatmap(era5_array: xr.DataArray):
"""Create a heatmap showing ERA5 feature weights across variables and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = era5_array.to_dataframe(name="weight").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("time:N", title="Time", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("time:N", title="Time"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(
height=400,
title="ERA5 Feature Weights Heatmap",
)
)
return chart
def _plot_era5_summary(era5_array: xr.DataArray):
"""Create bar charts summarizing ERA5 weights by variable and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Tuple of two Altair charts (by_variable, by_time).
"""
# Aggregate by different dimensions
by_variable = era5_array.mean(dim="time").to_pandas().abs()
by_time = era5_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
df_time = df_time.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_variable = (
alt.Chart(df_variable)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="purples"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Variable"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Variable")
)
chart_time = (
alt.Chart(df_time)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="teals"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Time"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Time")
)
return chart_variable, chart_time
def main():
"""Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -138,23 +585,30 @@ def main():
st.sidebar.header("Configuration")
# Get available result files
result_files = get_available_result_files()
result_dirs = get_available_result_files()
if not result_files:
if not result_dirs:
st.error(f"No result files found in {RESULTS_DIR}")
st.info("Please run a random CV search first to generate results.")
return
# File selection
file_options = {str(f.parent.name): f for f in result_files}
selected_file_name = st.sidebar.selectbox(
"Select Result File", options=list(file_options.keys()), help="Choose a search result file to visualize"
# Directory selection
dir_options = {_parse_results_dir_name(f): f for f in result_dirs}
selected_dir_name = st.sidebar.selectbox(
"Select Result Directory",
options=list(dir_options.keys()),
help="Choose a search result directory to visualize",
)
selected_file = file_options[selected_file_name]
results_dir = dir_options[selected_dir_name]
# Load and prepare data
# Load and prepare data with default bin width (will be reloaded with custom width later)
with st.spinner("Loading results..."):
results = load_and_prepare_results(selected_file)
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=40)
model_state = load_and_prepare_model_state(results_dir / "best_estimator_state.nc")
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
st.sidebar.success(f"Loaded {len(results)} results")
@ -164,11 +618,6 @@ def main():
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Percentile normalization option
use_percentile = st.sidebar.checkbox(
"Use Percentile Normalization", value=True, help="Apply percentile-based color normalization to plots"
)
# Display some basic statistics
st.header("Dataset Overview")
col1, col2, col3 = st.columns(3)
@ -186,50 +635,235 @@ def main():
with st.expander("Best Parameters"):
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
st.dataframe(best_params.to_frame().T, use_container_width=True)
st.dataframe(best_params.to_frame().T, width="content")
# Main plots
st.header(f"Visualization for {selected_metric.capitalize()}")
# Create tabs for different visualizations
tab1, tab2 = st.tabs(["Search Results", "Model State"])
# K-binned plots
st.subheader("K-Binned Parameter Space (Mean)")
with st.spinner("Generating mean plot..."):
if use_percentile:
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50)
with tab1:
# Main plots
st.header(f"Visualization for {selected_metric.capitalize()}")
# K-binned plot configuration
col_toggle, col_slider = st.columns([1, 1])
with col_toggle:
# Percentile normalization toggle for K-binned plots
use_percentile = st.toggle(
"Use Percentile Normalization",
value=True,
help="Apply percentile-based color normalization to K-binned parameter space plots",
)
with col_slider:
# Bin width slider for K-binned plots
k_min = int(results["initial_K"].min())
k_max = int(results["initial_K"].max())
k_range = k_max - k_min
k_bin_width = st.slider(
"Initial K Bin Width",
min_value=10,
max_value=max(100, k_range // 2),
value=40,
step=10,
help=f"Width of bins for initial_K facets (range: {k_min}-{k_max})",
)
# Show estimated number of bins
estimated_bins = int(np.ceil(k_range / k_bin_width))
st.caption(f"Creating approximately {estimated_bins} bins for initial_K")
# Reload data if bin width changed from default
if k_bin_width != 40:
with st.spinner("Re-binning data..."):
results = load_and_prepare_results(results_dir / "search_results.parquet", k_bin_width=k_bin_width)
# K-binned plots
col1, col2 = st.columns(2)
with col1:
st.subheader("K-Binned Parameter Space (Mean)")
with st.spinner("Generating mean plot..."):
if use_percentile:
chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50)
else:
chart1 = _plot_k_binned(results, f"mean_test_{selected_metric}")
st.altair_chart(chart1, use_container_width=True)
with col2:
st.subheader("K-Binned Parameter Space (Std)")
with st.spinner("Generating std plot..."):
if use_percentile:
chart2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50)
else:
chart2 = _plot_k_binned(results, f"std_test_{selected_metric}")
st.altair_chart(chart2, use_container_width=True)
# Epsilon-binned plots
col1, col2 = st.columns(2)
with col1:
st.subheader("K vs eps_cl")
with st.spinner("Generating eps_cl plot..."):
chart3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}")
st.altair_chart(chart3, use_container_width=True)
with col2:
st.subheader("K vs eps_e")
with st.spinner("Generating eps_e plot..."):
chart4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
st.altair_chart(chart4, use_container_width=True)
# Optional: Raw data table
with st.expander("View Raw Results Data"):
st.dataframe(results, width="stretch")
with tab2:
# Model state visualization
st.header("Best Estimator Model State")
# Show basic model state info
with st.expander("Model State Information"):
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
# Show statistics
st.write("**Feature Weight Statistics:**")
feature_weights = model_state["feature_weights"].to_pandas()
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
with col2:
st.metric("Max Weight", f"{feature_weights.max():.4f}")
with col3:
st.metric("Total Features", len(feature_weights))
# Feature importance plot
st.subheader("Feature Importance")
st.markdown("The most important features based on learned feature weights from the best estimator.")
# Slider to control number of features to display
top_n = st.slider(
"Number of top features to display",
min_value=5,
max_value=50,
value=10,
step=5,
help="Select how many of the most important features to visualize",
)
with st.spinner("Generating feature importance plot..."):
feature_chart = _plot_top_features(model_state, top_n=top_n)
st.altair_chart(feature_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- **Magnitude**: Larger absolute values indicate more important features
"""
)
# Embedding features analysis (if present)
if embedding_feature_array is not None:
st.subheader("Embedding Feature Analysis")
st.markdown(
"""
Analysis of embedding features showing which aggregations, bands, and years
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating dimension summaries..."):
chart_agg, chart_band, chart_year = _plot_embedding_aggregation_summary(embedding_feature_array)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_agg, use_container_width=True)
with col2:
st.altair_chart(chart_band, use_container_width=True)
with col3:
st.altair_chart(chart_year, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap by Aggregation")
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
with st.spinner("Generating heatmap..."):
heatmap_chart = _plot_embedding_heatmap(embedding_feature_array)
st.altair_chart(heatmap_chart, use_container_width=True)
# Statistics
with st.expander("Embedding Feature Statistics"):
st.write("**Overall Statistics:**")
n_emb_features = embedding_feature_array.size
mean_weight = float(embedding_feature_array.mean().values)
max_weight = float(embedding_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Embedding Features", n_emb_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top embedding features
st.write("**Top 10 Embedding Features:**")
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
st.dataframe(top_emb, width="stretch")
else:
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}")
st.pyplot(fig1.figure)
plt.close()
st.info("No embedding features found in this model.")
st.subheader("K-Binned Parameter Space (Std)")
with st.spinner("Generating std plot..."):
if use_percentile:
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50)
# ERA5 features analysis (if present)
if era5_feature_array is not None:
st.subheader("ERA5 Feature Analysis")
st.markdown(
"""
Analysis of ERA5 climate features showing which variables and time periods
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating ERA5 dimension summaries..."):
chart_variable, chart_time = _plot_era5_summary(era5_feature_array)
col1, col2 = st.columns(2)
with col1:
st.altair_chart(chart_variable, use_container_width=True)
with col2:
st.altair_chart(chart_time, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap")
st.markdown("Shows the weight of each variable-time combination.")
with st.spinner("Generating ERA5 heatmap..."):
era5_heatmap_chart = _plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, use_container_width=True)
# Statistics
with st.expander("ERA5 Feature Statistics"):
st.write("**Overall Statistics:**")
n_era5_features = era5_feature_array.size
mean_weight = float(era5_feature_array.mean().values)
max_weight = float(era5_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total ERA5 Features", n_era5_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top ERA5 features
st.write("**Top 10 ERA5 Features:**")
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
st.dataframe(top_era5, width="stretch")
else:
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}")
st.pyplot(fig2.figure)
plt.close()
# Epsilon-binned plots
col1, col2 = st.columns(2)
with col1:
st.subheader("K vs eps_cl")
with st.spinner("Generating eps_cl plot..."):
fig3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}")
st.pyplot(fig3.figure)
plt.close()
with col2:
st.subheader("K vs eps_e")
with st.spinner("Generating eps_e plot..."):
fig4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
st.pyplot(fig4.figure)
plt.close()
# Optional: Raw data table
with st.expander("View Raw Results Data"):
st.dataframe(results, use_container_width=True)
st.info("No ERA5 features found in this model.")
if __name__ == "__main__":