Add an inference map

This commit is contained in:
Tobias Hölzer 2025-11-08 22:44:08 +01:00
parent 150f14ed52
commit fb522ddad5
6 changed files with 580 additions and 251 deletions

14
pixi.lock generated
View file

@ -288,6 +288,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/31/cc/099fab5a73909a117e9689c7da4c39a248595187f0f30dd879ad1d2c34ce/tblib-3.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl
@ -1300,7 +1301,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -1348,6 +1349,7 @@ packages:
- streamlit>=1.50.0,<2
- altair[all]>=5.5.0,<6
- h5netcdf>=1.7.3,<2
- streamlit-folium>=0.25.3,<0.26
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
@ -4560,6 +4562,16 @@ packages:
- streamlit[auth,charts,pdf,snowflake,sql] ; extra == 'all'
- rich>=11.0.0 ; extra == 'all'
requires_python: '>=3.9,!=3.9.7'
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
name: streamlit-folium
version: 0.25.3
sha256: cfdf085764da3f9b5e1e0668f6e4cc0385ff041c98133d023800983a875ca26c
requires_dist:
- streamlit>=1.13.0
- folium>=0.13,!=0.15.0
- jinja2
- branca
requires_python: '>=3.9'
- conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda
sha256: 09d3b6ac51d437bc996ad006d9f749ca5c645c1900a854a6c8f193cbd13f03a8
md5: 8c09fac3785696e1c477156192d64b91

View file

@ -49,7 +49,9 @@ dependencies = [
"xvec>=0.5.1",
"zarr[remote]>=3.1.3",
"geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2",
"streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
]
[project.scripts]
@ -57,7 +59,8 @@ create-grid = "entropice.grids:main"
darts = "entropice.darts:main"
alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli"
train = "entropice.training:cli"
train = "entropice.training:main"
dataset = "entropice.dataset:main"
[build-system]
requires = ["hatchling"]

130
src/entropice/dataset.py Normal file
View file

@ -0,0 +1,130 @@
"""Training dataset preparation and model training."""
from typing import Literal
import cyclopts
import geopandas as gpd
import pandas as pd
import seaborn as sns
import xarray as xr
from rich import pretty, traceback
from sklearn import set_config
from stopuhr import stopwatch
from entropice.paths import (
get_darts_rts_file,
get_embeddings_store,
get_era5_stores,
get_train_dataset_file,
)
traceback.install()
pretty.install()
set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
seasons = {10: "winter", 4: "summer"}
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
def _prep_era5(
rts: gpd.GeoDataFrame,
temporal: Literal["yearly", "seasonal", "shoulder"],
grid: Literal["hex", "healpix"],
level: int,
) -> pd.DataFrame:
era5_df = []
era5_store = get_era5_stores(temporal, grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
if temporal == "yearly":
df["t"] = df.index.get_level_values("time").year
elif temporal == "seasonal":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
embs_store = get_embeddings_store(grid=grid, level=level)
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
embeddings = embeddings.sel(cell=rts["cell_id"].values)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
embeddings_df = embeddings_df.rename(
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
)
return embeddings_df
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
"""
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
# Filter to coverage
if filter_target:
rts = rts[rts["darts_has_coverage"]]
# Convert hex cell_id to int
if grid == "hex":
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
# Add the lat / lon of the cell centers
rts["lon"] = rts.geometry.centroid.x
rts["lat"] = rts.geometry.centroid.y
# Get era5 data
era5_yearly = _prep_era5(rts, "yearly", grid, level)
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
# Get embeddings data
embeddings = _prep_embeddings(rts, grid, level)
# Combine datasets by cell id / cell
with stopwatch("Combine datasets"):
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
dataset_file = get_train_dataset_file(grid=grid, level=level)
dataset.reset_index().to_parquet(dataset_file)
def main():
cyclopts.run(prepare_dataset)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,61 @@
# ruff: noqa: N806
"""Inference runs on trained models."""
from typing import Literal
import geopandas as gpd
import pandas as pd
import torch
from entropy import ESPAClassifier
from rich import pretty, traceback
from sklearn import set_config
from entropice.paths import get_train_dataset_file
traceback.install()
pretty.install()
set_config(array_api_dispatch=True)
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
"""Get predicted probabilities for each cell.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
clf (ESPAClassifier): The trained classifier to use for predictions.
Returns:
list: A list of predicted probabilities for each cell.
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
print(f"Predicting probabilities for {len(data)} cells...")
# Predict in batches to avoid memory issues
batch_size = 100_000
preds = []
for i in range(0, len(data), batch_size):
batch = data.iloc[i : i + batch_size]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
X_batch = batch.drop(columns=cols_to_drop).dropna()
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32")
X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
batch_preds = gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"predicted_proba": batch_preds,
"geometry": cell_geoms,
},
crs="epsg:3413",
)
preds.append(batch_preds)
preds = gpd.GeoDataFrame(pd.concat(preds))
return preds

View file

@ -1,12 +1,13 @@
# ruff: noqa: N806
"""Training dataset preparation and model training."""
import pickle
from typing import Literal
import cyclopts
import geopandas as gpd
import pandas as pd
import seaborn as sns
import toml
import torch
import xarray as xr
from entropy import ESPAClassifier
@ -16,11 +17,9 @@ from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from stopuhr import stopwatch
from entropice.inference import predict_proba
from entropice.paths import (
get_cv_results_dir,
get_darts_rts_file,
get_embeddings_store,
get_era5_stores,
get_train_dataset_file,
)
@ -29,83 +28,7 @@ pretty.install()
set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
cli = cyclopts.App()
@cli.command()
def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
"""
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
# Filter to coverage
rts = rts[rts["darts_has_coverage"]]
# Convert hex cell_id to int
if grid == "hex":
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
# Get era5 data
era5_df = []
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
seasons = {
10: "winter",
4: "summer",
}
for temporal in ["yearly", "seasonal", "shoulder"]:
era5_store = get_era5_stores(temporal, grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
if temporal == "yearly":
df["t"] = df.index.get_level_values("time").year
elif temporal == "seasonal":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
# Get embeddings data
embs_store = get_embeddings_store(grid=grid, level=level)
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
embeddings = embeddings.sel(cell=rts["cell_id"].values)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# Combine datasets by cell id / cell
# TODO: use prefixes to easy split the features in analysis again
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
dataset_file = get_train_dataset_file(grid=grid, level=level)
dataset.reset_index().to_parquet(dataset_file)
@cli.command()
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"""Perform random cross-validation on the training dataset.
@ -117,6 +40,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
data = data[data["darts_has_coverage"]]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
@ -142,7 +66,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
clf,
param_grid,
n_iter=n_iter,
n_jobs=20,
n_jobs=16,
cv=cv,
random_state=42,
verbose=10,
@ -163,6 +87,45 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
# Store the search settings
settings = {
"grid": grid,
"level": level,
"random_state": 42,
"n_iter": n_iter,
"param_grid": {
"eps_cl": {
"distribution": "loguniform",
"low": param_grid["eps_cl"].a,
"high": param_grid["eps_cl"].b,
},
"eps_e": {
"distribution": "loguniform",
"low": param_grid["eps_e"].a,
"high": param_grid["eps_e"].b,
},
"initial_K": {
"distribution": "randint",
"low": param_grid["initial_K"].a,
"high": param_grid["initial_K"].b,
},
},
"cv_splits": cv.get_n_splits(),
"metrics": metrics,
}
settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}")
with open(settings_file, "w") as f:
toml.dump({"settings": settings}, f)
# Store the best estimator model
best_model_file = results_dir / "best_estimator_model.pkl"
print(f"Storing best estimator model to {best_model_file}")
with open(best_model_file, "wb") as f:
pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
# Store the search results
results = pd.DataFrame(search.cv_results_)
# Parse the params into individual columns
@ -171,7 +134,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
results["grid"] = grid
results["level"] = level
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
results_file = results_dir / "search_results.parquet"
print(f"Storing CV results to {results_file}")
results.to_parquet(results_file)
@ -219,9 +181,20 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
# Predict probabilities for all cells
print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file)
stopwatch.summary()
print("Done.")
def main():
cyclopts.run(random_cv)
if __name__ == "__main__":
cli()
main()

View file

@ -4,14 +4,177 @@ from datetime import datetime
from pathlib import Path
import altair as alt
import folium
import geopandas as gpd
import numpy as np
import pandas as pd
import streamlit as st
import xarray as xr
from streamlit_folium import st_folium
from entropice.paths import RESULTS_DIR
def get_available_result_files() -> list[Path]:
"""Get all available result files from RESULTS_DIR."""
if not RESULTS_DIR.exists():
return []
result_files = []
for search_dir in RESULTS_DIR.iterdir():
if not search_dir.is_dir():
continue
result_file = search_dir / "search_results.parquet"
state_file = search_dir / "best_estimator_state.nc"
preds_file = search_dir / "predicted_probabilities.parquet"
settings_file = search_dir / "search_settings.toml"
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
result_files.append(search_dir)
return sorted(result_files, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns.
Args:
file_path: Path to the results parquet file.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results = pd.read_parquet(file_path)
# Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min()
k_max = results["initial_K"].max()
# Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min())
eps_cl_max = np.log10(results["eps_cl"].max())
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
eps_e_min = np.log10(results["eps_e"].min())
eps_e_max = np.log10(results["eps_e"].max())
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
return results
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
"""Load a model state from a NetCDF file.
Args:
file_path (Path): The path to the NetCDF file.
Returns:
xr.Dataset: The model state as an xarray Dataset.
"""
return xr.open_dataset(file_path, engine="h5netcdf")
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
('agg', 'band', 'year') corresponding to the different components of the embedding features.
Returns None if no embedding features are found.
"""
def _is_embedding_feature(feature: str) -> bool:
return feature.startswith("embeddings_")
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[1] for f in embedding_features]),
band=("feature", [f.split("_")[2] for f in embedding_features]),
year=("feature", [f.split("_")[3] for f in embedding_features]),
)
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
return embedding_feature_array
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
Returns None if no ERA5 features are found.
"""
def _is_era5_feature(feature: str) -> bool:
return feature.startswith("era5_")
def _extract_var_name(feature: str) -> str:
feature = feature.replace("era5_", "")
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
return feature.rsplit("_", 2)[0]
else:
return feature.rsplit("_", 1)[0]
def _extract_time_name(feature: str) -> str:
feature = feature.replace("era5_", "")
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
return "_".join(feature.rsplit("_", 2)[-2:])
else:
return feature.rsplit("_", 1)[-1]
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
# TODO: Extract common features, e.g. area or water content
def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map:
"""Plot predicted probabilities on a map using Streamlit and Folium.
Args:
preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns.
Returns:
folium.Map: A Folium map object with the predicted probabilities visualized.
"""
m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
return m
def _plot_k_binned(
results: pd.DataFrame,
target: str,
@ -118,171 +281,6 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
return chart
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
"""Load results file and prepare binned columns.
Args:
file_path: Path to the results parquet file.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results = pd.read_parquet(file_path)
# Automatically determine bin width for initial_K based on data range
k_min = results["initial_K"].min()
k_max = results["initial_K"].max()
# Use configurable bin width, adapted to actual data range
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
# Automatically create logarithmic bins for epsilon parameters based on data range
# Use 10 bins spanning the actual data range
eps_cl_min = np.log10(results["eps_cl"].min())
eps_cl_max = np.log10(results["eps_cl"].max())
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
eps_e_min = np.log10(results["eps_e"].min())
eps_e_max = np.log10(results["eps_e"].max())
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
return results
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
"""Load a model state from a NetCDF file.
Args:
file_path (Path): The path to the NetCDF file.
Returns:
xr.Dataset: The model state as an xarray Dataset.
"""
return xr.open_dataset(file_path, engine="h5netcdf")
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
('agg', 'band', 'year') corresponding to the different components of the embedding features.
Returns None if no embedding features are found.
"""
def _is_embedding_feature(feature: str) -> bool:
parts = feature.split("_")
if len(parts) != 3:
return False
_, band, _ = parts
if not band.startswith("A"):
return False
if not band[1:].isdigit():
return False
return True
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[0] for f in embedding_features]),
band=("feature", [f.split("_")[1] for f in embedding_features]),
year=("feature", [f.split("_")[2] for f in embedding_features]),
)
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
return embedding_feature_array
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
Returns None if no ERA5 features are found.
"""
def _is_era5_feature(feature: str) -> bool:
# Instant fit if winter or summer in the name
if "winter" in feature or "spring" in feature:
return True
# Instant fit if OND, JFM, AMJ or JAS in the name
if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]):
return True
parts = feature.split("_")
if len(parts) == 3:
_, band, year = parts
if band.startswith("A"):
return False
if year.isdigit():
return True
elif parts[-1].isdigit():
return True
return False
def _extract_var_name(feature: str) -> str:
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
return feature.rsplit("_", 2)[0]
else:
return feature.rsplit("_", 1)[0]
def _extract_time_name(feature: str) -> str:
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
return "_".join(feature.rsplit("_", 2)[-2:])
else:
return feature.rsplit("_", 1)[-1]
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
# TODO: Extract common features, e.g. area or water content
def get_available_result_files() -> list[Path]:
"""Get all available result files from RESULTS_DIR."""
if not RESULTS_DIR.exists():
return []
result_files = []
for search_dir in RESULTS_DIR.iterdir():
if not search_dir.is_dir():
continue
result_file = search_dir / "search_results.parquet"
state_file = search_dir / "best_estimator_state.nc"
if result_file.exists() and state_file.exists():
result_files.append(search_dir)
return sorted(result_files, reverse=True) # Most recent first
def _parse_results_dir_name(results_dir: Path) -> str:
gridname, date = results_dir.name.split("_random_search_cv")
gridname = gridname.lstrip("permafrost_")
@ -574,6 +572,95 @@ def _plot_era5_summary(era5_array: xr.DataArray):
return chart_variable, chart_time
def _plot_box_assignments(model_state: xr.Dataset):
"""Create a heatmap showing which boxes are assigned to which labels/classes.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
Returns:
Altair chart showing the box-to-label assignment heatmap.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame for plotting
df = box_assignments.to_dataframe(name="assignment").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
y=alt.Y("class:N", title="Class Label"),
color=alt.Color(
"assignment:Q",
scale=alt.Scale(scheme="viridis"),
title="Assignment Strength",
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("box:O", title="Box"),
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
],
)
.properties(
height=150,
title="Box-to-Label Assignments (Lambda Matrix)",
)
)
return chart
def _plot_box_assignment_bars(model_state: xr.Dataset):
"""Create a bar chart showing how many boxes are assigned to each class.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
Returns:
Altair chart showing count of boxes per class.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame
df = box_assignments.to_dataframe(name="assignment").reset_index()
# For each box, find which class it's most strongly assigned to
box_to_class = df.groupby("box")["assignment"].idxmax()
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
# Count boxes per class
counts = primary_classes.groupby("class").size().reset_index(name="count")
# Create bar chart
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Class Label"),
y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color("class:N", title="Class", legend=None),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Number of Boxes"),
],
)
.properties(
width=600,
height=300,
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
)
)
return chart
def main():
"""Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -609,6 +696,7 @@ def main():
model_state["feature_weights"] *= n_features
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
st.sidebar.success(f"Loaded {len(results)} results")
@ -618,6 +706,12 @@ def main():
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
)
# Map visualization
st.header("Predictions Map")
st.markdown("Map showing predicted classes from the best estimator")
m = _plot_prediction_map(predictions)
st_folium(m, width="100%", height=300)
# Display some basic statistics
st.header("Dataset Overview")
col1, col2, col3 = st.columns(3)
@ -765,6 +859,62 @@ def main():
"""
)
# Box-to-Label Assignment Visualization
st.subheader("Box-to-Label Assignments")
st.markdown(
"""
This visualization shows how the learned boxes (prototypes in feature space) are
assigned to different class labels. The ESPA classifier learns K boxes and assigns
them to classes through the Lambda matrix. Higher values indicate stronger assignment
of a box to a particular class.
"""
)
with st.spinner("Generating box assignment visualizations..."):
col1, col2 = st.columns([0.7, 0.3])
with col1:
st.markdown("### Assignment Heatmap")
box_assignment_heatmap = _plot_box_assignments(model_state)
st.altair_chart(box_assignment_heatmap, use_container_width=True)
with col2:
st.markdown("### Box Count by Class")
box_assignment_bars = _plot_box_assignment_bars(model_state)
st.altair_chart(box_assignment_bars, use_container_width=True)
# Show statistics
with st.expander("Box Assignment Statistics"):
box_assignments = model_state["box_assignments"].to_pandas()
st.write("**Assignment Matrix Statistics:**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Boxes", len(box_assignments.columns))
with col2:
st.metric("Number of Classes", len(box_assignments.index))
with col3:
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
with col4:
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
# Show which boxes are most strongly assigned to each class
st.write("**Top Box Assignments per Class:**")
for class_label in box_assignments.index:
top_boxes = box_assignments.loc[class_label].nlargest(5)
st.write(
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
)
st.markdown(
"""
**Interpretation:**
- Each box can be assigned to multiple classes with different strengths
- Boxes with higher assignment values for a class contribute more to that class's predictions
- The distribution shows how the model partitions the feature space for classification
"""
)
# Embedding features analysis (if present)
if embedding_feature_array is not None:
st.subheader("Embedding Feature Analysis")