Add an inference map
This commit is contained in:
parent
150f14ed52
commit
fb522ddad5
6 changed files with 580 additions and 251 deletions
14
pixi.lock
generated
14
pixi.lock
generated
|
|
@ -288,6 +288,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/31/cc/099fab5a73909a117e9689c7da4c39a248595187f0f30dd879ad1d2c34ce/tblib-3.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/31/cc/099fab5a73909a117e9689c7da4c39a248595187f0f30dd879ad1d2c34ce/tblib-3.2.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl
|
||||||
|
|
@ -1300,7 +1301,7 @@ packages:
|
||||||
- pypi: ./
|
- pypi: ./
|
||||||
name: entropice
|
name: entropice
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
sha256: 4f45dd8bbe428416b7bcb3a904e31376735a9bbbc0d5438e91913e7477e3c0c0
|
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- aiohttp>=3.12.11
|
- aiohttp>=3.12.11
|
||||||
- bokeh>=3.7.3
|
- bokeh>=3.7.3
|
||||||
|
|
@ -1348,6 +1349,7 @@ packages:
|
||||||
- streamlit>=1.50.0,<2
|
- streamlit>=1.50.0,<2
|
||||||
- altair[all]>=5.5.0,<6
|
- altair[all]>=5.5.0,<6
|
||||||
- h5netcdf>=1.7.3,<2
|
- h5netcdf>=1.7.3,<2
|
||||||
|
- streamlit-folium>=0.25.3,<0.26
|
||||||
editable: true
|
editable: true
|
||||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||||
name: entropy
|
name: entropy
|
||||||
|
|
@ -4560,6 +4562,16 @@ packages:
|
||||||
- streamlit[auth,charts,pdf,snowflake,sql] ; extra == 'all'
|
- streamlit[auth,charts,pdf,snowflake,sql] ; extra == 'all'
|
||||||
- rich>=11.0.0 ; extra == 'all'
|
- rich>=11.0.0 ; extra == 'all'
|
||||||
requires_python: '>=3.9,!=3.9.7'
|
requires_python: '>=3.9,!=3.9.7'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||||
|
name: streamlit-folium
|
||||||
|
version: 0.25.3
|
||||||
|
sha256: cfdf085764da3f9b5e1e0668f6e4cc0385ff041c98133d023800983a875ca26c
|
||||||
|
requires_dist:
|
||||||
|
- streamlit>=1.13.0
|
||||||
|
- folium>=0.13,!=0.15.0
|
||||||
|
- jinja2
|
||||||
|
- branca
|
||||||
|
requires_python: '>=3.9'
|
||||||
- conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda
|
- conda: https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda
|
||||||
sha256: 09d3b6ac51d437bc996ad006d9f749ca5c645c1900a854a6c8f193cbd13f03a8
|
sha256: 09d3b6ac51d437bc996ad006d9f749ca5c645c1900a854a6c8f193cbd13f03a8
|
||||||
md5: 8c09fac3785696e1c477156192d64b91
|
md5: 8c09fac3785696e1c477156192d64b91
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,9 @@ dependencies = [
|
||||||
"xvec>=0.5.1",
|
"xvec>=0.5.1",
|
||||||
"zarr[remote]>=3.1.3",
|
"zarr[remote]>=3.1.3",
|
||||||
"geocube>=0.7.1,<0.8",
|
"geocube>=0.7.1,<0.8",
|
||||||
"streamlit>=1.50.0,<2", "altair[all]>=5.5.0,<6", "h5netcdf>=1.7.3,<2",
|
"streamlit>=1.50.0,<2",
|
||||||
|
"altair[all]>=5.5.0,<6",
|
||||||
|
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -57,7 +59,8 @@ create-grid = "entropice.grids:main"
|
||||||
darts = "entropice.darts:main"
|
darts = "entropice.darts:main"
|
||||||
alpha-earth = "entropice.alphaearth:main"
|
alpha-earth = "entropice.alphaearth:main"
|
||||||
era5 = "entropice.era5:cli"
|
era5 = "entropice.era5:cli"
|
||||||
train = "entropice.training:cli"
|
train = "entropice.training:main"
|
||||||
|
dataset = "entropice.dataset:main"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|
|
||||||
130
src/entropice/dataset.py
Normal file
130
src/entropice/dataset.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""Training dataset preparation and model training."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import xarray as xr
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from sklearn import set_config
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.paths import (
|
||||||
|
get_darts_rts_file,
|
||||||
|
get_embeddings_store,
|
||||||
|
get_era5_stores,
|
||||||
|
get_train_dataset_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
|
seasons = {10: "winter", 4: "summer"}
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||||
|
def _prep_era5(
|
||||||
|
rts: gpd.GeoDataFrame,
|
||||||
|
temporal: Literal["yearly", "seasonal", "shoulder"],
|
||||||
|
grid: Literal["hex", "healpix"],
|
||||||
|
level: int,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
era5_df = []
|
||||||
|
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
||||||
|
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||||
|
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||||
|
|
||||||
|
for var in era5.data_vars:
|
||||||
|
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||||
|
if temporal == "yearly":
|
||||||
|
df["t"] = df.index.get_level_values("time").year
|
||||||
|
elif temporal == "seasonal":
|
||||||
|
df["t"] = (
|
||||||
|
df.index.get_level_values("time")
|
||||||
|
.month.map(lambda x: seasons.get(x))
|
||||||
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
|
)
|
||||||
|
elif temporal == "shoulder":
|
||||||
|
df["t"] = (
|
||||||
|
df.index.get_level_values("time")
|
||||||
|
.month.map(lambda x: shoulder_seasons.get(x))
|
||||||
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
|
)
|
||||||
|
df = (
|
||||||
|
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||||
|
.rename(columns=lambda x: f"{var}_{x}")
|
||||||
|
.rename_axis(None, axis=1)
|
||||||
|
)
|
||||||
|
era5_df.append(df)
|
||||||
|
era5_df = pd.concat(era5_df, axis=1)
|
||||||
|
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
|
||||||
|
return era5_df
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Prepare embeddings data")
|
||||||
|
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
|
||||||
|
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||||
|
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
||||||
|
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
||||||
|
|
||||||
|
embeddings_df = embeddings.to_dataframe(name="value")
|
||||||
|
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
||||||
|
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||||
|
|
||||||
|
embeddings_df = embeddings_df.rename(
|
||||||
|
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
|
||||||
|
)
|
||||||
|
return embeddings_df
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
|
||||||
|
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
|
||||||
|
"""
|
||||||
|
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
||||||
|
# Filter to coverage
|
||||||
|
if filter_target:
|
||||||
|
rts = rts[rts["darts_has_coverage"]]
|
||||||
|
# Convert hex cell_id to int
|
||||||
|
if grid == "hex":
|
||||||
|
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
|
||||||
|
# Add the lat / lon of the cell centers
|
||||||
|
rts["lon"] = rts.geometry.centroid.x
|
||||||
|
rts["lat"] = rts.geometry.centroid.y
|
||||||
|
|
||||||
|
# Get era5 data
|
||||||
|
era5_yearly = _prep_era5(rts, "yearly", grid, level)
|
||||||
|
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
|
||||||
|
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
|
||||||
|
|
||||||
|
# Get embeddings data
|
||||||
|
embeddings = _prep_embeddings(rts, grid, level)
|
||||||
|
|
||||||
|
# Combine datasets by cell id / cell
|
||||||
|
with stopwatch("Combine datasets"):
|
||||||
|
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
|
||||||
|
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||||
|
|
||||||
|
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
dataset.reset_index().to_parquet(dataset_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cyclopts.run(prepare_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
# ruff: noqa: N806
|
||||||
|
"""Inference runs on trained models."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from entropy import ESPAClassifier
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from sklearn import set_config
|
||||||
|
|
||||||
|
from entropice.paths import get_train_dataset_file
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifier):
|
||||||
|
"""Get predicted probabilities for each cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
clf (ESPAClassifier): The trained classifier to use for predictions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of predicted probabilities for each cell.
|
||||||
|
|
||||||
|
"""
|
||||||
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
data = gpd.read_parquet(data)
|
||||||
|
print(f"Predicting probabilities for {len(data)} cells...")
|
||||||
|
|
||||||
|
# Predict in batches to avoid memory issues
|
||||||
|
batch_size = 100_000
|
||||||
|
preds = []
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data.iloc[i : i + batch_size]
|
||||||
|
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||||
|
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||||
|
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||||
|
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
|
||||||
|
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||||
|
X_batch = X_batch.to_numpy(dtype="float32")
|
||||||
|
X_batch = torch.asarray(X_batch, device=0)
|
||||||
|
batch_preds = clf.predict_proba(X_batch)[:, 1].cpu().numpy()
|
||||||
|
batch_preds = gpd.GeoDataFrame(
|
||||||
|
{
|
||||||
|
"cell_id": cell_ids,
|
||||||
|
"predicted_proba": batch_preds,
|
||||||
|
"geometry": cell_geoms,
|
||||||
|
},
|
||||||
|
crs="epsg:3413",
|
||||||
|
)
|
||||||
|
preds.append(batch_preds)
|
||||||
|
preds = gpd.GeoDataFrame(pd.concat(preds))
|
||||||
|
|
||||||
|
return preds
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
"""Training dataset preparation and model training."""
|
"""Training dataset preparation and model training."""
|
||||||
|
|
||||||
|
import pickle
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import toml
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from entropy import ESPAClassifier
|
from entropy import ESPAClassifier
|
||||||
|
|
@ -16,11 +17,9 @@ from sklearn import set_config
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.inference import predict_proba
|
||||||
from entropice.paths import (
|
from entropice.paths import (
|
||||||
get_cv_results_dir,
|
get_cv_results_dir,
|
||||||
get_darts_rts_file,
|
|
||||||
get_embeddings_store,
|
|
||||||
get_era5_stores,
|
|
||||||
get_train_dataset_file,
|
get_train_dataset_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,83 +28,7 @@ pretty.install()
|
||||||
|
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
sns.set_theme("talk", "whitegrid")
|
|
||||||
|
|
||||||
cli = cyclopts.App()
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
|
||||||
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
|
||||||
level (int): The grid level to use.
|
|
||||||
|
|
||||||
"""
|
|
||||||
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
|
||||||
# Filter to coverage
|
|
||||||
rts = rts[rts["darts_has_coverage"]]
|
|
||||||
# Convert hex cell_id to int
|
|
||||||
if grid == "hex":
|
|
||||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
|
||||||
|
|
||||||
# Get era5 data
|
|
||||||
era5_df = []
|
|
||||||
|
|
||||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
|
||||||
seasons = {
|
|
||||||
10: "winter",
|
|
||||||
4: "summer",
|
|
||||||
}
|
|
||||||
for temporal in ["yearly", "seasonal", "shoulder"]:
|
|
||||||
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
|
||||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
|
||||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
|
||||||
|
|
||||||
for var in era5.data_vars:
|
|
||||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
|
||||||
if temporal == "yearly":
|
|
||||||
df["t"] = df.index.get_level_values("time").year
|
|
||||||
elif temporal == "seasonal":
|
|
||||||
df["t"] = (
|
|
||||||
df.index.get_level_values("time")
|
|
||||||
.month.map(lambda x: seasons.get(x))
|
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
|
||||||
)
|
|
||||||
elif temporal == "shoulder":
|
|
||||||
df["t"] = (
|
|
||||||
df.index.get_level_values("time")
|
|
||||||
.month.map(lambda x: shoulder_seasons.get(x))
|
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
|
||||||
)
|
|
||||||
df = (
|
|
||||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
|
||||||
.rename(columns=lambda x: f"{var}_{x}")
|
|
||||||
.rename_axis(None, axis=1)
|
|
||||||
)
|
|
||||||
era5_df.append(df)
|
|
||||||
era5_df = pd.concat(era5_df, axis=1)
|
|
||||||
|
|
||||||
# Get embeddings data
|
|
||||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
|
||||||
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
|
||||||
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
|
||||||
|
|
||||||
embeddings_df = embeddings.to_dataframe(name="value")
|
|
||||||
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
|
||||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
|
||||||
|
|
||||||
# Combine datasets by cell id / cell
|
|
||||||
# TODO: use prefixes to easy split the features in analysis again
|
|
||||||
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
|
|
||||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
|
||||||
|
|
||||||
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
|
||||||
dataset.reset_index().to_parquet(dataset_file)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
|
@ -117,6 +40,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
"""
|
"""
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
data = gpd.read_parquet(data)
|
data = gpd.read_parquet(data)
|
||||||
|
data = data[data["darts_has_coverage"]]
|
||||||
|
|
||||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||||
|
|
@ -142,7 +66,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
clf,
|
clf,
|
||||||
param_grid,
|
param_grid,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
n_jobs=20,
|
n_jobs=16,
|
||||||
cv=cv,
|
cv=cv,
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=10,
|
verbose=10,
|
||||||
|
|
@ -163,6 +87,45 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||||
|
|
||||||
|
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
||||||
|
|
||||||
|
# Store the search settings
|
||||||
|
settings = {
|
||||||
|
"grid": grid,
|
||||||
|
"level": level,
|
||||||
|
"random_state": 42,
|
||||||
|
"n_iter": n_iter,
|
||||||
|
"param_grid": {
|
||||||
|
"eps_cl": {
|
||||||
|
"distribution": "loguniform",
|
||||||
|
"low": param_grid["eps_cl"].a,
|
||||||
|
"high": param_grid["eps_cl"].b,
|
||||||
|
},
|
||||||
|
"eps_e": {
|
||||||
|
"distribution": "loguniform",
|
||||||
|
"low": param_grid["eps_e"].a,
|
||||||
|
"high": param_grid["eps_e"].b,
|
||||||
|
},
|
||||||
|
"initial_K": {
|
||||||
|
"distribution": "randint",
|
||||||
|
"low": param_grid["initial_K"].a,
|
||||||
|
"high": param_grid["initial_K"].b,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cv_splits": cv.get_n_splits(),
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
settings_file = results_dir / "search_settings.toml"
|
||||||
|
print(f"Storing search settings to {settings_file}")
|
||||||
|
with open(settings_file, "w") as f:
|
||||||
|
toml.dump({"settings": settings}, f)
|
||||||
|
|
||||||
|
# Store the best estimator model
|
||||||
|
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||||
|
print(f"Storing best estimator model to {best_model_file}")
|
||||||
|
with open(best_model_file, "wb") as f:
|
||||||
|
pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
# Store the search results
|
# Store the search results
|
||||||
results = pd.DataFrame(search.cv_results_)
|
results = pd.DataFrame(search.cv_results_)
|
||||||
# Parse the params into individual columns
|
# Parse the params into individual columns
|
||||||
|
|
@ -171,7 +134,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||||
results["grid"] = grid
|
results["grid"] = grid
|
||||||
results["level"] = level
|
results["level"] = level
|
||||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
|
||||||
results_file = results_dir / "search_results.parquet"
|
results_file = results_dir / "search_results.parquet"
|
||||||
print(f"Storing CV results to {results_file}")
|
print(f"Storing CV results to {results_file}")
|
||||||
results.to_parquet(results_file)
|
results.to_parquet(results_file)
|
||||||
|
|
@ -219,9 +181,20 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000):
|
||||||
print(f"Storing best estimator state to {state_file}")
|
print(f"Storing best estimator state to {state_file}")
|
||||||
state.to_netcdf(state_file, engine="h5netcdf")
|
state.to_netcdf(state_file, engine="h5netcdf")
|
||||||
|
|
||||||
|
# Predict probabilities for all cells
|
||||||
|
print("Predicting probabilities for all cells...")
|
||||||
|
preds = predict_proba(grid=grid, level=level, clf=best_estimator)
|
||||||
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
|
preds.to_parquet(preds_file)
|
||||||
|
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cyclopts.run(random_cv)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,177 @@ from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import altair as alt
|
import altair as alt
|
||||||
|
import folium
|
||||||
|
import geopandas as gpd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from streamlit_folium import st_folium
|
||||||
|
|
||||||
from entropice.paths import RESULTS_DIR
|
from entropice.paths import RESULTS_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_result_files() -> list[Path]:
|
||||||
|
"""Get all available result files from RESULTS_DIR."""
|
||||||
|
if not RESULTS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
result_files = []
|
||||||
|
for search_dir in RESULTS_DIR.iterdir():
|
||||||
|
if not search_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_file = search_dir / "search_results.parquet"
|
||||||
|
state_file = search_dir / "best_estimator_state.nc"
|
||||||
|
preds_file = search_dir / "predicted_probabilities.parquet"
|
||||||
|
settings_file = search_dir / "search_settings.toml"
|
||||||
|
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
|
||||||
|
result_files.append(search_dir)
|
||||||
|
|
||||||
|
return sorted(result_files, reverse=True) # Most recent first
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
||||||
|
"""Load results file and prepare binned columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the results parquet file.
|
||||||
|
k_bin_width: Width of bins for initial_K parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with added binned columns.
|
||||||
|
|
||||||
|
"""
|
||||||
|
results = pd.read_parquet(file_path)
|
||||||
|
|
||||||
|
# Automatically determine bin width for initial_K based on data range
|
||||||
|
k_min = results["initial_K"].min()
|
||||||
|
k_max = results["initial_K"].max()
|
||||||
|
# Use configurable bin width, adapted to actual data range
|
||||||
|
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
||||||
|
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
||||||
|
|
||||||
|
# Automatically create logarithmic bins for epsilon parameters based on data range
|
||||||
|
# Use 10 bins spanning the actual data range
|
||||||
|
eps_cl_min = np.log10(results["eps_cl"].min())
|
||||||
|
eps_cl_max = np.log10(results["eps_cl"].max())
|
||||||
|
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
||||||
|
|
||||||
|
eps_e_min = np.log10(results["eps_e"].min())
|
||||||
|
eps_e_max = np.log10(results["eps_e"].max())
|
||||||
|
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
||||||
|
|
||||||
|
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||||
|
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
|
||||||
|
"""Load a model state from a NetCDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): The path to the NetCDF file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.Dataset: The model state as an xarray Dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return xr.open_dataset(file_path, engine="h5netcdf")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
|
"""Extract embedding features from the model state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||||
|
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
||||||
|
Returns None if no embedding features are found.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_embedding_feature(feature: str) -> bool:
|
||||||
|
return feature.startswith("embeddings_")
|
||||||
|
|
||||||
|
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||||
|
if len(embedding_features) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||||
|
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||||
|
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||||
|
agg=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||||
|
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||||
|
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
||||||
|
)
|
||||||
|
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||||
|
return embedding_feature_array
|
||||||
|
|
||||||
|
|
||||||
|
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
|
"""Extract ERA5 features from the model state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||||
|
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||||
|
Returns None if no ERA5 features are found.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_era5_feature(feature: str) -> bool:
|
||||||
|
return feature.startswith("era5_")
|
||||||
|
|
||||||
|
def _extract_var_name(feature: str) -> str:
|
||||||
|
feature = feature.replace("era5_", "")
|
||||||
|
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||||
|
return feature.rsplit("_", 2)[0]
|
||||||
|
else:
|
||||||
|
return feature.rsplit("_", 1)[0]
|
||||||
|
|
||||||
|
def _extract_time_name(feature: str) -> str:
|
||||||
|
feature = feature.replace("era5_", "")
|
||||||
|
if any(season in feature for season in ["summer", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
||||||
|
return "_".join(feature.rsplit("_", 2)[-2:])
|
||||||
|
else:
|
||||||
|
return feature.rsplit("_", 1)[-1]
|
||||||
|
|
||||||
|
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||||
|
if len(era5_features) == 0:
|
||||||
|
return None
|
||||||
|
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||||
|
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||||
|
era5_features_array = era5_features_array.assign_coords(
|
||||||
|
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||||
|
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||||
|
)
|
||||||
|
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||||
|
return era5_features_array
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Extract common features, e.g. area or water content
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_prediction_map(preds: gpd.GeoDataFrame) -> folium.Map:
|
||||||
|
"""Plot predicted probabilities on a map using Streamlit and Folium.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: GeoDataFrame containing 'predicted_proba' and 'geometry' columns.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
folium.Map: A Folium map object with the predicted probabilities visualized.
|
||||||
|
|
||||||
|
"""
|
||||||
|
m = preds.explore(column="predicted_proba", cmap="Set3", legend=True, tiles="CartoDB positron")
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
def _plot_k_binned(
|
def _plot_k_binned(
|
||||||
results: pd.DataFrame,
|
results: pd.DataFrame,
|
||||||
target: str,
|
target: str,
|
||||||
|
|
@ -118,171 +281,6 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def load_and_prepare_results(file_path: Path, k_bin_width: int = 40) -> pd.DataFrame:
|
|
||||||
"""Load results file and prepare binned columns.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the results parquet file.
|
|
||||||
k_bin_width: Width of bins for initial_K parameter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame with added binned columns.
|
|
||||||
|
|
||||||
"""
|
|
||||||
results = pd.read_parquet(file_path)
|
|
||||||
|
|
||||||
# Automatically determine bin width for initial_K based on data range
|
|
||||||
k_min = results["initial_K"].min()
|
|
||||||
k_max = results["initial_K"].max()
|
|
||||||
# Use configurable bin width, adapted to actual data range
|
|
||||||
k_bins = np.arange(k_min, k_max + k_bin_width, k_bin_width)
|
|
||||||
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=k_bins, right=False)
|
|
||||||
|
|
||||||
# Automatically create logarithmic bins for epsilon parameters based on data range
|
|
||||||
# Use 10 bins spanning the actual data range
|
|
||||||
eps_cl_min = np.log10(results["eps_cl"].min())
|
|
||||||
eps_cl_max = np.log10(results["eps_cl"].max())
|
|
||||||
eps_cl_bins = np.logspace(eps_cl_min, eps_cl_max, num=10)
|
|
||||||
|
|
||||||
eps_e_min = np.log10(results["eps_e"].min())
|
|
||||||
eps_e_max = np.log10(results["eps_e"].max())
|
|
||||||
eps_e_bins = np.logspace(eps_e_min, eps_e_max, num=10)
|
|
||||||
|
|
||||||
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
|
||||||
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_prepare_model_state(file_path: Path) -> xr.Dataset:
|
|
||||||
"""Load a model state from a NetCDF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (Path): The path to the NetCDF file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xr.Dataset: The model state as an xarray Dataset.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return xr.open_dataset(file_path, engine="h5netcdf")
|
|
||||||
|
|
||||||
|
|
||||||
def extract_embedding_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
|
||||||
"""Extract embedding features from the model state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_state: The xarray Dataset containing the model state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
|
||||||
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
|
||||||
Returns None if no embedding features are found.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_embedding_feature(feature: str) -> bool:
|
|
||||||
parts = feature.split("_")
|
|
||||||
if len(parts) != 3:
|
|
||||||
return False
|
|
||||||
_, band, _ = parts
|
|
||||||
if not band.startswith("A"):
|
|
||||||
return False
|
|
||||||
if not band[1:].isdigit():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
|
||||||
if len(embedding_features) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
|
||||||
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
|
||||||
embedding_feature_array = embedding_feature_array.assign_coords(
|
|
||||||
agg=("feature", [f.split("_")[0] for f in embedding_features]),
|
|
||||||
band=("feature", [f.split("_")[1] for f in embedding_features]),
|
|
||||||
year=("feature", [f.split("_")[2] for f in embedding_features]),
|
|
||||||
)
|
|
||||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
|
||||||
return embedding_feature_array
|
|
||||||
|
|
||||||
|
|
||||||
def extract_era5_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
|
||||||
"""Extract ERA5 features from the model state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_state: The xarray Dataset containing the model state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
|
||||||
('variable', 'time') corresponding to the different components of the ERA5 features.
|
|
||||||
Returns None if no ERA5 features are found.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_era5_feature(feature: str) -> bool:
|
|
||||||
# Instant fit if winter or summer in the name
|
|
||||||
if "winter" in feature or "spring" in feature:
|
|
||||||
return True
|
|
||||||
# Instant fit if OND, JFM, AMJ or JAS in the name
|
|
||||||
if any(season in feature for season in ["OND", "JFM", "AMJ", "JAS"]):
|
|
||||||
return True
|
|
||||||
parts = feature.split("_")
|
|
||||||
if len(parts) == 3:
|
|
||||||
_, band, year = parts
|
|
||||||
if band.startswith("A"):
|
|
||||||
return False
|
|
||||||
if year.isdigit():
|
|
||||||
return True
|
|
||||||
elif parts[-1].isdigit():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _extract_var_name(feature: str) -> str:
|
|
||||||
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
|
||||||
return feature.rsplit("_", 2)[0]
|
|
||||||
else:
|
|
||||||
return feature.rsplit("_", 1)[0]
|
|
||||||
|
|
||||||
def _extract_time_name(feature: str) -> str:
|
|
||||||
if any(season in feature for season in ["spring", "winter", "OND", "JFM", "AMJ", "JAS"]):
|
|
||||||
return "_".join(feature.rsplit("_", 2)[-2:])
|
|
||||||
else:
|
|
||||||
return feature.rsplit("_", 1)[-1]
|
|
||||||
|
|
||||||
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
|
||||||
if len(era5_features) == 0:
|
|
||||||
return None
|
|
||||||
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
|
||||||
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
|
||||||
era5_features_array = era5_features_array.assign_coords(
|
|
||||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
|
||||||
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
|
||||||
)
|
|
||||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
|
||||||
return era5_features_array
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Extract common features, e.g. area or water content
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_result_files() -> list[Path]:
|
|
||||||
"""Get all available result files from RESULTS_DIR."""
|
|
||||||
if not RESULTS_DIR.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
result_files = []
|
|
||||||
for search_dir in RESULTS_DIR.iterdir():
|
|
||||||
if not search_dir.is_dir():
|
|
||||||
continue
|
|
||||||
|
|
||||||
result_file = search_dir / "search_results.parquet"
|
|
||||||
state_file = search_dir / "best_estimator_state.nc"
|
|
||||||
if result_file.exists() and state_file.exists():
|
|
||||||
result_files.append(search_dir)
|
|
||||||
|
|
||||||
return sorted(result_files, reverse=True) # Most recent first
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_results_dir_name(results_dir: Path) -> str:
|
def _parse_results_dir_name(results_dir: Path) -> str:
|
||||||
gridname, date = results_dir.name.split("_random_search_cv")
|
gridname, date = results_dir.name.split("_random_search_cv")
|
||||||
gridname = gridname.lstrip("permafrost_")
|
gridname = gridname.lstrip("permafrost_")
|
||||||
|
|
@ -574,6 +572,95 @@ def _plot_era5_summary(era5_array: xr.DataArray):
|
||||||
return chart_variable, chart_time
|
return chart_variable, chart_time
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_box_assignments(model_state: xr.Dataset):
|
||||||
|
"""Create a heatmap showing which boxes are assigned to which labels/classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing the box-to-label assignment heatmap.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Extract box assignments
|
||||||
|
box_assignments = model_state["box_assignments"]
|
||||||
|
|
||||||
|
# Convert to DataFrame for plotting
|
||||||
|
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||||
|
|
||||||
|
# Create heatmap
|
||||||
|
chart = (
|
||||||
|
alt.Chart(df)
|
||||||
|
.mark_rect()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
|
||||||
|
y=alt.Y("class:N", title="Class Label"),
|
||||||
|
color=alt.Color(
|
||||||
|
"assignment:Q",
|
||||||
|
scale=alt.Scale(scheme="viridis"),
|
||||||
|
title="Assignment Strength",
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("class:N", title="Class"),
|
||||||
|
alt.Tooltip("box:O", title="Box"),
|
||||||
|
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
height=150,
|
||||||
|
title="Box-to-Label Assignments (Lambda Matrix)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_box_assignment_bars(model_state: xr.Dataset):
|
||||||
|
"""Create a bar chart showing how many boxes are assigned to each class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing count of boxes per class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Extract box assignments
|
||||||
|
box_assignments = model_state["box_assignments"]
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||||
|
|
||||||
|
# For each box, find which class it's most strongly assigned to
|
||||||
|
box_to_class = df.groupby("box")["assignment"].idxmax()
|
||||||
|
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
|
||||||
|
|
||||||
|
# Count boxes per class
|
||||||
|
counts = primary_classes.groupby("class").size().reset_index(name="count")
|
||||||
|
|
||||||
|
# Create bar chart
|
||||||
|
chart = (
|
||||||
|
alt.Chart(counts)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("class:N", title="Class Label"),
|
||||||
|
y=alt.Y("count:Q", title="Number of Boxes"),
|
||||||
|
color=alt.Color("class:N", title="Class", legend=None),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("class:N", title="Class"),
|
||||||
|
alt.Tooltip("count:Q", title="Number of Boxes"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=600,
|
||||||
|
height=300,
|
||||||
|
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run Streamlit dashboard application."""
|
"""Run Streamlit dashboard application."""
|
||||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
@ -609,6 +696,7 @@ def main():
|
||||||
model_state["feature_weights"] *= n_features
|
model_state["feature_weights"] *= n_features
|
||||||
embedding_feature_array = extract_embedding_features(model_state)
|
embedding_feature_array = extract_embedding_features(model_state)
|
||||||
era5_feature_array = extract_era5_features(model_state)
|
era5_feature_array = extract_era5_features(model_state)
|
||||||
|
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||||
|
|
||||||
st.sidebar.success(f"Loaded {len(results)} results")
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
|
||||||
|
|
@ -618,6 +706,12 @@ def main():
|
||||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Map visualization
|
||||||
|
st.header("Predictions Map")
|
||||||
|
st.markdown("Map showing predicted classes from the best estimator")
|
||||||
|
m = _plot_prediction_map(predictions)
|
||||||
|
st_folium(m, width="100%", height=300)
|
||||||
|
|
||||||
# Display some basic statistics
|
# Display some basic statistics
|
||||||
st.header("Dataset Overview")
|
st.header("Dataset Overview")
|
||||||
col1, col2, col3 = st.columns(3)
|
col1, col2, col3 = st.columns(3)
|
||||||
|
|
@ -765,6 +859,62 @@ def main():
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Box-to-Label Assignment Visualization
|
||||||
|
st.subheader("Box-to-Label Assignments")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows how the learned boxes (prototypes in feature space) are
|
||||||
|
assigned to different class labels. The ESPA classifier learns K boxes and assigns
|
||||||
|
them to classes through the Lambda matrix. Higher values indicate stronger assignment
|
||||||
|
of a box to a particular class.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with st.spinner("Generating box assignment visualizations..."):
|
||||||
|
col1, col2 = st.columns([0.7, 0.3])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.markdown("### Assignment Heatmap")
|
||||||
|
box_assignment_heatmap = _plot_box_assignments(model_state)
|
||||||
|
st.altair_chart(box_assignment_heatmap, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.markdown("### Box Count by Class")
|
||||||
|
box_assignment_bars = _plot_box_assignment_bars(model_state)
|
||||||
|
st.altair_chart(box_assignment_bars, use_container_width=True)
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
with st.expander("Box Assignment Statistics"):
|
||||||
|
box_assignments = model_state["box_assignments"].to_pandas()
|
||||||
|
st.write("**Assignment Matrix Statistics:**")
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Boxes", len(box_assignments.columns))
|
||||||
|
with col2:
|
||||||
|
st.metric("Number of Classes", len(box_assignments.index))
|
||||||
|
with col3:
|
||||||
|
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
|
||||||
|
with col4:
|
||||||
|
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
|
||||||
|
|
||||||
|
# Show which boxes are most strongly assigned to each class
|
||||||
|
st.write("**Top Box Assignments per Class:**")
|
||||||
|
for class_label in box_assignments.index:
|
||||||
|
top_boxes = box_assignments.loc[class_label].nlargest(5)
|
||||||
|
st.write(
|
||||||
|
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
|
||||||
|
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Interpretation:**
|
||||||
|
- Each box can be assigned to multiple classes with different strengths
|
||||||
|
- Boxes with higher assignment values for a class contribute more to that class's predictions
|
||||||
|
- The distribution shows how the model partitions the feature space for classification
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Embedding features analysis (if present)
|
# Embedding features analysis (if present)
|
||||||
if embedding_feature_array is not None:
|
if embedding_feature_array is not None:
|
||||||
st.subheader("Embedding Feature Analysis")
|
st.subheader("Embedding Feature Analysis")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue