Add multiclass training
This commit is contained in:
parent
553b54bb32
commit
d5b35d6da4
7 changed files with 814 additions and 4867 deletions
11
pixi.lock
generated
11
pixi.lock
generated
|
|
@ -285,6 +285,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
|
||||||
|
|
@ -1301,7 +1302,7 @@ packages:
|
||||||
- pypi: ./
|
- pypi: ./
|
||||||
name: entropice
|
name: entropice
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
|
sha256: 852c87cdbd1d452fccaa6253c6ce7410dda9fa32d2f951d35441e454414acfc0
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- aiohttp>=3.12.11
|
- aiohttp>=3.12.11
|
||||||
- bokeh>=3.7.3
|
- bokeh>=3.7.3
|
||||||
|
|
@ -1350,6 +1351,7 @@ packages:
|
||||||
- altair[all]>=5.5.0,<6
|
- altair[all]>=5.5.0,<6
|
||||||
- h5netcdf>=1.7.3,<2
|
- h5netcdf>=1.7.3,<2
|
||||||
- streamlit-folium>=0.25.3,<0.26
|
- streamlit-folium>=0.25.3,<0.26
|
||||||
|
- st-theme>=1.2.3,<2
|
||||||
editable: true
|
editable: true
|
||||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||||
name: entropy
|
name: entropy
|
||||||
|
|
@ -4507,6 +4509,13 @@ packages:
|
||||||
version: 1.0.1
|
version: 1.0.1
|
||||||
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
|
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
|
||||||
requires_python: '>=3.6'
|
requires_python: '>=3.6'
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
|
||||||
|
name: st-theme
|
||||||
|
version: 1.2.3
|
||||||
|
sha256: 0a54d9817dd5f8a6d7b0d071b25ae72eacf536c63a5fb97374923938021b1389
|
||||||
|
requires_dist:
|
||||||
|
- streamlit>=1.33
|
||||||
|
requires_python: '>=3.8'
|
||||||
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||||
name: stack-data
|
name: stack-data
|
||||||
version: 0.6.3
|
version: 0.6.3
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ dependencies = [
|
||||||
"geocube>=0.7.1,<0.8",
|
"geocube>=0.7.1,<0.8",
|
||||||
"streamlit>=1.50.0,<2",
|
"streamlit>=1.50.0,<2",
|
||||||
"altair[all]>=5.5.0,<6",
|
"altair[all]>=5.5.0,<6",
|
||||||
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
|
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
||||||
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
||||||
|
|
||||||
|
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
|
||||||
|
darts_counts = grid_gdf[darts_counts_columns]
|
||||||
|
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
|
||||||
|
|
||||||
output_path = get_darts_rts_file(grid, level)
|
output_path = get_darts_rts_file(grid, level)
|
||||||
grid_gdf.to_parquet(output_path)
|
grid_gdf.to_parquet(output_path)
|
||||||
print(f"Saved RTS labels to {output_path}")
|
print(f"Saved RTS labels to {output_path}")
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,14 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
return dataset_file
|
return dataset_file
|
||||||
|
|
||||||
|
|
||||||
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_cv_results_dir(
|
||||||
|
name: str,
|
||||||
|
grid: Literal["hex", "healpix"],
|
||||||
|
level: int,
|
||||||
|
task: Literal["binary", "multi"],
|
||||||
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_binary"
|
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return results_dir
|
return results_dir
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,50 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
|
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
|
||||||
|
"""Create X and y data from the training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
||||||
|
|
||||||
|
"""
|
||||||
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
data = gpd.read_parquet(data)
|
||||||
|
data = data[data["darts_has_coverage"]]
|
||||||
|
|
||||||
|
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
|
||||||
|
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||||
|
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||||
|
if task == "binary":
|
||||||
|
labels = ["No RTS", "RTS"]
|
||||||
|
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||||
|
else:
|
||||||
|
# Put into n categories (log scaled)
|
||||||
|
y_data = data.loc[X_data.index, "darts_rts_count"]
|
||||||
|
n_categories = 5
|
||||||
|
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||||
|
# Change the first interval to start at 1 and add a category for 0
|
||||||
|
bins = pd.IntervalIndex.from_tuples(
|
||||||
|
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
||||||
|
)
|
||||||
|
y_data = pd.cut(y_data, bins=bins)
|
||||||
|
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||||
|
y_data = y_data.cat.codes
|
||||||
|
return data, X_data, y_data, labels
|
||||||
|
|
||||||
|
|
||||||
|
def random_cv(
|
||||||
|
grid: Literal["hex", "healpix"],
|
||||||
|
level: int,
|
||||||
|
n_iter: int = 2000,
|
||||||
|
robust: bool = False,
|
||||||
|
task: Literal["binary", "multi"] = "binary",
|
||||||
|
):
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -37,21 +80,17 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||||
|
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data = get_train_dataset_file(grid=grid, level=level)
|
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||||
data = gpd.read_parquet(data)
|
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
||||||
data = data[data["darts_has_coverage"]]
|
print(f"{y_data.describe()=}")
|
||||||
|
|
||||||
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
|
||||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
|
||||||
X_data = data.drop(columns=cols_to_drop).dropna()
|
|
||||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
|
||||||
X = X_data.to_numpy(dtype="float32")
|
X = X_data.to_numpy(dtype="float32")
|
||||||
y = y_data.to_numpy(dtype="int8")
|
y = y_data.to_numpy(dtype="int8")
|
||||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||||
print(f"{X.shape=}, {y.shape=}")
|
print(f"{X.shape=}, {y.shape=}")
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||||
|
|
||||||
param_grid = {
|
param_grid = {
|
||||||
|
|
@ -62,7 +101,21 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
|
|
||||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
|
if task == "binary":
|
||||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||||
|
else:
|
||||||
|
metrics = [
|
||||||
|
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
|
||||||
|
"f1_macro",
|
||||||
|
"f1_weighted",
|
||||||
|
"precision_macro",
|
||||||
|
"precision_weighted",
|
||||||
|
"recall_macro",
|
||||||
|
"recall_weighted",
|
||||||
|
"jaccard_micro",
|
||||||
|
"jaccard_macro",
|
||||||
|
"jaccard_weighted",
|
||||||
|
]
|
||||||
search = RandomizedSearchCV(
|
search = RandomizedSearchCV(
|
||||||
clf,
|
clf,
|
||||||
param_grid,
|
param_grid,
|
||||||
|
|
@ -72,7 +125,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=3,
|
verbose=3,
|
||||||
scoring=metrics,
|
scoring=metrics,
|
||||||
refit="f1",
|
refit="f1" if task == "binary" else "f1_weighted",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||||
|
|
@ -88,10 +141,11 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||||
|
|
||||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
|
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
||||||
|
|
||||||
# Store the search settings
|
# Store the search settings
|
||||||
settings = {
|
settings = {
|
||||||
|
"task": task,
|
||||||
"grid": grid,
|
"grid": grid,
|
||||||
"level": level,
|
"level": level,
|
||||||
"random_state": 42,
|
"random_state": 42,
|
||||||
|
|
@ -115,7 +169,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
},
|
},
|
||||||
"cv_splits": cv.get_n_splits(),
|
"cv_splits": cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
"metrics": metrics,
|
||||||
"classes": ["No RTS", "RTS"],
|
"classes": labels,
|
||||||
}
|
}
|
||||||
settings_file = results_dir / "search_settings.toml"
|
settings_file = results_dir / "search_settings.toml"
|
||||||
print(f"Storing search settings to {settings_file}")
|
print(f"Storing search settings to {settings_file}")
|
||||||
|
|
@ -144,7 +198,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
best_estimator = search.best_estimator_
|
best_estimator = search.best_estimator_
|
||||||
# Annotate the state with xarray metadata
|
# Annotate the state with xarray metadata
|
||||||
features = X_data.columns.tolist()
|
features = X_data.columns.tolist()
|
||||||
labels = y_data.unique().tolist()
|
|
||||||
boxes = list(range(best_estimator.K_))
|
boxes = list(range(best_estimator.K_))
|
||||||
box_centers = xr.DataArray(
|
box_centers = xr.DataArray(
|
||||||
best_estimator.S_.cpu().numpy(),
|
best_estimator.S_.cpu().numpy(),
|
||||||
|
|
@ -185,7 +238,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
|
||||||
|
|
||||||
# Predict probabilities for all cells
|
# Predict probabilities for all cells
|
||||||
print("Predicting probabilities for all cells...")
|
print("Predicting probabilities for all cells...")
|
||||||
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
|
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
|
||||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||||
print(f"Storing predicted probabilities to {preds_file}")
|
print(f"Storing predicted probabilities to {preds_file}")
|
||||||
preds.to_parquet(preds_file)
|
preds.to_parquet(preds_file)
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,93 @@ from pathlib import Path
|
||||||
|
|
||||||
import altair as alt
|
import altair as alt
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import matplotlib.path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import streamlit_folium as st_folium
|
import streamlit_folium as st_folium
|
||||||
import toml
|
import toml
|
||||||
|
import ultraplot as uplt
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
|
from matplotlib.patches import PathPatch
|
||||||
|
|
||||||
from entropice.paths import RESULTS_DIR
|
from entropice.paths import RESULTS_DIR
|
||||||
|
from entropice.training import create_xy_data
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unified_colormap(settings: dict):
|
||||||
|
"""Generate unified colormaps for all plotting libraries.
|
||||||
|
|
||||||
|
This function creates consistent color schemes across Matplotlib/Ultraplot,
|
||||||
|
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
|
||||||
|
of classes from the settings, then generating appropriate colormaps for each library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Settings dictionary containing task type, classes, and other configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
|
||||||
|
- matplotlib_cmap: matplotlib ListedColormap object
|
||||||
|
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
|
||||||
|
- altair_colors: list of hex color strings for Altair
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Determine task type and number of classes from settings
|
||||||
|
task = settings.get("task", "binary")
|
||||||
|
n_classes = len(settings.get("classes", []))
|
||||||
|
|
||||||
|
# Check theme
|
||||||
|
is_dark_theme = st.context.theme.type == "dark"
|
||||||
|
|
||||||
|
# Define base colormaps for different tasks
|
||||||
|
if task == "binary":
|
||||||
|
# For binary: use a simple two-color scheme
|
||||||
|
if is_dark_theme:
|
||||||
|
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
|
||||||
|
else:
|
||||||
|
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
|
||||||
|
else:
|
||||||
|
# For multi-class: use a sequential colormap
|
||||||
|
# Use matplotlib's viridis colormap as base (better perceptual uniformity than inferno)
|
||||||
|
cmap = uplt.Colormap("viridis")
|
||||||
|
# Sample colors evenly across the colormap
|
||||||
|
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
|
||||||
|
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
|
||||||
|
|
||||||
|
# Create matplotlib colormap (for ultraplot and geopandas)
|
||||||
|
matplotlib_cmap = mcolors.ListedColormap(base_colors)
|
||||||
|
|
||||||
|
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
|
||||||
|
folium_cmap = mcolors.ListedColormap(base_colors)
|
||||||
|
|
||||||
|
# Create Altair color list (Altair uses hex color strings in range)
|
||||||
|
altair_colors = base_colors
|
||||||
|
|
||||||
|
return matplotlib_cmap, folium_cmap, altair_colors
|
||||||
|
|
||||||
|
|
||||||
|
def format_metric_name(metric: str) -> str:
|
||||||
|
"""Format metric name for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Split by underscore and capitalize each part
|
||||||
|
parts = metric.split("_")
|
||||||
|
# Special handling for F1
|
||||||
|
formatted_parts = []
|
||||||
|
for part in parts:
|
||||||
|
if part.lower() == "f1":
|
||||||
|
formatted_parts.append("F1")
|
||||||
|
else:
|
||||||
|
formatted_parts.append(part.capitalize())
|
||||||
|
return " ".join(formatted_parts)
|
||||||
|
|
||||||
|
|
||||||
def get_available_result_files() -> list[Path]:
|
def get_available_result_files() -> list[Path]:
|
||||||
|
|
@ -185,32 +263,229 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
|
||||||
return common_feature_array
|
return common_feature_array
|
||||||
|
|
||||||
|
|
||||||
def _plot_prediction_map(preds: gpd.GeoDataFrame):
|
def _plot_prediction_map_static(preds: gpd.GeoDataFrame, matplotlib_cmap: mcolors.ListedColormap):
|
||||||
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
|
"""Create a static map of predictions using ultraplot and cartopy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: GeoDataFrame with predicted classes (intervals as strings).
|
||||||
|
matplotlib_cmap: Matplotlib ListedColormap object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matplotlib figure with predictions colored by class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Create a copy to avoid modifying the original data
|
||||||
|
preds_plot = preds.copy()
|
||||||
|
|
||||||
|
# Replace the special (-1, 0] interval with "No RTS"
|
||||||
|
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(class_str.split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
# Get unique classes and sort them
|
||||||
|
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
|
||||||
|
|
||||||
|
# Create categorical with ordered categories
|
||||||
|
preds_plot["predicted_class"] = pd.Categorical(
|
||||||
|
preds_plot["predicted_class"], categories=unique_classes, ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect theme for styling
|
||||||
|
theme = st.context.theme.type
|
||||||
|
if theme == "light":
|
||||||
|
figcolor = "#f8f9fa"
|
||||||
|
axcolor = "#ffffff"
|
||||||
|
oceancolor = "#e3f2fd"
|
||||||
|
landcolor = "#ecf0f1"
|
||||||
|
gridcolor = "#666666"
|
||||||
|
fgcolor = "#2c3e50"
|
||||||
|
coastcolor = "#34495e"
|
||||||
|
else:
|
||||||
|
figcolor = "#0e1117"
|
||||||
|
axcolor = "#0e1117"
|
||||||
|
oceancolor = "#1a1d29"
|
||||||
|
landcolor = "#262730"
|
||||||
|
gridcolor = "#ffffff"
|
||||||
|
fgcolor = "#fafafa"
|
||||||
|
coastcolor = "#d0d0d0"
|
||||||
|
|
||||||
|
# Create figure with North Polar Stereographic projection
|
||||||
|
# proj = uplt.Proj("npaeqd")
|
||||||
|
proj = uplt.Proj("npstere", lon_0=-45)
|
||||||
|
fig, ax = uplt.subplots(proj=proj, figsize=(12, 12), facecolor=figcolor)
|
||||||
|
|
||||||
|
# Put a background on the figure.
|
||||||
|
bgpatch = PathPatch(matplotlib.path.Path.unit_rectangle(), transform=fig.transFigure, color=figcolor, zorder=-1)
|
||||||
|
fig.patches.append(bgpatch)
|
||||||
|
|
||||||
|
# Apply theme-appropriate styling with enhanced aesthetics
|
||||||
|
ax.format(
|
||||||
|
boundinglat=50,
|
||||||
|
coast=True,
|
||||||
|
coastcolor=coastcolor,
|
||||||
|
coastlinewidth=0.8,
|
||||||
|
land=True,
|
||||||
|
landcolor=landcolor,
|
||||||
|
ocean=True,
|
||||||
|
oceancolor=oceancolor,
|
||||||
|
title="Predicted RTS Classes",
|
||||||
|
titlecolor=fgcolor,
|
||||||
|
titleweight="bold",
|
||||||
|
titlesize=14,
|
||||||
|
color=fgcolor,
|
||||||
|
gridcolor=gridcolor,
|
||||||
|
gridlinewidth=0.5,
|
||||||
|
gridlinestyle="-",
|
||||||
|
gridalpha=0.3,
|
||||||
|
facecolor=axcolor,
|
||||||
|
labels=True,
|
||||||
|
labelcolor=fgcolor,
|
||||||
|
latlines=10,
|
||||||
|
lonlines=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot the predictions using the provided colormap with enhanced styling
|
||||||
|
preds_plot.to_crs(proj.proj4_init).plot(
|
||||||
|
ax=ax,
|
||||||
|
column="predicted_class",
|
||||||
|
cmap=matplotlib_cmap,
|
||||||
|
legend=True,
|
||||||
|
legend_kwds={
|
||||||
|
"loc": "lower left",
|
||||||
|
"frameon": True,
|
||||||
|
"framealpha": 0.95,
|
||||||
|
"edgecolor": gridcolor,
|
||||||
|
"facecolor": figcolor,
|
||||||
|
"fancybox": True,
|
||||||
|
"shadow": False,
|
||||||
|
"title": "RTS Class",
|
||||||
|
"title_fontsize": 11,
|
||||||
|
"fontsize": 9,
|
||||||
|
"labelspacing": 0.8,
|
||||||
|
"borderpad": 1.0,
|
||||||
|
"columnspacing": 1.0,
|
||||||
|
},
|
||||||
|
edgecolor="none",
|
||||||
|
linewidth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enhance legend styling after creation
|
||||||
|
legend = ax.get_legend()
|
||||||
|
if legend:
|
||||||
|
legend.set_title("RTS Class", prop={"size": 11, "weight": "bold"})
|
||||||
|
legend.get_frame().set_linewidth(0.5)
|
||||||
|
legend.get_frame().set_edgecolor(gridcolor)
|
||||||
|
legend.get_frame().set_facecolor(figcolor)
|
||||||
|
legend.get_frame().set_alpha(0.95)
|
||||||
|
# Set text color for legend labels
|
||||||
|
for text in legend.get_texts():
|
||||||
|
text.set_color(fgcolor)
|
||||||
|
# Set title color
|
||||||
|
legend.get_title().set_color(fgcolor)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
|
def _plot_prediction_map(preds: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
|
||||||
|
"""Create an interactive map of predictions with properly sorted classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: GeoDataFrame with predicted classes (intervals as strings).
|
||||||
|
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Folium map with predictions colored by class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Create a copy to avoid modifying the original data
|
||||||
|
preds_plot = preds.copy()
|
||||||
|
|
||||||
|
# Replace the special (-1, 0] interval with "No RTS"
|
||||||
|
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(class_str.split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
# Get unique classes and sort them
|
||||||
|
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
|
||||||
|
|
||||||
|
# Create categorical with ordered categories
|
||||||
|
preds_plot["predicted_class"] = pd.Categorical(
|
||||||
|
preds_plot["predicted_class"], categories=unique_classes, ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select tiles based on theme
|
||||||
|
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
|
||||||
|
return preds_plot.explore(column="predicted_class", cmap=folium_cmap, legend=True, tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame, altair_colors: list[str]):
|
||||||
"""Create a bar chart showing the count of each predicted class.
|
"""Create a bar chart showing the count of each predicted class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preds: GeoDataFrame with predicted classes.
|
preds: GeoDataFrame with predicted classes.
|
||||||
|
altair_colors: List of hex color strings for altair.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Altair chart showing class distribution.
|
Altair chart showing class distribution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
|
# Create a copy and apply the same transformations as the map
|
||||||
|
preds_plot = preds.copy()
|
||||||
|
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(class_str.split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
df = pd.DataFrame({"predicted_class": preds_plot["predicted_class"].to_numpy()})
|
||||||
counts = df["predicted_class"].value_counts().reset_index()
|
counts = df["predicted_class"].value_counts().reset_index()
|
||||||
counts.columns = ["class", "count"]
|
counts.columns = ["class", "count"]
|
||||||
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
|
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
|
||||||
|
|
||||||
|
# Sort counts by the same key
|
||||||
|
counts["sort_key"] = counts["class"].apply(sort_key)
|
||||||
|
counts = counts.sort_values("sort_key")
|
||||||
|
|
||||||
|
# Create an ordered list of classes for consistent color mapping
|
||||||
|
class_order = counts["class"].tolist()
|
||||||
|
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(counts)
|
alt.Chart(counts)
|
||||||
.mark_bar()
|
.mark_bar()
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("class:N", title="Predicted Class"),
|
x=alt.X("class:N", title="Predicted Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
|
||||||
y=alt.Y("count:Q", title="Number of Cells"),
|
y=alt.Y("count:Q", title="Number of Cells"),
|
||||||
color=alt.Color("class:N", title="Class", legend=None),
|
color=alt.Color(
|
||||||
|
"class:N",
|
||||||
|
title="Class",
|
||||||
|
scale=alt.Scale(domain=class_order, range=altair_colors),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
tooltip=[
|
tooltip=[
|
||||||
alt.Tooltip("class:N", title="Class"),
|
alt.Tooltip("class:N", title="Class"),
|
||||||
alt.Tooltip("count:Q", title="Count"),
|
alt.Tooltip("count:Q", title="Count"),
|
||||||
|
|
@ -286,41 +561,36 @@ def _plot_k_binned(
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
def _plot_params_binned(results: pd.DataFrame, x: str, hue: str, col: str, metric: str):
|
||||||
"""Plot epsilon-binned results with K parameter."""
|
"""Plot epsilon-binned results with K parameter."""
|
||||||
assert "initial_K" in results.columns, "initial_K column not found in results."
|
|
||||||
assert metric in results.columns, f"{metric} not found in results."
|
assert metric in results.columns, f"{metric} not found in results."
|
||||||
|
assert x in ["eps_cl", "eps_e", "initial_K"]
|
||||||
|
assert hue in ["eps_cl", "eps_e", "initial_K"]
|
||||||
|
assert col in ["eps_cl_binned", "eps_e_binned", "initial_K_binned"]
|
||||||
|
|
||||||
if target == "eps_cl":
|
assert x in results.columns, f"{x} column not found in results."
|
||||||
hue = "eps_cl"
|
|
||||||
col = "eps_e_binned"
|
|
||||||
elif target == "eps_e":
|
|
||||||
hue = "eps_e"
|
|
||||||
col = "eps_cl_binned"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid target: {target}")
|
|
||||||
|
|
||||||
assert hue in results.columns, f"{hue} column not found in results."
|
assert hue in results.columns, f"{hue} column not found in results."
|
||||||
assert col in results.columns, f"{col} column not found in results."
|
assert col in results.columns, f"{col} column not found in results."
|
||||||
|
|
||||||
# Prepare data
|
# Prepare data
|
||||||
plot_data = results[["initial_K", metric, hue, col]].copy()
|
plot_data = results[[x, metric, hue, col]].copy()
|
||||||
|
|
||||||
# Sort bins by their left value and convert to string with sorted categories
|
# Sort bins by their left value and convert to string with sorted categories
|
||||||
plot_data = plot_data.sort_values(col)
|
plot_data = plot_data.sort_values(col)
|
||||||
plot_data[col] = plot_data[col].astype(str)
|
plot_data[col] = plot_data[col].astype(str)
|
||||||
bin_order = plot_data[col].unique().tolist()
|
bin_order = plot_data[col].unique().tolist()
|
||||||
|
xscale = alt.Scale(type="log") if x in ["eps_cl", "eps_e"] else alt.Scale()
|
||||||
|
cscheme = "bluepurple" if hue == "eps_e" else "purplered" if hue == "eps_cl" else "greenblue"
|
||||||
# Create the chart
|
# Create the chart
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(plot_data)
|
alt.Chart(plot_data)
|
||||||
.mark_circle(size=60, opacity=0.7)
|
.mark_circle(size=60, opacity=0.7)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("initial_K:Q", title="Initial K"),
|
x=alt.X(f"{x}:Q", title=x, scale=xscale),
|
||||||
y=alt.Y(f"{metric}:Q", title=metric),
|
y=alt.Y(f"{metric}:Q", title=metric),
|
||||||
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue),
|
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme=cscheme), title=hue),
|
||||||
tooltip=[
|
tooltip=[
|
||||||
"initial_K:Q",
|
f"{x}:Q",
|
||||||
alt.Tooltip(f"{metric}:Q", format=".4f"),
|
alt.Tooltip(f"{metric}:Q", format=".4f"),
|
||||||
alt.Tooltip(f"{hue}:Q", format=".2e"),
|
alt.Tooltip(f"{hue}:Q", format=".2e"),
|
||||||
f"{col}:N",
|
f"{col}:N",
|
||||||
|
|
@ -334,11 +604,14 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
|
|
||||||
|
|
||||||
def _parse_results_dir_name(results_dir: Path) -> str:
|
def _parse_results_dir_name(results_dir: Path) -> str:
|
||||||
gridname, date = results_dir.name.split("_random_search_cv")
|
gridname, date = results_dir.name.replace("_binary", "").replace("_multi", "").split("_random_search_cv")
|
||||||
gridname = gridname.lstrip("permafrost_")
|
gridname = gridname.lstrip("permafrost_")
|
||||||
date = datetime.strptime(date, "%Y%m%d-%H%M%S")
|
date = datetime.strptime(date, "%Y%m%d-%H%M%S")
|
||||||
date = date.strftime("%Y-%m-%d %H:%M:%S")
|
date = date.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
return f"{gridname} ({date})"
|
|
||||||
|
settings = toml.load(results_dir / "search_settings.toml")["settings"]
|
||||||
|
task = settings.get("task", "binary")
|
||||||
|
return f"[{task.capitalize()}] {gridname.capitalize()} ({date})"
|
||||||
|
|
||||||
|
|
||||||
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
|
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
|
||||||
|
|
@ -667,11 +940,12 @@ def _plot_box_assignments(model_state: xr.Dataset):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def _plot_box_assignment_bars(model_state: xr.Dataset):
|
def _plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]):
|
||||||
"""Create a bar chart showing how many boxes are assigned to each class.
|
"""Create a bar chart showing how many boxes are assigned to each class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_state: The xarray Dataset containing the model state with box_assignments.
|
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||||
|
altair_colors: List of hex color strings for altair.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Altair chart showing count of boxes per class.
|
Altair chart showing count of boxes per class.
|
||||||
|
|
@ -690,14 +964,40 @@ def _plot_box_assignment_bars(model_state: xr.Dataset):
|
||||||
# Count boxes per class
|
# Count boxes per class
|
||||||
counts = primary_classes.groupby("class").size().reset_index(name="count")
|
counts = primary_classes.groupby("class").size().reset_index(name="count")
|
||||||
|
|
||||||
|
# Replace the special (-1, 0] interval with "No RTS" if present
|
||||||
|
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(str(class_str).split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
# Sort counts by the same key
|
||||||
|
counts["sort_key"] = counts["class"].apply(sort_key)
|
||||||
|
counts = counts.sort_values("sort_key")
|
||||||
|
|
||||||
|
# Create an ordered list of classes for consistent color mapping
|
||||||
|
class_order = counts["class"].tolist()
|
||||||
|
|
||||||
# Create bar chart
|
# Create bar chart
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(counts)
|
alt.Chart(counts)
|
||||||
.mark_bar()
|
.mark_bar()
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("class:N", title="Class Label"),
|
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
|
||||||
y=alt.Y("count:Q", title="Number of Boxes"),
|
y=alt.Y("count:Q", title="Number of Boxes"),
|
||||||
color=alt.Color("class:N", title="Class", legend=None),
|
color=alt.Color(
|
||||||
|
"class:N",
|
||||||
|
title="Class",
|
||||||
|
scale=alt.Scale(domain=class_order, range=altair_colors),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
tooltip=[
|
tooltip=[
|
||||||
alt.Tooltip("class:N", title="Class"),
|
alt.Tooltip("class:N", title="Class"),
|
||||||
alt.Tooltip("count:Q", title="Number of Boxes"),
|
alt.Tooltip("count:Q", title="Number of Boxes"),
|
||||||
|
|
@ -758,7 +1058,7 @@ def _plot_common_features(common_array: xr.DataArray):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
|
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str, refit_metric: str):
|
||||||
"""Create a scatter plot comparing two metrics with parameter-based coloring.
|
"""Create a scatter plot comparing two metrics with parameter-based coloring.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -766,6 +1066,7 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
|
||||||
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
|
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
|
||||||
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
|
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
|
||||||
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
|
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
|
||||||
|
refit_metric: The metric used for refitting (e.g., 'f1' or 'f1_macro').
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Altair chart showing the metric comparison.
|
Altair chart showing the metric comparison.
|
||||||
|
|
@ -781,17 +1082,17 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
|
||||||
alt.Chart(results)
|
alt.Chart(results)
|
||||||
.mark_circle(size=60, opacity=0.7)
|
.mark_circle(size=60, opacity=0.7)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
|
x=alt.X(f"mean_test_{x_metric}:Q", title=format_metric_name(x_metric), scale=alt.Scale(zero=False)),
|
||||||
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
|
y=alt.Y(f"mean_test_{y_metric}:Q", title=format_metric_name(y_metric), scale=alt.Scale(zero=False)),
|
||||||
color=alt.Color(
|
color=alt.Color(
|
||||||
f"{color_param}:Q",
|
f"{color_param}:Q",
|
||||||
scale=color_scale,
|
scale=color_scale,
|
||||||
title=color_param,
|
title=color_param,
|
||||||
),
|
),
|
||||||
tooltip=[
|
tooltip=[
|
||||||
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
|
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=format_metric_name(x_metric)),
|
||||||
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
|
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=format_metric_name(y_metric)),
|
||||||
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
|
alt.Tooltip(f"mean_test_{refit_metric}:Q", format=".4f", title=format_metric_name(refit_metric)),
|
||||||
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
|
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
|
||||||
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
|
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
|
||||||
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
|
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
|
||||||
|
|
@ -807,6 +1108,148 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def _load_training_data(grid: str, level: int, task: str):
|
||||||
|
"""Load training data for analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid: Grid type (hex or healpix).
|
||||||
|
level: Grid level.
|
||||||
|
task: Classification task type (binary or multi).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (GeoDataFrame with training data, feature DataFrame, labels Series, label names list).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Use create_xy_data to get X and y data
|
||||||
|
data, x_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||||
|
|
||||||
|
# Convert y_data to string labels for visualization
|
||||||
|
if task == "binary":
|
||||||
|
y_data_str = y_data.map({False: "No RTS", True: "RTS"})
|
||||||
|
else:
|
||||||
|
# For multi-class, reconstruct the interval labels from the codes
|
||||||
|
y_data_counts = data.loc[x_data.index, "darts_rts_count"]
|
||||||
|
n_categories = 5
|
||||||
|
bins = pd.qcut(y_data_counts, q=n_categories, duplicates="drop").unique().categories
|
||||||
|
bins = pd.IntervalIndex.from_tuples(
|
||||||
|
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
|
||||||
|
)
|
||||||
|
y_data_str = pd.cut(y_data_counts, bins=bins).astype(str)
|
||||||
|
|
||||||
|
# Create a GeoDataFrame with geometry and labels
|
||||||
|
training_gdf = data.loc[x_data.index, ["geometry"]].copy()
|
||||||
|
training_gdf["target_class"] = y_data_str
|
||||||
|
|
||||||
|
return training_gdf, x_data, y_data_str, labels
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_training_data_map(training_gdf: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
|
||||||
|
"""Create an interactive map of training data with properly sorted classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_gdf: GeoDataFrame with target classes (intervals as strings or category names).
|
||||||
|
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Folium map with training data colored by class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Create a copy to avoid modifying the original data
|
||||||
|
training_plot = training_gdf.copy()
|
||||||
|
|
||||||
|
# Replace the special (-1, 0] interval with "No RTS" if present
|
||||||
|
training_plot["target_class"] = training_plot["target_class"].replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(class_str.split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
# Get unique classes and sort them
|
||||||
|
unique_classes = sorted(training_plot["target_class"].unique(), key=sort_key)
|
||||||
|
|
||||||
|
# Create categorical with ordered categories
|
||||||
|
training_plot["target_class"] = pd.Categorical(
|
||||||
|
training_plot["target_class"], categories=unique_classes, ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select tiles based on theme
|
||||||
|
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
|
||||||
|
return training_plot.explore(column="target_class", cmap=folium_cmap, legend=True, tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_training_class_distribution(y_data: pd.Series, altair_colors: list[str]):
|
||||||
|
"""Create a bar chart showing the count of each class in training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_data: Series with target classes.
|
||||||
|
altair_colors: List of hex color strings for altair.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Altair chart showing class distribution.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Create a copy and apply the same transformations as the map
|
||||||
|
y_data_plot = y_data.copy()
|
||||||
|
y_data_plot = y_data_plot.replace("(-1, 0]", "No RTS")
|
||||||
|
|
||||||
|
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||||
|
def sort_key(class_str):
|
||||||
|
if class_str == "No RTS":
|
||||||
|
return -1 # Put "No RTS" first
|
||||||
|
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||||
|
try:
|
||||||
|
lower = float(class_str.split(",")[0].strip("([ "))
|
||||||
|
return lower
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return float("inf") # Put unparseable values at the end
|
||||||
|
|
||||||
|
df = pd.DataFrame({"target_class": y_data_plot.to_numpy()})
|
||||||
|
counts = df["target_class"].value_counts().reset_index()
|
||||||
|
counts.columns = ["class", "count"]
|
||||||
|
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
|
||||||
|
|
||||||
|
# Sort counts by the same key
|
||||||
|
counts["sort_key"] = counts["class"].apply(sort_key)
|
||||||
|
counts = counts.sort_values("sort_key")
|
||||||
|
|
||||||
|
# Create an ordered list of classes for consistent color mapping
|
||||||
|
class_order = counts["class"].tolist()
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(counts)
|
||||||
|
.mark_bar()
|
||||||
|
.encode(
|
||||||
|
x=alt.X("class:N", title="Target Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
|
||||||
|
y=alt.Y("count:Q", title="Number of Samples"),
|
||||||
|
color=alt.Color(
|
||||||
|
"class:N",
|
||||||
|
title="Class",
|
||||||
|
scale=alt.Scale(domain=class_order, range=altair_colors),
|
||||||
|
legend=None,
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("class:N", title="Class"),
|
||||||
|
alt.Tooltip("count:Q", title="Count"),
|
||||||
|
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(
|
||||||
|
width=400,
|
||||||
|
height=300,
|
||||||
|
title="Training Data Class Distribution",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Run Streamlit dashboard application."""
|
"""Run Streamlit dashboard application."""
|
||||||
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
@ -846,7 +1289,19 @@ def main():
|
||||||
common_feature_array = extract_common_features(model_state)
|
common_feature_array = extract_common_features(model_state)
|
||||||
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
|
||||||
|
|
||||||
|
# Determine task type and available metrics
|
||||||
|
task = settings.get("task", "binary")
|
||||||
|
available_metrics = settings.get("metrics", ["accuracy", "recall", "precision", "f1", "jaccard"])
|
||||||
|
refit_metric = "f1" if task == "binary" else "f1_macro"
|
||||||
|
|
||||||
|
# Generate unified colormaps once for all visualizations
|
||||||
|
matplotlib_cmap, folium_cmap, altair_colors = generate_unified_colormap(settings)
|
||||||
|
|
||||||
st.sidebar.success(f"Loaded {len(results)} results")
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
st.sidebar.info(f"Task: {task.capitalize()} Classification")
|
||||||
|
# Dump the settings into the sidebar
|
||||||
|
with st.sidebar.expander("Search Settings", expanded=True):
|
||||||
|
st.json(settings)
|
||||||
|
|
||||||
# Display some basic statistics first (lightweight)
|
# Display some basic statistics first (lightweight)
|
||||||
st.header("Parameter-Search Overview")
|
st.header("Parameter-Search Overview")
|
||||||
|
|
@ -857,16 +1312,18 @@ def main():
|
||||||
st.metric("Total Runs", len(results))
|
st.metric("Total Runs", len(results))
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
# Best model based on F1 score
|
# Best model based on refit metric
|
||||||
best_f1_idx = results["mean_test_f1"].idxmax()
|
best_idx = results[f"mean_test_{refit_metric}"].idxmax()
|
||||||
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
|
st.metric(f"Best Model Index (by {format_metric_name(refit_metric)})", f"#{best_idx}")
|
||||||
|
|
||||||
# Show best parameters for the best model
|
# Show best parameters for the best model
|
||||||
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
|
best_params = results.loc[
|
||||||
|
best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{refit_metric}", f"std_test_{refit_metric}"]
|
||||||
|
]
|
||||||
|
|
||||||
with st.container(border=True):
|
with st.container(border=True):
|
||||||
st.subheader(":abacus: Best Model Parameters")
|
st.subheader(":abacus: Best Model Parameters")
|
||||||
st.caption("Parameters of retrained best model (selected by F1 score)")
|
st.caption(f"Parameters of retrained best model (selected by {format_metric_name(refit_metric)} score)")
|
||||||
col1, col2, col3 = st.columns(3)
|
col1, col2, col3 = st.columns(3)
|
||||||
with col1:
|
with col1:
|
||||||
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
|
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
|
||||||
|
|
@ -877,32 +1334,132 @@ def main():
|
||||||
|
|
||||||
# Show all metrics for the best model in a container
|
# Show all metrics for the best model in a container
|
||||||
st.subheader(":bar_chart: Performance Across All Metrics")
|
st.subheader(":bar_chart: Performance Across All Metrics")
|
||||||
st.caption("Complete performance profile of the best model (selected by F1 score)")
|
st.caption(
|
||||||
|
f"Complete performance profile of the best model (selected by {format_metric_name(refit_metric)} score)"
|
||||||
|
)
|
||||||
|
|
||||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
|
||||||
cols = st.columns(len(available_metrics))
|
cols = st.columns(len(available_metrics))
|
||||||
|
|
||||||
for idx, metric in enumerate(available_metrics):
|
for idx, metric in enumerate(available_metrics):
|
||||||
with cols[idx]:
|
with cols[idx]:
|
||||||
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
|
best_score = results.loc[best_idx, f"mean_test_{metric}"]
|
||||||
best_std = results.loc[best_f1_idx, f"std_test_{metric}"]
|
best_std = results.loc[best_idx, f"std_test_{metric}"]
|
||||||
# Highlight F1 since that's what we optimized for
|
# Highlight refit metric since that's what we optimized for
|
||||||
st.metric(
|
st.metric(
|
||||||
f"{metric.capitalize()}",
|
format_metric_name(metric),
|
||||||
f"{best_score:.4f}",
|
f"{best_score:.4f}",
|
||||||
delta=f"±{best_std:.4f}",
|
delta=f"±{best_std:.4f}",
|
||||||
help="Mean ± std across cross-validation folds",
|
help="Mean ± std across cross-validation folds",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tabs for different visualizations
|
# Create tabs for different visualizations
|
||||||
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
|
tab0, tab1, tab2, tab3 = st.tabs(["Training Data", "Search Results", "Model State", "Inference Analysis"])
|
||||||
|
|
||||||
|
with tab0:
|
||||||
|
# Training Data Analysis
|
||||||
|
st.header("Training Data Analysis")
|
||||||
|
st.markdown("Comprehensive analysis of the training dataset used for model development")
|
||||||
|
|
||||||
|
# Load training data
|
||||||
|
with st.spinner("Loading training data..."):
|
||||||
|
training_gdf, X_data, y_data, _ = _load_training_data(
|
||||||
|
grid=settings["grid"], level=settings["level"], task=task
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
st.subheader("Dataset Statistics")
|
||||||
|
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Samples", f"{len(training_gdf):,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("Number of Features", f"{X_data.shape[1]:,}")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
n_classes = y_data.nunique()
|
||||||
|
st.metric("Number of Classes", n_classes)
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
missing_pct = X_data.isnull().sum().sum() / (X_data.shape[0] * X_data.shape[1]) * 100
|
||||||
|
st.metric("Missing Values", f"{missing_pct:.2f}%")
|
||||||
|
|
||||||
|
# Class distribution visualization
|
||||||
|
st.subheader("Class Distribution")
|
||||||
|
|
||||||
|
with st.spinner("Generating class distribution..."):
|
||||||
|
class_dist_chart = _plot_training_class_distribution(y_data, altair_colors)
|
||||||
|
st.altair_chart(class_dist_chart, use_container_width=True)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Interpretation:**
|
||||||
|
- Shows the balance between different classes in the training dataset
|
||||||
|
- Class imbalance affects model learning and may require special handling
|
||||||
|
- Each bar represents the count of training samples for that class
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Interactive map
|
||||||
|
st.subheader("Interactive Training Data Map")
|
||||||
|
st.markdown("Explore the spatial distribution of training samples by class")
|
||||||
|
|
||||||
|
with st.spinner("Generating interactive map..."):
|
||||||
|
training_map = _plot_training_data_map(training_gdf, folium_cmap)
|
||||||
|
st_folium.st_folium(training_map, width="100%", height=600, returned_objects=[])
|
||||||
|
|
||||||
|
# Additional statistics in expander
|
||||||
|
with st.expander("Detailed Training Data Statistics"):
|
||||||
|
st.write("**Class Distribution:**")
|
||||||
|
class_counts = y_data.value_counts().sort_index()
|
||||||
|
|
||||||
|
# Create columns for better layout
|
||||||
|
n_cols = min(5, len(class_counts))
|
||||||
|
cols = st.columns(n_cols)
|
||||||
|
|
||||||
|
for idx, (class_label, count) in enumerate(class_counts.items()):
|
||||||
|
percentage = count / len(y_data) * 100
|
||||||
|
with cols[idx % n_cols]:
|
||||||
|
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
|
||||||
|
|
||||||
|
# Show detailed table
|
||||||
|
st.write("**Detailed Class Breakdown:**")
|
||||||
|
class_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Class": class_counts.index,
|
||||||
|
"Count": class_counts.to_numpy(),
|
||||||
|
"Percentage": (class_counts.to_numpy() / len(y_data) * 100).round(2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
st.dataframe(class_df, width="stretch", hide_index=True)
|
||||||
|
|
||||||
|
# Feature statistics
|
||||||
|
st.write("**Feature Statistics:**")
|
||||||
|
st.markdown(f"- Total number of features: **{X_data.shape[1]}**")
|
||||||
|
st.markdown(f"- Features with missing values: **{X_data.isnull().any().sum()}**")
|
||||||
|
|
||||||
|
# Show feature types breakdown
|
||||||
|
feature_types = X_data.columns.to_series().apply(lambda x: x.split("_")[0]).value_counts()
|
||||||
|
st.write("**Feature Type Distribution:**")
|
||||||
|
feature_type_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Feature Type": feature_types.index,
|
||||||
|
"Count": feature_types.to_numpy(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
st.dataframe(feature_type_df, width="stretch", hide_index=True)
|
||||||
|
|
||||||
with tab1:
|
with tab1:
|
||||||
# Metric selection - only used in this tab
|
# Metric selection - only used in this tab
|
||||||
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
metric_display_names = {metric: format_metric_name(metric) for metric in available_metrics}
|
||||||
selected_metric = st.selectbox(
|
selected_metric_display = st.selectbox(
|
||||||
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
"Select Metric",
|
||||||
|
options=[format_metric_name(m) for m in available_metrics],
|
||||||
|
help="Choose which metric to visualize",
|
||||||
)
|
)
|
||||||
|
# Convert back to raw metric name
|
||||||
|
selected_metric = next(k for k, v in metric_display_names.items() if v == selected_metric_display)
|
||||||
|
|
||||||
# Show best parameters
|
# Show best parameters
|
||||||
with st.expander("Best Parameters"):
|
with st.expander("Best Parameters"):
|
||||||
|
|
@ -911,7 +1468,7 @@ def main():
|
||||||
st.dataframe(best_params.to_frame().T, width="content")
|
st.dataframe(best_params.to_frame().T, width="content")
|
||||||
|
|
||||||
# Main plots
|
# Main plots
|
||||||
st.header(f"Visualization for {selected_metric.capitalize()}")
|
st.header(f"Visualization for {format_metric_name(selected_metric)}")
|
||||||
|
|
||||||
# K-binned plot configuration
|
# K-binned plot configuration
|
||||||
@st.fragment
|
@st.fragment
|
||||||
|
|
@ -978,17 +1535,35 @@ def main():
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
st.subheader("K vs eps_cl")
|
st.subheader(f"K vs {selected_metric} (binned by eps_e)")
|
||||||
with st.spinner("Generating eps_cl plot..."):
|
with st.spinner("Generating plot..."):
|
||||||
chart3 = _plot_eps_binned(results_binned, "eps_cl", f"mean_test_{selected_metric}")
|
chart3 = _plot_params_binned(
|
||||||
|
results_binned, "initial_K", "eps_cl", "eps_e_binned", f"mean_test_{selected_metric}"
|
||||||
|
)
|
||||||
st.altair_chart(chart3, use_container_width=True)
|
st.altair_chart(chart3, use_container_width=True)
|
||||||
|
|
||||||
with col2:
|
st.subheader(f"eps_e vs {selected_metric} (binned by K)")
|
||||||
st.subheader("K vs eps_e")
|
with st.spinner("Generating plot..."):
|
||||||
with st.spinner("Generating eps_e plot..."):
|
chart4 = _plot_params_binned(
|
||||||
chart4 = _plot_eps_binned(results_binned, "eps_e", f"mean_test_{selected_metric}")
|
results_binned, "eps_e", "eps_cl", "initial_K_binned", f"mean_test_{selected_metric}"
|
||||||
|
)
|
||||||
st.altair_chart(chart4, use_container_width=True)
|
st.altair_chart(chart4, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader(f"K vs {selected_metric} (binned by eps_cl)")
|
||||||
|
with st.spinner("Generating plot..."):
|
||||||
|
chart5 = _plot_params_binned(
|
||||||
|
results_binned, "initial_K", "eps_e", "eps_cl_binned", f"mean_test_{selected_metric}"
|
||||||
|
)
|
||||||
|
st.altair_chart(chart5, use_container_width=True)
|
||||||
|
|
||||||
|
st.subheader(f"eps_cl vs {selected_metric} (binned by K)")
|
||||||
|
with st.spinner("Generating plot..."):
|
||||||
|
chart6 = _plot_params_binned(
|
||||||
|
results_binned, "eps_cl", "eps_e", "initial_K_binned", f"mean_test_{selected_metric}"
|
||||||
|
)
|
||||||
|
st.altair_chart(chart6, use_container_width=True)
|
||||||
|
|
||||||
render_k_binned_plots()
|
render_k_binned_plots()
|
||||||
|
|
||||||
# Metric comparison plots
|
# Metric comparison plots
|
||||||
|
|
@ -1003,19 +1578,37 @@ def main():
|
||||||
help="Choose which parameter to use for coloring the scatter plots",
|
help="Choose which parameter to use for coloring the scatter plots",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Dynamically determine which metrics to compare based on available metrics
|
||||||
|
if task == "binary":
|
||||||
|
# For binary: show recall vs precision and accuracy vs jaccard
|
||||||
|
comparisons = [
|
||||||
|
("precision", "recall", "Recall vs Precision"),
|
||||||
|
("accuracy", "jaccard", "Accuracy vs Jaccard"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# For multiclass: show micro vs macro variants
|
||||||
|
comparisons = [
|
||||||
|
("precision_macro", "recall_macro", "Recall Macro vs Precision Macro"),
|
||||||
|
("accuracy", "jaccard_macro", "Accuracy vs Jaccard Macro"),
|
||||||
|
]
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
st.subheader("Recall vs Precision")
|
st.subheader(comparisons[0][2])
|
||||||
with st.spinner("Generating Recall vs Precision plot..."):
|
with st.spinner(f"Generating {comparisons[0][2]} plot..."):
|
||||||
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
|
chart1 = _plot_metric_comparison(
|
||||||
st.altair_chart(recall_precision_chart, use_container_width=True)
|
results, comparisons[0][0], comparisons[0][1], color_param, refit_metric
|
||||||
|
)
|
||||||
|
st.altair_chart(chart1, use_container_width=True)
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
st.subheader("Accuracy vs Jaccard")
|
st.subheader(comparisons[1][2])
|
||||||
with st.spinner("Generating Accuracy vs Jaccard plot..."):
|
with st.spinner(f"Generating {comparisons[1][2]} plot..."):
|
||||||
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
|
chart2 = _plot_metric_comparison(
|
||||||
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
|
results, comparisons[1][0], comparisons[1][1], color_param, refit_metric
|
||||||
|
)
|
||||||
|
st.altair_chart(chart2, use_container_width=True)
|
||||||
|
|
||||||
render_metric_comparisons()
|
render_metric_comparisons()
|
||||||
|
|
||||||
|
|
@ -1094,7 +1687,7 @@ def main():
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
st.markdown("### Box Count by Class")
|
st.markdown("### Box Count by Class")
|
||||||
box_assignment_bars = _plot_box_assignment_bars(model_state)
|
box_assignment_bars = _plot_box_assignment_bars(model_state, altair_colors)
|
||||||
st.altair_chart(box_assignment_bars, use_container_width=True)
|
st.altair_chart(box_assignment_bars, use_container_width=True)
|
||||||
|
|
||||||
# Show statistics
|
# Show statistics
|
||||||
|
|
@ -1302,30 +1895,44 @@ def main():
|
||||||
n_classes = predictions["predicted_class"].nunique()
|
n_classes = predictions["predicted_class"].nunique()
|
||||||
st.metric("Number of Classes", n_classes)
|
st.metric("Number of Classes", n_classes)
|
||||||
|
|
||||||
# Class distribution visualization
|
# Class distribution and static map side by side
|
||||||
st.subheader("Class Distribution")
|
st.subheader("Prediction Overview")
|
||||||
|
|
||||||
|
col_map, col_dist = st.columns([0.6, 0.4])
|
||||||
|
|
||||||
|
with col_map:
|
||||||
|
st.markdown("#### Static Prediction Map")
|
||||||
|
st.markdown("High-quality map with North Polar Stereographic projection")
|
||||||
|
|
||||||
|
with st.spinner("Generating static map..."):
|
||||||
|
static_fig = _plot_prediction_map_static(predictions, matplotlib_cmap)
|
||||||
|
st.pyplot(static_fig, width="stretch")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
**Map Features:**
|
||||||
|
- North Polar Azimuthal Equidistant projection (optimized for Arctic)
|
||||||
|
- Predictions colored by class using inferno colormap
|
||||||
|
- Includes coastlines and 50°N latitude boundary
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with col_dist:
|
||||||
|
st.markdown("#### Class Distribution")
|
||||||
|
|
||||||
with st.spinner("Generating class distribution..."):
|
with st.spinner("Generating class distribution..."):
|
||||||
class_dist_chart = _plot_prediction_class_distribution(predictions)
|
class_dist_chart = _plot_prediction_class_distribution(predictions, altair_colors)
|
||||||
st.altair_chart(class_dist_chart, use_container_width=True)
|
st.altair_chart(class_dist_chart, use_container_width=True)
|
||||||
|
|
||||||
st.markdown(
|
st.markdown(
|
||||||
"""
|
"""
|
||||||
**Interpretation:**
|
**Interpretation:**
|
||||||
- Shows the balance between predicted classes
|
- Balance between predicted classes
|
||||||
- Class imbalance may indicate regional patterns or model bias
|
- Class imbalance may indicate regional patterns
|
||||||
- Each bar represents the count of cells predicted for that class
|
- Each bar shows cell count per class
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Interactive map
|
|
||||||
st.subheader("Interactive Prediction Map")
|
|
||||||
st.markdown("Explore predictions spatially with the interactive map below")
|
|
||||||
|
|
||||||
with st.spinner("Generating interactive map..."):
|
|
||||||
chart_map = _plot_prediction_map(predictions)
|
|
||||||
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
|
||||||
|
|
||||||
# Additional statistics in expander
|
# Additional statistics in expander
|
||||||
with st.expander("Detailed Prediction Statistics"):
|
with st.expander("Detailed Prediction Statistics"):
|
||||||
st.write("**Class Distribution:**")
|
st.write("**Class Distribution:**")
|
||||||
|
|
@ -1351,6 +1958,17 @@ def main():
|
||||||
)
|
)
|
||||||
st.dataframe(class_df, width="stretch", hide_index=True)
|
st.dataframe(class_df, width="stretch", hide_index=True)
|
||||||
|
|
||||||
|
# Interactive map
|
||||||
|
st.subheader("Interactive Prediction Map")
|
||||||
|
st.markdown("Explore predictions spatially with the interactive map below")
|
||||||
|
|
||||||
|
# with st.spinner("Generating interactive map..."):
|
||||||
|
# chart_map = _plot_prediction_map(predictions, folium_colors)
|
||||||
|
# st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
|
||||||
|
st.text("Interactive map functionality is currently disabled.")
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue