Add an inference map
This commit is contained in:
parent
150f14ed52
commit
fb522ddad5
6 changed files with 580 additions and 251 deletions
14
pixi.lock
generated
14
pixi.lock
generated
|
|
@ -288,6 +288,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/31/cc/099fab5a73909a117e9689c7da4c39a248595187f0f30dd879ad1d2c34ce/tblib-3.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl
|
||||
|
|
@ -1300,7 +1301,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0
|
||||
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -1348,6 +1349,7 @@ packages:
|
|||
- streamlit>=1.50.0,<2
|
||||
- altair[all]>=5.5.0,<6
|
||||
- h5netcdf>=1.7.3,<2
|
||||
- streamlit-folium>=0.25.3,<0.26
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
name: entropy
|
||||
|
|
@ -4560,6 +4562,16 @@ packages:
|
|||
- streamlit[auth,charts,pdf,snowflake,sql] ; extra == 'all'
|
||||
- rich>=11.0.0 ; extra == 'all'
|
||||
requires_python: '>=3.9,!=3.9.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||
name: streamlit-folium
|
||||
version: 0.25.3
|
||||
sha256: cfdf085764da3f9b5e1e0668f6e4cc0385ff041c98133d023800983a875ca26c
|
||||
requires_dist:
|
||||
- streamlit>=1.13.0
|
||||
- folium>=0.13,!=0.15.0
|
||||
- jinja2
|
||||
- branca
|
||||
requires_python: '>=3.9'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda
|
||||
sha256: 09d3b6ac51d437bc996ad006d9f749ca5c645c1900a854a6c8f193cbd13f03a8
|
||||
md5: 8c09fac3785696e1c477156192d64b91
|
||||
|
|
|
|||
|
|
@ -49,7 +49,9 @@ dependencies = [
|
|||
"xvec>=0.5.1",
|
||||
"zarr[remote]>=3.1.3",
|
||||
"geocube>=0.7.1,<0.8",
|
||||
"streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2",
|
||||
"streamlit>=1.50.0,<2",
|
||||
"altair[all]>=5.5.0,<6",
|
||||
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
@ -57,7 +59,8 @@ create-grid = "entropice.grids:main"
|
|||
darts = "entropice.darts:main"
|
||||
alpha-earth = "entropice.alphaearth:main"
|
||||
era5 = "entropice.era5:cli"
|
||||
train = "entropice.training:cli"
|
||||
train = "entropice.training:main"
|
||||
dataset = "entropice.dataset:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
|
|
|||
130
src/entropice/dataset.py
Normal file
130
src/entropice/dataset.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""Training dataset preparation and model training."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import xarray as xr
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.paths import (
|
||||
get_darts_rts_file,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
get_train_dataset_file,
|
||||
)
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
|
||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||
seasons = {10: "winter", 4: "summer"}
|
||||
|
||||
|
||||
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||
def _prep_era5(
|
||||
rts: gpd.GeoDataFrame,
|
||||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
) -> pd.DataFrame:
|
||||
era5_df = []
|
||||
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||
|
||||
for var in era5.data_vars:
|
||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||
if temporal == "yearly":
|
||||
df["t"] = df.index.get_level_values("time").year
|
||||
elif temporal == "seasonal":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
elif temporal == "shoulder":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: shoulder_seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
df = (
|
||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||
.rename(columns=lambda x: f"{var}_{x}")
|
||||
.rename_axis(None, axis=1)
|
||||
)
|
||||
era5_df.append(df)
|
||||
era5_df = pd.concat(era5_df, axis=1)
|
||||
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
|
||||
return era5_df
|
||||
|
||||
|
||||
@stopwatch("Prepare embeddings data")
|
||||
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
|
||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
||||
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
||||
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
|
||||
embeddings_df = embeddings_df.rename(
|
||||
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
|
||||
)
|
||||
return embeddings_df
|
||||
|
||||
|
||||
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
|
||||
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
||||
# Filter to coverage
|
||||
if filter_target:
|
||||
rts = rts[rts["darts_has_coverage"]]
|
||||
# Convert hex cell_id to int
|
||||
if grid == "hex":
|
||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||
|
||||
# Add the lat / lon of the cell centers
|
||||
rts["lon"] = rts.geometry.centroid.x
|
||||
rts["lat"] = rts.geometry.centroid.y
|
||||
|
||||
# Get era5 data
|
||||
era5_yearly = _prep_era5(rts, "yearly", grid, level)
|
||||
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
|
||||
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
|
||||
|
||||
# Get embeddings data
|
||||
embeddings = _prep_embeddings(rts, grid, level)
|
||||
|
||||
# Combine datasets by cell id / cell
|
||||
with stopwatch("Combine datasets"):
|
||||
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
|
||||
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
||||
dataset.reset_index().to_parquet(dataset_file)
|
||||
|
||||
|
||||
def main():
|
||||
cyclopts.run(prepare_dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# ruff: noqa: N806
|
||||
"""Inference runs on trained models."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import torch
|
||||
from entropy import ESPAClassifier
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
|
||||
from entropice.paths import get_train_dataset_file
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
|
||||
"""Get predicted probabilities for each cell.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
clf (ESPAClassifier): The trained classifier to use for predictions.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted probabilities for each cell.
|
||||
|
||||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
print(f"Predicting probabilities for {len(data)} cells...")
|
||||
|
||||
# Predict in batches to avoid memory issues
|
||||
batch_size = 100_000
|
||||
preds = []
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data.iloc[i : i + batch_size]
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
|
||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float32")
|
||||
X_batch = torch.asarray(X_batch, device=0)
|
||||
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
|
||||
batch_preds = gpd.GeoDataFrame(
|
||||
{
|
||||
"cell_id": cell_ids,
|
||||
"predicted_proba": batch_preds,
|
||||
"geometry": cell_geoms,
|
||||
},
|
||||
crs="epsg:3413",
|
||||
)
|
||||
preds.append(batch_preds)
|
||||
preds = gpd.GeoDataFrame(pd.concat(preds))
|
||||
|
||||
return preds
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
# ruff: noqa: N806
|
||||
"""Training dataset preparation and model training."""
|
||||
|
||||
import pickle
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import toml
|
||||
import torch
|
||||
import xarray as xr
|
||||
from entropy import ESPAClassifier
|
||||
|
|
@ -16,11 +17,9 @@ from sklearn import set_config
|
|||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.inference import predict_proba
|
||||
from entropice.paths import (
|
||||
get_cv_results_dir,
|
||||
get_darts_rts_file,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
get_train_dataset_file,
|
||||
)
|
||||
|
||||
|
|
@ -29,83 +28,7 @@ pretty.install()
|
|||
|
||||
set_config(array_api_dispatch=True)
|
||||
|
||||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
cli = cyclopts.App()
|
||||
|
||||
|
||||
@cli.command()
|
||||
def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
||||
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
||||
# Filter to coverage
|
||||
rts = rts[rts["darts_has_coverage"]]
|
||||
# Convert hex cell_id to int
|
||||
if grid == "hex":
|
||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||
|
||||
# Get era5 data
|
||||
era5_df = []
|
||||
|
||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||
seasons = {
|
||||
10: "winter",
|
||||
4: "summer",
|
||||
}
|
||||
for temporal in ["yearly", "seasonal", "shoulder"]:
|
||||
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||
|
||||
for var in era5.data_vars:
|
||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||
if temporal == "yearly":
|
||||
df["t"] = df.index.get_level_values("time").year
|
||||
elif temporal == "seasonal":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
elif temporal == "shoulder":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: shoulder_seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
df = (
|
||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||
.rename(columns=lambda x: f"{var}_{x}")
|
||||
.rename_axis(None, axis=1)
|
||||
)
|
||||
era5_df.append(df)
|
||||
era5_df = pd.concat(era5_df, axis=1)
|
||||
|
||||
# Get embeddings data
|
||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
||||
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
||||
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
|
||||
# Combine datasets by cell id / cell
|
||||
# TODO: use prefixes to easy split the features in analysis again
|
||||
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
|
||||
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
||||
dataset.reset_index().to_parquet(dataset_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
|
|
@ -117,6 +40,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
"""
|
||||
data = get_train_dataset_file(grid=grid, level=level)
|
||||
data = gpd.read_parquet(data)
|
||||
data = data[data["darts_has_coverage"]]
|
||||
|
||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
|
|
@ -142,7 +66,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
clf,
|
||||
param_grid,
|
||||
n_iter=n_iter,
|
||||
n_jobs=20,
|
||||
n_jobs=16,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=10,
|
||||
|
|
@ -163,6 +87,45 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||
|
||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
||||
|
||||
# Store the search settings
|
||||
settings = {
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"random_state": 42,
|
||||
"n_iter": n_iter,
|
||||
"param_grid": {
|
||||
"eps_cl": {
|
||||
"distribution": "loguniform",
|
||||
"low": param_grid["eps_cl"].a,
|
||||
"high": param_grid["eps_cl"].b,
|
||||
},
|
||||
"eps_e": {
|
||||
"distribution": "loguniform",
|
||||
"low": param_grid["eps_e"].a,
|
||||
"high": param_grid["eps_e"].b,
|
||||
},
|
||||
"initial_K": {
|
||||
"distribution": "randint",
|
||||
"low": param_grid["initial_K"].a,
|
||||
"high": param_grid["initial_K"].b,
|
||||
},
|
||||
},
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
}
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": settings}, f)
|
||||
|
||||
# Store the best estimator model
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
print(f"Storing best estimator model to {best_model_file}")
|
||||
with open(best_model_file, "wb") as f:
|
||||
pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Store the search results
|
||||
results = pd.DataFrame(search.cv_results_)
|
||||
# Parse the params into individual columns
|
||||
|
|
@ -171,7 +134,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||
results["grid"] = grid
|
||||
results["level"] = level
|
||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
||||
results_file = results_dir / "search_results.parquet"
|
||||
print(f"Storing CV results to {results_file}")
|
||||
results.to_parquet(results_file)
|
||||
|
|
@ -219,9 +181,20 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
|||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
||||
stopwatch.summary()
|
||||
print("Done.")
|
||||
|
||||
|
||||
def main():
|
||||
cyclopts.run(random_cv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -4,14 +4,177 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import altair as alt
|
||||
import folium
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from streamlit_folium import st_folium
|
||||
|
||||
from entropice.paths import RESULTS_DIR
|
||||
|
||||
|
||||
def get_available_result_files() -> list[Path]:
|
||||
"""Get all available result files from RESULTS_DIR."""
|
||||
if not RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
result_files = []
|
||||
for search_dir in RESULTS_DIR.iterdir():
|
||||
if not search_dir.is_dir():
|
||||
continue
|
||||
|
||||
result_file = search_dir / "search_results.parquet"
|
||||
state_file = search_dir / "best_estimator_state.nc"
|
||||
preds_file = search_dir / "predicted_probabilities.parquet"
|
||||
settings_file = search_dir / "search_settings.toml"
|
||||
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
|
||||
result_files.append(search_dir)
|
||||
|
||||
return sorted(result_files, reverse=True) # Most recent first
|
||||
|
||||
|
||||
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
"""Load results file and prepare binned columns.
|
||||
|
||||
Args:
|
||||
file_path: Path to the results parquet file.
|
||||
k_bin_width: Width of bins for initial_K parameter.
|
||||
|
||||
Returns:
|
||||
DataFrame with added binned columns.
|
||||
|
||||
"""
|
||||
results = pd.read_parquet(file_path)
|
||||
|
||||
# Automatically determine bin width for initial_K based on data range
|
||||
k_min = results["initial_K"].min()
|
||||
k_max = results["initial_K"].max()
|
||||
# Use configurable bin width, adapted to actual data range
|
||||
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||
|
||||
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||
# Use 10 bins spanning the actual data range
|
||||
eps_cl_min = np.log10(results["eps_cl"].min())
|
||||
eps_cl_max = np.log10(results["eps_cl"].max())
|
||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
||||
|
||||
eps_e_min = np.log10(results["eps_e"].min())
|
||||
eps_e_max = np.log10(results["eps_e"].max())
|
||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
||||
|
||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
|
||||
"""Load a model state from a NetCDF file.
|
||||
|
||||
Args:
|
||||
file_path (Path): The path to the NetCDF file.
|
||||
|
||||
Returns:
|
||||
xr.Dataset: The model state as an xarray Dataset.
|
||||
|
||||
"""
|
||||
return xr.open_dataset(file_path, engine="h5netcdf")
|
||||
|
||||
|
||||
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||
"""Extract embedding features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
||||
Returns None if no embedding features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_embedding_feature(feature: str) -> bool:
|
||||
return feature.startswith("embeddings_")
|
||||
|
||||
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||
if len(embedding_features) == 0:
|
||||
return None
|
||||
|
||||
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||
agg=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
||||
)
|
||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||
return embedding_feature_array
|
||||
|
||||
|
||||
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||
"""Extract ERA5 features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||
Returns None if no ERA5 features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_era5_feature(feature: str) -> bool:
|
||||
return feature.startswith("era5_")
|
||||
|
||||
def _extract_var_name(feature: str) -> str:
|
||||
feature = feature.replace("era5_", "")
|
||||
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||
return feature.rsplit("_", 2)[0]
|
||||
else:
|
||||
return feature.rsplit("_", 1)[0]
|
||||
|
||||
def _extract_time_name(feature: str) -> str:
|
||||
feature = feature.replace("era5_", "")
|
||||
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||
return "_".join(feature.rsplit("_", 2)[-2:])
|
||||
else:
|
||||
return feature.rsplit("_", 1)[-1]
|
||||
|
||||
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||
if len(era5_features) == 0:
|
||||
return None
|
||||
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||
return era5_features_array
|
||||
|
||||
|
||||
# TODO: Extract common features, e.g. area or water content
|
||||
|
||||
|
||||
def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map:
|
||||
"""Plot predicted probabilities on a map using Streamlit and Folium.
|
||||
|
||||
Args:
|
||||
preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns.
|
||||
|
||||
Returns:
|
||||
folium.Map: A Folium map object with the predicted probabilities visualized.
|
||||
|
||||
"""
|
||||
m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||
return m
|
||||
|
||||
|
||||
def _plot_k_binned(
|
||||
results: pd.DataFrame,
|
||||
target: str,
|
||||
|
|
@ -118,171 +281,6 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
|||
return chart
|
||||
|
||||
|
||||
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
"""Load results file and prepare binned columns.
|
||||
|
||||
Args:
|
||||
file_path: Path to the results parquet file.
|
||||
k_bin_width: Width of bins for initial_K parameter.
|
||||
|
||||
Returns:
|
||||
DataFrame with added binned columns.
|
||||
|
||||
"""
|
||||
results = pd.read_parquet(file_path)
|
||||
|
||||
# Automatically determine bin width for initial_K based on data range
|
||||
k_min = results["initial_K"].min()
|
||||
k_max = results["initial_K"].max()
|
||||
# Use configurable bin width, adapted to actual data range
|
||||
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||
|
||||
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||
# Use 10 bins spanning the actual data range
|
||||
eps_cl_min = np.log10(results["eps_cl"].min())
|
||||
eps_cl_max = np.log10(results["eps_cl"].max())
|
||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
||||
|
||||
eps_e_min = np.log10(results["eps_e"].min())
|
||||
eps_e_max = np.log10(results["eps_e"].max())
|
||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
||||
|
||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
|
||||
"""Load a model state from a NetCDF file.
|
||||
|
||||
Args:
|
||||
file_path (Path): The path to the NetCDF file.
|
||||
|
||||
Returns:
|
||||
xr.Dataset: The model state as an xarray Dataset.
|
||||
|
||||
"""
|
||||
return xr.open_dataset(file_path, engine="h5netcdf")
|
||||
|
||||
|
||||
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||
"""Extract embedding features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
||||
Returns None if no embedding features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_embedding_feature(feature: str) -> bool:
|
||||
parts = feature.split("_")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
_, band, _ = parts
|
||||
if not band.startswith("A"):
|
||||
return False
|
||||
if not band[1:].isdigit():
|
||||
return False
|
||||
return True
|
||||
|
||||
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||
if len(embedding_features) == 0:
|
||||
return None
|
||||
|
||||
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||
agg=("feature", [f.split("_")[0] for f in embedding_features]),
|
||||
band=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||
year=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||
)
|
||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||
return embedding_feature_array
|
||||
|
||||
|
||||
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||
"""Extract ERA5 features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||
Returns None if no ERA5 features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_era5_feature(feature: str) -> bool:
|
||||
# Instant fit if winter or summer in the name
|
||||
if "winter" in feature or "spring" in feature:
|
||||
return True
|
||||
# Instant fit if OND, JFM, AMJ or JAS in the name
|
||||
if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]):
|
||||
return True
|
||||
parts = feature.split("_")
|
||||
if len(parts) == 3:
|
||||
_, band, year = parts
|
||||
if band.startswith("A"):
|
||||
return False
|
||||
if year.isdigit():
|
||||
return True
|
||||
elif parts[-1].isdigit():
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_var_name(feature: str) -> str:
|
||||
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||
return feature.rsplit("_", 2)[0]
|
||||
else:
|
||||
return feature.rsplit("_", 1)[0]
|
||||
|
||||
def _extract_time_name(feature: str) -> str:
|
||||
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||
return "_".join(feature.rsplit("_", 2)[-2:])
|
||||
else:
|
||||
return feature.rsplit("_", 1)[-1]
|
||||
|
||||
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||
if len(era5_features) == 0:
|
||||
return None
|
||||
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||
return era5_features_array
|
||||
|
||||
|
||||
# TODO: Extract common features, e.g. area or water content
|
||||
|
||||
|
||||
def get_available_result_files() -> list[Path]:
|
||||
"""Get all available result files from RESULTS_DIR."""
|
||||
if not RESULTS_DIR.exists():
|
||||
return []
|
||||
|
||||
result_files = []
|
||||
for search_dir in RESULTS_DIR.iterdir():
|
||||
if not search_dir.is_dir():
|
||||
continue
|
||||
|
||||
result_file = search_dir / "search_results.parquet"
|
||||
state_file = search_dir / "best_estimator_state.nc"
|
||||
if result_file.exists() and state_file.exists():
|
||||
result_files.append(search_dir)
|
||||
|
||||
return sorted(result_files, reverse=True) # Most recent first
|
||||
|
||||
|
||||
def _parse_results_dir_name(results_dir: Path) -> str:
|
||||
gridname, date = results_dir.name.split("_random_search_cv")
|
||||
gridname = gridname.lstrip("permafrost_")
|
||||
|
|
@ -574,6 +572,95 @@ def _plot_era5_summary(era5_array: xr.DataArray):
|
|||
return chart_variable, chart_time
|
||||
|
||||
|
||||
def _plot_box_assignments(model_state: xr.Dataset):
|
||||
"""Create a heatmap showing which boxes are assigned to which labels/classes.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the box-to-label assignment heatmap.
|
||||
|
||||
"""
|
||||
# Extract box assignments
|
||||
box_assignments = model_state["box_assignments"]
|
||||
|
||||
# Convert to DataFrame for plotting
|
||||
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
|
||||
y=alt.Y("class:N", title="Class Label"),
|
||||
color=alt.Color(
|
||||
"assignment:Q",
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
title="Assignment Strength",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("class:N", title="Class"),
|
||||
alt.Tooltip("box:O", title="Box"),
|
||||
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
height=150,
|
||||
title="Box-to-Label Assignments (Lambda Matrix)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def _plot_box_assignment_bars(model_state: xr.Dataset):
|
||||
"""Create a bar chart showing how many boxes are assigned to each class.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||
|
||||
Returns:
|
||||
Altair chart showing count of boxes per class.
|
||||
|
||||
"""
|
||||
# Extract box assignments
|
||||
box_assignments = model_state["box_assignments"]
|
||||
|
||||
# Convert to DataFrame
|
||||
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||
|
||||
# For each box, find which class it's most strongly assigned to
|
||||
box_to_class = df.groupby("box")["assignment"].idxmax()
|
||||
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
|
||||
|
||||
# Count boxes per class
|
||||
counts = primary_classes.groupby("class").size().reset_index(name="count")
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(counts)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("class:N", title="Class Label"),
|
||||
y=alt.Y("count:Q", title="Number of Boxes"),
|
||||
color=alt.Color("class:N", title="Class", legend=None),
|
||||
tooltip=[
|
||||
alt.Tooltip("class:N", title="Class"),
|
||||
alt.Tooltip("count:Q", title="Number of Boxes"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=300,
|
||||
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def main():
|
||||
"""Run Streamlit dashboard application."""
|
||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||
|
|
@ -609,6 +696,7 @@ def main():
|
|||
model_state["feature_weights"] *= n_features
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
era5_feature_array = extract_era5_features(model_state)
|
||||
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||
|
||||
st.sidebar.success(f"Loaded {len(results)} results")
|
||||
|
||||
|
|
@ -618,6 +706,12 @@ def main():
|
|||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||
)
|
||||
|
||||
# Map visualization
|
||||
st.header("Predictions Map")
|
||||
st.markdown("Map showing predicted classes from the best estimator")
|
||||
m = _plot_prediction_map(predictions)
|
||||
st_folium(m, width="100%", height=300)
|
||||
|
||||
# Display some basic statistics
|
||||
st.header("Dataset Overview")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
|
@ -765,6 +859,62 @@ def main():
|
|||
"""
|
||||
)
|
||||
|
||||
# Box-to-Label Assignment Visualization
|
||||
st.subheader("Box-to-Label Assignments")
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows how the learned boxes (prototypes in feature space) are
|
||||
assigned to different class labels. The ESPA classifier learns K boxes and assigns
|
||||
them to classes through the Lambda matrix. Higher values indicate stronger assignment
|
||||
of a box to a particular class.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.spinner("Generating box assignment visualizations..."):
|
||||
col1, col2 = st.columns([0.7, 0.3])
|
||||
|
||||
with col1:
|
||||
st.markdown("### Assignment Heatmap")
|
||||
box_assignment_heatmap = _plot_box_assignments(model_state)
|
||||
st.altair_chart(box_assignment_heatmap, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.markdown("### Box Count by Class")
|
||||
box_assignment_bars = _plot_box_assignment_bars(model_state)
|
||||
st.altair_chart(box_assignment_bars, use_container_width=True)
|
||||
|
||||
# Show statistics
|
||||
with st.expander("Box Assignment Statistics"):
|
||||
box_assignments = model_state["box_assignments"].to_pandas()
|
||||
st.write("**Assignment Matrix Statistics:**")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Boxes", len(box_assignments.columns))
|
||||
with col2:
|
||||
st.metric("Number of Classes", len(box_assignments.index))
|
||||
with col3:
|
||||
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
|
||||
with col4:
|
||||
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
|
||||
|
||||
# Show which boxes are most strongly assigned to each class
|
||||
st.write("**Top Box Assignments per Class:**")
|
||||
for class_label in box_assignments.index:
|
||||
top_boxes = box_assignments.loc[class_label].nlargest(5)
|
||||
st.write(
|
||||
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
|
||||
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- Each box can be assigned to multiple classes with different strengths
|
||||
- Boxes with higher assignment values for a class contribute more to that class's predictions
|
||||
- The distribution shows how the model partitions the feature space for classification
|
||||
"""
|
||||
)
|
||||
|
||||
# Embedding features analysis (if present)
|
||||
if embedding_feature_array is not None:
|
||||
st.subheader("Embedding Feature Analysis")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue