Add multiclass training

This commit is contained in:
Tobias Hölzer 2025-11-09 22:34:03 +01:00
parent 553b54bb32
commit d5b35d6da4
7 changed files with 814 additions and 4867 deletions

11
pixi.lock generated
View file

@ -285,6 +285,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
@ -1301,7 +1302,7 @@ packages:
- pypi: ./ - pypi: ./
name: entropice name: entropice
version: 0.1.0 version: 0.1.0
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9 sha256: 852c87cdbd1d452fccaa6253c6ce7410dda9fa32d2f951d35441e454414acfc0
requires_dist: requires_dist:
- aiohttp>=3.12.11 - aiohttp>=3.12.11
- bokeh>=3.7.3 - bokeh>=3.7.3
@ -1350,6 +1351,7 @@ packages:
- altair[all]>=5.5.0,<6 - altair[all]>=5.5.0,<6
- h5netcdf>=1.7.3,<2 - h5netcdf>=1.7.3,<2
- streamlit-folium>=0.25.3,<0.26 - streamlit-folium>=0.25.3,<0.26
- st-theme>=1.2.3,<2
editable: true editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy name: entropy
@ -4507,6 +4509,13 @@ packages:
version: 1.0.1 version: 1.0.1
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718 sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
requires_python: '>=3.6' requires_python: '>=3.6'
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
name: st-theme
version: 1.2.3
sha256: 0a54d9817dd5f8a6d7b0d071b25ae72eacf536c63a5fb97374923938021b1389
requires_dist:
- streamlit>=1.33
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
name: stack-data name: stack-data
version: 0.6.3 version: 0.6.3

View file

@ -51,7 +51,7 @@ dependencies = [
"geocube>=0.7.1,<0.8", "geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2", "streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6", "altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
] ]
[project.scripts] [project.scripts]

View file

@ -67,6 +67,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1) grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1) grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
output_path = get_darts_rts_file(grid, level) output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path) grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}") print(f"Saved RTS labels to {output_path}")

View file

@ -87,9 +87,14 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file return dataset_file
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path: def get_cv_results_dir(
name: str,
grid: Literal["hex", "healpix"],
level: int,
task: Literal["binary", "multi"],
) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_binary" results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
return results_dir return results_dir

View file

@ -29,7 +29,50 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False): def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
"""Create X and y data from the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
Returns:
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
data = data[data["darts_has_coverage"]]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
if task == "binary":
labels = ["No RTS", "RTS"]
y_data = data.loc[X_data.index, "darts_has_rts"]
else:
# Put into n categories (log scaled)
y_data = data.loc[X_data.index, "darts_rts_count"]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
# Change the first interval to start at 1 and add a category for 0
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
return data, X_data, y_data, labels
def random_cv(
grid: Literal["hex", "healpix"],
level: int,
n_iter: int = 2000,
robust: bool = False,
task: Literal["binary", "multi"] = "binary",
):
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
Args: Args:
@ -37,21 +80,17 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False. robust (bool, optional): Whether to use robust training. Defaults to False.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
""" """
data = get_train_dataset_file(grid=grid, level=level) _, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
data = gpd.read_parquet(data) print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
data = data[data["darts_has_coverage"]] print(f"{y_data.describe()=}")
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
y_data = data.loc[X_data.index, "darts_has_rts"]
X = X_data.to_numpy(dtype="float32") X = X_data.to_numpy(dtype="float32")
y = y_data.to_numpy(dtype="int8") y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}") print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
param_grid = { param_grid = {
@ -62,7 +101,21 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)
if task == "binary":
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
else:
metrics = [
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
"f1_macro",
"f1_weighted",
"precision_macro",
"precision_weighted",
"recall_macro",
"recall_weighted",
"jaccard_micro",
"jaccard_macro",
"jaccard_weighted",
]
search = RandomizedSearchCV( search = RandomizedSearchCV(
clf, clf,
param_grid, param_grid,
@ -72,7 +125,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
random_state=42, random_state=42,
verbose=3, verbose=3,
scoring=metrics, scoring=metrics,
refit="f1", refit="f1" if task == "binary" else "f1_weighted",
) )
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
@ -88,10 +141,11 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}")
results_dir = get_cv_results_dir("random_search", grid=grid, level=level) results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
# Store the search settings # Store the search settings
settings = { settings = {
"task": task,
"grid": grid, "grid": grid,
"level": level, "level": level,
"random_state": 42, "random_state": 42,
@ -115,7 +169,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
}, },
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": ["No RTS", "RTS"], "classes": labels,
} }
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
@ -144,7 +198,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
best_estimator = search.best_estimator_ best_estimator = search.best_estimator_
# Annotate the state with xarray metadata # Annotate the state with xarray metadata
features = X_data.columns.tolist() features = X_data.columns.tolist()
labels = y_data.unique().tolist()
boxes = list(range(best_estimator.K_)) boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray( box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(), best_estimator.S_.cpu().numpy(),
@ -185,7 +238,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"]) preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
preds_file = results_dir / "predicted_probabilities.parquet" preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}") print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file) preds.to_parquet(preds_file)

View file

@ -5,15 +5,93 @@ from pathlib import Path
import altair as alt import altair as alt
import geopandas as gpd import geopandas as gpd
import matplotlib.colors as mcolors
import matplotlib.path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
import streamlit_folium as st_folium import streamlit_folium as st_folium
import toml import toml
import ultraplot as uplt
import xarray as xr import xarray as xr
import xdggs import xdggs
from matplotlib.patches import PathPatch
from entropice.paths import RESULTS_DIR from entropice.paths import RESULTS_DIR
from entropice.training import create_xy_data
def generate_unified_colormap(settings: dict):
"""Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot,
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
of classes from the settings, then generating appropriate colormaps for each library.
Args:
settings: Settings dictionary containing task type, classes, and other configuration.
Returns:
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
- matplotlib_cmap: matplotlib ListedColormap object
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
- altair_colors: list of hex color strings for Altair
"""
# Determine task type and number of classes from settings
task = settings.get("task", "binary")
n_classes = len(settings.get("classes", []))
# Check theme
is_dark_theme = st.context.theme.type == "dark"
# Define base colormaps for different tasks
if task == "binary":
# For binary: use a simple two-color scheme
if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else:
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
else:
# For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap as base (better perceptual uniformity than inferno)
cmap = uplt.Colormap("viridis")
# Sample colors evenly across the colormap
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
# Create matplotlib colormap (for ultraplot and geopandas)
matplotlib_cmap = mcolors.ListedColormap(base_colors)
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
folium_cmap = mcolors.ListedColormap(base_colors)
# Create Altair color list (Altair uses hex color strings in range)
altair_colors = base_colors
return matplotlib_cmap, folium_cmap, altair_colors
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)
def get_available_result_files() -> list[Path]: def get_available_result_files() -> list[Path]:
@ -185,32 +263,229 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
return common_feature_array return common_feature_array
def _plot_prediction_map(preds: gpd.GeoDataFrame): def _plot_prediction_map_static(preds: gpd.GeoDataFrame, matplotlib_cmap: mcolors.ListedColormap):
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron") """Create a static map of predictions using ultraplot and cartopy.
Args:
preds: GeoDataFrame with predicted classes (intervals as strings).
matplotlib_cmap: Matplotlib ListedColormap object.
Returns:
Matplotlib figure with predictions colored by class.
"""
# Create a copy to avoid modifying the original data
preds_plot = preds.copy()
# Replace the special (-1, 0] interval with "No RTS"
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
# Create categorical with ordered categories
preds_plot["predicted_class"] = pd.Categorical(
preds_plot["predicted_class"], categories=unique_classes, ordered=True
)
# Detect theme for styling
theme = st.context.theme.type
if theme == "light":
figcolor = "#f8f9fa"
axcolor = "#ffffff"
oceancolor = "#e3f2fd"
landcolor = "#ecf0f1"
gridcolor = "#666666"
fgcolor = "#2c3e50"
coastcolor = "#34495e"
else:
figcolor = "#0e1117"
axcolor = "#0e1117"
oceancolor = "#1a1d29"
landcolor = "#262730"
gridcolor = "#ffffff"
fgcolor = "#fafafa"
coastcolor = "#d0d0d0"
# Create figure with North Polar Stereographic projection
# proj = uplt.Proj("npaeqd")
proj = uplt.Proj("npstere", lon_0=-45)
fig, ax = uplt.subplots(proj=proj, figsize=(12, 12), facecolor=figcolor)
# Put a background on the figure.
bgpatch = PathPatch(matplotlib.path.Path.unit_rectangle(), transform=fig.transFigure, color=figcolor, zorder=-1)
fig.patches.append(bgpatch)
# Apply theme-appropriate styling with enhanced aesthetics
ax.format(
boundinglat=50,
coast=True,
coastcolor=coastcolor,
coastlinewidth=0.8,
land=True,
landcolor=landcolor,
ocean=True,
oceancolor=oceancolor,
title="Predicted RTS Classes",
titlecolor=fgcolor,
titleweight="bold",
titlesize=14,
color=fgcolor,
gridcolor=gridcolor,
gridlinewidth=0.5,
gridlinestyle="-",
gridalpha=0.3,
facecolor=axcolor,
labels=True,
labelcolor=fgcolor,
latlines=10,
lonlines=30,
)
# Plot the predictions using the provided colormap with enhanced styling
preds_plot.to_crs(proj.proj4_init).plot(
ax=ax,
column="predicted_class",
cmap=matplotlib_cmap,
legend=True,
legend_kwds={
"loc": "lower left",
"frameon": True,
"framealpha": 0.95,
"edgecolor": gridcolor,
"facecolor": figcolor,
"fancybox": True,
"shadow": False,
"title": "RTS Class",
"title_fontsize": 11,
"fontsize": 9,
"labelspacing": 0.8,
"borderpad": 1.0,
"columnspacing": 1.0,
},
edgecolor="none",
linewidth=0,
)
# Enhance legend styling after creation
legend = ax.get_legend()
if legend:
legend.set_title("RTS Class", prop={"size": 11, "weight": "bold"})
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor(gridcolor)
legend.get_frame().set_facecolor(figcolor)
legend.get_frame().set_alpha(0.95)
# Set text color for legend labels
for text in legend.get_texts():
text.set_color(fgcolor)
# Set title color
legend.get_title().set_color(fgcolor)
return fig
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame): def _plot_prediction_map(preds: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
"""Create an interactive map of predictions with properly sorted classes.
Args:
preds: GeoDataFrame with predicted classes (intervals as strings).
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
Returns:
Folium map with predictions colored by class.
"""
# Create a copy to avoid modifying the original data
preds_plot = preds.copy()
# Replace the special (-1, 0] interval with "No RTS"
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
# Create categorical with ordered categories
preds_plot["predicted_class"] = pd.Categorical(
preds_plot["predicted_class"], categories=unique_classes, ordered=True
)
# Select tiles based on theme
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
return preds_plot.explore(column="predicted_class", cmap=folium_cmap, legend=True, tiles=tiles)
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame, altair_colors: list[str]):
"""Create a bar chart showing the count of each predicted class. """Create a bar chart showing the count of each predicted class.
Args: Args:
preds: GeoDataFrame with predicted classes. preds: GeoDataFrame with predicted classes.
altair_colors: List of hex color strings for altair.
Returns: Returns:
Altair chart showing class distribution. Altair chart showing class distribution.
""" """
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()}) # Create a copy and apply the same transformations as the map
preds_plot = preds.copy()
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
df = pd.DataFrame({"predicted_class": preds_plot["predicted_class"].to_numpy()})
counts = df["predicted_class"].value_counts().reset_index() counts = df["predicted_class"].value_counts().reset_index()
counts.columns = ["class", "count"] counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2) counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
chart = ( chart = (
alt.Chart(counts) alt.Chart(counts)
.mark_bar() .mark_bar()
.encode( .encode(
x=alt.X("class:N", title="Predicted Class"), x=alt.X("class:N", title="Predicted Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
y=alt.Y("count:Q", title="Number of Cells"), y=alt.Y("count:Q", title="Number of Cells"),
color=alt.Color("class:N", title="Class", legend=None), color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[ tooltip=[
alt.Tooltip("class:N", title="Class"), alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"), alt.Tooltip("count:Q", title="Count"),
@ -286,41 +561,36 @@ def _plot_k_binned(
return chart return chart
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str): def _plot_params_binned(results: pd.DataFrame, x: str, hue: str, col: str, metric: str):
"""Plot epsilon-binned results with K parameter.""" """Plot epsilon-binned results with K parameter."""
assert "initial_K" in results.columns, "initial_K column not found in results."
assert metric in results.columns, f"{metric} not found in results." assert metric in results.columns, f"{metric} not found in results."
assert x in ["eps_cl", "eps_e", "initial_K"]
assert hue in ["eps_cl", "eps_e", "initial_K"]
assert col in ["eps_cl_binned", "eps_e_binned", "initial_K_binned"]
if target == "eps_cl": assert x in results.columns, f"{x} column not found in results."
hue = "eps_cl"
col = "eps_e_binned"
elif target == "eps_e":
hue = "eps_e"
col = "eps_cl_binned"
else:
raise ValueError(f"Invalid target: {target}")
assert hue in results.columns, f"{hue} column not found in results." assert hue in results.columns, f"{hue} column not found in results."
assert col in results.columns, f"{col} column not found in results." assert col in results.columns, f"{col} column not found in results."
# Prepare data # Prepare data
plot_data = results[["initial_K", metric, hue, col]].copy() plot_data = results[[x, metric, hue, col]].copy()
# Sort bins by their left value and convert to string with sorted categories # Sort bins by their left value and convert to string with sorted categories
plot_data = plot_data.sort_values(col) plot_data = plot_data.sort_values(col)
plot_data[col] = plot_data[col].astype(str) plot_data[col] = plot_data[col].astype(str)
bin_order = plot_data[col].unique().tolist() bin_order = plot_data[col].unique().tolist()
xscale = alt.Scale(type="log") if x in ["eps_cl", "eps_e"] else alt.Scale()
cscheme = "bluepurple" if hue == "eps_e" else "purplered" if hue == "eps_cl" else "greenblue"
# Create the chart # Create the chart
chart = ( chart = (
alt.Chart(plot_data) alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7) .mark_circle(size=60, opacity=0.7)
.encode( .encode(
x=alt.X("initial_K:Q", title="Initial K"), x=alt.X(f"{x}:Q", title=x, scale=xscale),
y=alt.Y(f"{metric}:Q", title=metric), y=alt.Y(f"{metric}:Q", title=metric),
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue), color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme=cscheme), title=hue),
tooltip=[ tooltip=[
"initial_K:Q", f"{x}:Q",
alt.Tooltip(f"{metric}:Q", format=".4f"), alt.Tooltip(f"{metric}:Q", format=".4f"),
alt.Tooltip(f"{hue}:Q", format=".2e"), alt.Tooltip(f"{hue}:Q", format=".2e"),
f"{col}:N", f"{col}:N",
@ -334,11 +604,14 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
def _parse_results_dir_name(results_dir: Path) -> str: def _parse_results_dir_name(results_dir: Path) -> str:
gridname, date = results_dir.name.split("_random_search_cv") gridname, date = results_dir.name.replace("_binary", "").replace("_multi", "").split("_random_search_cv")
gridname = gridname.lstrip("permafrost_") gridname = gridname.lstrip("permafrost_")
date = datetime.strptime(date, "%Y%m%d-%H%M%S") date = datetime.strptime(date, "%Y%m%d-%H%M%S")
date = date.strftime("%Y-%m-%d %H:%M:%S") date = date.strftime("%Y-%m-%d %H:%M:%S")
return f"{gridname} ({date})"
settings = toml.load(results_dir / "search_settings.toml")["settings"]
task = settings.get("task", "binary")
return f"[{task.capitalize()}] {gridname.capitalize()} ({date})"
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10): def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
@ -667,11 +940,12 @@ def _plot_box_assignments(model_state: xr.Dataset):
return chart return chart
def _plot_box_assignment_bars(model_state: xr.Dataset): def _plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]):
"""Create a bar chart showing how many boxes are assigned to each class. """Create a bar chart showing how many boxes are assigned to each class.
Args: Args:
model_state: The xarray Dataset containing the model state with box_assignments. model_state: The xarray Dataset containing the model state with box_assignments.
altair_colors: List of hex color strings for altair.
Returns: Returns:
Altair chart showing count of boxes per class. Altair chart showing count of boxes per class.
@ -690,14 +964,40 @@ def _plot_box_assignment_bars(model_state: xr.Dataset):
# Count boxes per class # Count boxes per class
counts = primary_classes.groupby("class").size().reset_index(name="count") counts = primary_classes.groupby("class").size().reset_index(name="count")
# Replace the special (-1, 0] interval with "No RTS" if present
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(str(class_str).split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
# Create bar chart # Create bar chart
chart = ( chart = (
alt.Chart(counts) alt.Chart(counts)
.mark_bar() .mark_bar()
.encode( .encode(
x=alt.X("class:N", title="Class Label"), x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
y=alt.Y("count:Q", title="Number of Boxes"), y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color("class:N", title="Class", legend=None), color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[ tooltip=[
alt.Tooltip("class:N", title="Class"), alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Number of Boxes"), alt.Tooltip("count:Q", title="Number of Boxes"),
@ -758,7 +1058,7 @@ def _plot_common_features(common_array: xr.DataArray):
return chart return chart
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str): def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str, refit_metric: str):
"""Create a scatter plot comparing two metrics with parameter-based coloring. """Create a scatter plot comparing two metrics with parameter-based coloring.
Args: Args:
@ -766,6 +1066,7 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy'). x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard'). y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e'). color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
refit_metric: The metric used for refitting (e.g., 'f1' or 'f1_macro').
Returns: Returns:
Altair chart showing the metric comparison. Altair chart showing the metric comparison.
@ -781,17 +1082,17 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
alt.Chart(results) alt.Chart(results)
.mark_circle(size=60, opacity=0.7) .mark_circle(size=60, opacity=0.7)
.encode( .encode(
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)), x=alt.X(f"mean_test_{x_metric}:Q", title=format_metric_name(x_metric), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)), y=alt.Y(f"mean_test_{y_metric}:Q", title=format_metric_name(y_metric), scale=alt.Scale(zero=False)),
color=alt.Color( color=alt.Color(
f"{color_param}:Q", f"{color_param}:Q",
scale=color_scale, scale=color_scale,
title=color_param, title=color_param,
), ),
tooltip=[ tooltip=[
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()), alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=format_metric_name(x_metric)),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()), alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=format_metric_name(y_metric)),
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"), alt.Tooltip(f"mean_test_{refit_metric}:Q", format=".4f", title=format_metric_name(refit_metric)),
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"), alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"), alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"), alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
@ -807,6 +1108,148 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
return chart return chart
def _load_training_data(grid: str, level: int, task: str):
"""Load training data for analysis.
Args:
grid: Grid type (hex or healpix).
level: Grid level.
task: Classification task type (binary or multi).
Returns:
Tuple of (GeoDataFrame with training data, feature DataFrame, labels Series, label names list).
"""
# Use create_xy_data to get X and y data
data, x_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
# Convert y_data to string labels for visualization
if task == "binary":
y_data_str = y_data.map({False: "No RTS", True: "RTS"})
else:
# For multi-class, reconstruct the interval labels from the codes
y_data_counts = data.loc[x_data.index, "darts_rts_count"]
n_categories = 5
bins = pd.qcut(y_data_counts, q=n_categories, duplicates="drop").unique().categories
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
y_data_str = pd.cut(y_data_counts, bins=bins).astype(str)
# Create a GeoDataFrame with geometry and labels
training_gdf = data.loc[x_data.index, ["geometry"]].copy()
training_gdf["target_class"] = y_data_str
return training_gdf, x_data, y_data_str, labels
def _plot_training_data_map(training_gdf: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
"""Create an interactive map of training data with properly sorted classes.
Args:
training_gdf: GeoDataFrame with target classes (intervals as strings or category names).
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
Returns:
Folium map with training data colored by class.
"""
# Create a copy to avoid modifying the original data
training_plot = training_gdf.copy()
# Replace the special (-1, 0] interval with "No RTS" if present
training_plot["target_class"] = training_plot["target_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(training_plot["target_class"].unique(), key=sort_key)
# Create categorical with ordered categories
training_plot["target_class"] = pd.Categorical(
training_plot["target_class"], categories=unique_classes, ordered=True
)
# Select tiles based on theme
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
return training_plot.explore(column="target_class", cmap=folium_cmap, legend=True, tiles=tiles)
def _plot_training_class_distribution(y_data: pd.Series, altair_colors: list[str]):
"""Create a bar chart showing the count of each class in training data.
Args:
y_data: Series with target classes.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing class distribution.
"""
# Create a copy and apply the same transformations as the map
y_data_plot = y_data.copy()
y_data_plot = y_data_plot.replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
df = pd.DataFrame({"target_class": y_data_plot.to_numpy()})
counts = df["target_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Target Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
y=alt.Y("count:Q", title="Number of Samples"),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
],
)
.properties(
width=400,
height=300,
title="Training Data Class Distribution",
)
)
return chart
def main(): def main():
"""Run Streamlit dashboard application.""" """Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide") st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -846,7 +1289,19 @@ def main():
common_feature_array = extract_common_features(model_state) common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413") predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
# Determine task type and available metrics
task = settings.get("task", "binary")
available_metrics = settings.get("metrics", ["accuracy", "recall", "precision", "f1", "jaccard"])
refit_metric = "f1" if task == "binary" else "f1_macro"
# Generate unified colormaps once for all visualizations
matplotlib_cmap, folium_cmap, altair_colors = generate_unified_colormap(settings)
st.sidebar.success(f"Loaded {len(results)} results") st.sidebar.success(f"Loaded {len(results)} results")
st.sidebar.info(f"Task: {task.capitalize()} Classification")
# Dump the settings into the sidebar
with st.sidebar.expander("Search Settings", expanded=True):
st.json(settings)
# Display some basic statistics first (lightweight) # Display some basic statistics first (lightweight)
st.header("Parameter-Search Overview") st.header("Parameter-Search Overview")
@ -857,16 +1312,18 @@ def main():
st.metric("Total Runs", len(results)) st.metric("Total Runs", len(results))
with col2: with col2:
# Best model based on F1 score # Best model based on refit metric
best_f1_idx = results["mean_test_f1"].idxmax() best_idx = results[f"mean_test_{refit_metric}"].idxmax()
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}") st.metric(f"Best Model Index (by {format_metric_name(refit_metric)})", f"#{best_idx}")
# Show best parameters for the best model # Show best parameters for the best model
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]] best_params = results.loc[
best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{refit_metric}", f"std_test_{refit_metric}"]
]
with st.container(border=True): with st.container(border=True):
st.subheader(":abacus: Best Model Parameters") st.subheader(":abacus: Best Model Parameters")
st.caption("Parameters of retrained best model (selected by F1 score)") st.caption(f"Parameters of retrained best model (selected by {format_metric_name(refit_metric)} score)")
col1, col2, col3 = st.columns(3) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.metric("initial_K", f"{best_params['initial_K']:.0f}") st.metric("initial_K", f"{best_params['initial_K']:.0f}")
@ -877,32 +1334,132 @@ def main():
# Show all metrics for the best model in a container # Show all metrics for the best model in a container
st.subheader(":bar_chart: Performance Across All Metrics") st.subheader(":bar_chart: Performance Across All Metrics")
st.caption("Complete performance profile of the best model (selected by F1 score)") st.caption(
f"Complete performance profile of the best model (selected by {format_metric_name(refit_metric)} score)"
)
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
cols = st.columns(len(available_metrics)) cols = st.columns(len(available_metrics))
for idx, metric in enumerate(available_metrics): for idx, metric in enumerate(available_metrics):
with cols[idx]: with cols[idx]:
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"] best_score = results.loc[best_idx, f"mean_test_{metric}"]
best_std = results.loc[best_f1_idx, f"std_test_{metric}"] best_std = results.loc[best_idx, f"std_test_{metric}"]
# Highlight F1 since that's what we optimized for # Highlight refit metric since that's what we optimized for
st.metric( st.metric(
f"{metric.capitalize()}", format_metric_name(metric),
f"{best_score:.4f}", f"{best_score:.4f}",
delta=f"±{best_std:.4f}", delta=f"±{best_std:.4f}",
help="Mean ± std across cross-validation folds", help="Mean ± std across cross-validation folds",
) )
# Create tabs for different visualizations # Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"]) tab0, tab1, tab2, tab3 = st.tabs(["Training Data", "Search Results", "Model State", "Inference Analysis"])
with tab0:
# Training Data Analysis
st.header("Training Data Analysis")
st.markdown("Comprehensive analysis of the training dataset used for model development")
# Load training data
with st.spinner("Loading training data..."):
training_gdf, X_data, y_data, _ = _load_training_data(
grid=settings["grid"], level=settings["level"], task=task
)
# Summary statistics
st.subheader("Dataset Statistics")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Samples", f"{len(training_gdf):,}")
with col2:
st.metric("Number of Features", f"{X_data.shape[1]:,}")
with col3:
n_classes = y_data.nunique()
st.metric("Number of Classes", n_classes)
with col4:
missing_pct = X_data.isnull().sum().sum() / (X_data.shape[0] * X_data.shape[1]) * 100
st.metric("Missing Values", f"{missing_pct:.2f}%")
# Class distribution visualization
st.subheader("Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_training_class_distribution(y_data, altair_colors)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between different classes in the training dataset
- Class imbalance affects model learning and may require special handling
- Each bar represents the count of training samples for that class
"""
)
# Interactive map
st.subheader("Interactive Training Data Map")
st.markdown("Explore the spatial distribution of training samples by class")
with st.spinner("Generating interactive map..."):
training_map = _plot_training_data_map(training_gdf, folium_cmap)
st_folium.st_folium(training_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Training Data Statistics"):
st.write("**Class Distribution:**")
class_counts = y_data.value_counts().sort_index()
# Create columns for better layout
n_cols = min(5, len(class_counts))
cols = st.columns(n_cols)
for idx, (class_label, count) in enumerate(class_counts.items()):
percentage = count / len(y_data) * 100
with cols[idx % n_cols]:
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
# Show detailed table
st.write("**Detailed Class Breakdown:**")
class_df = pd.DataFrame(
{
"Class": class_counts.index,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(y_data) * 100).round(2),
}
)
st.dataframe(class_df, width="stretch", hide_index=True)
# Feature statistics
st.write("**Feature Statistics:**")
st.markdown(f"- Total number of features: **{X_data.shape[1]}**")
st.markdown(f"- Features with missing values: **{X_data.isnull().any().sum()}**")
# Show feature types breakdown
feature_types = X_data.columns.to_series().apply(lambda x: x.split("_")[0]).value_counts()
st.write("**Feature Type Distribution:**")
feature_type_df = pd.DataFrame(
{
"Feature Type": feature_types.index,
"Count": feature_types.to_numpy(),
}
)
st.dataframe(feature_type_df, width="stretch", hide_index=True)
with tab1: with tab1:
# Metric selection - only used in this tab # Metric selection - only used in this tab
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] metric_display_names = {metric: format_metric_name(metric) for metric in available_metrics}
selected_metric = st.selectbox( selected_metric_display = st.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize" "Select Metric",
options=[format_metric_name(m) for m in available_metrics],
help="Choose which metric to visualize",
) )
# Convert back to raw metric name
selected_metric = next(k for k, v in metric_display_names.items() if v == selected_metric_display)
# Show best parameters # Show best parameters
with st.expander("Best Parameters"): with st.expander("Best Parameters"):
@ -911,7 +1468,7 @@ def main():
st.dataframe(best_params.to_frame().T, width="content") st.dataframe(best_params.to_frame().T, width="content")
# Main plots # Main plots
st.header(f"Visualization for {selected_metric.capitalize()}") st.header(f"Visualization for {format_metric_name(selected_metric)}")
# K-binned plot configuration # K-binned plot configuration
@st.fragment @st.fragment
@ -978,17 +1535,35 @@ def main():
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
with col1: with col1:
st.subheader("K vs eps_cl") st.subheader(f"K vs {selected_metric} (binned by eps_e)")
with st.spinner("Generating eps_cl plot..."): with st.spinner("Generating plot..."):
chart3 = _plot_eps_binned(results_binned, "eps_cl", f"mean_test_{selected_metric}") chart3 = _plot_params_binned(
results_binned, "initial_K", "eps_cl", "eps_e_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart3, use_container_width=True) st.altair_chart(chart3, use_container_width=True)
with col2: st.subheader(f"eps_e vs {selected_metric} (binned by K)")
st.subheader("K vs eps_e") with st.spinner("Generating plot..."):
with st.spinner("Generating eps_e plot..."): chart4 = _plot_params_binned(
chart4 = _plot_eps_binned(results_binned, "eps_e", f"mean_test_{selected_metric}") results_binned, "eps_e", "eps_cl", "initial_K_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart4, use_container_width=True) st.altair_chart(chart4, use_container_width=True)
with col2:
st.subheader(f"K vs {selected_metric} (binned by eps_cl)")
with st.spinner("Generating plot..."):
chart5 = _plot_params_binned(
results_binned, "initial_K", "eps_e", "eps_cl_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart5, use_container_width=True)
st.subheader(f"eps_cl vs {selected_metric} (binned by K)")
with st.spinner("Generating plot..."):
chart6 = _plot_params_binned(
results_binned, "eps_cl", "eps_e", "initial_K_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart6, use_container_width=True)
render_k_binned_plots() render_k_binned_plots()
# Metric comparison plots # Metric comparison plots
@ -1003,19 +1578,37 @@ def main():
help="Choose which parameter to use for coloring the scatter plots", help="Choose which parameter to use for coloring the scatter plots",
) )
# Dynamically determine which metrics to compare based on available metrics
if task == "binary":
# For binary: show recall vs precision and accuracy vs jaccard
comparisons = [
("precision", "recall", "Recall vs Precision"),
("accuracy", "jaccard", "Accuracy vs Jaccard"),
]
else:
# For multiclass: show micro vs macro variants
comparisons = [
("precision_macro", "recall_macro", "Recall Macro vs Precision Macro"),
("accuracy", "jaccard_macro", "Accuracy vs Jaccard Macro"),
]
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
with col1: with col1:
st.subheader("Recall vs Precision") st.subheader(comparisons[0][2])
with st.spinner("Generating Recall vs Precision plot..."): with st.spinner(f"Generating {comparisons[0][2]} plot..."):
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param) chart1 = _plot_metric_comparison(
st.altair_chart(recall_precision_chart, use_container_width=True) results, comparisons[0][0], comparisons[0][1], color_param, refit_metric
)
st.altair_chart(chart1, use_container_width=True)
with col2: with col2:
st.subheader("Accuracy vs Jaccard") st.subheader(comparisons[1][2])
with st.spinner("Generating Accuracy vs Jaccard plot..."): with st.spinner(f"Generating {comparisons[1][2]} plot..."):
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param) chart2 = _plot_metric_comparison(
st.altair_chart(accuracy_jaccard_chart, use_container_width=True) results, comparisons[1][0], comparisons[1][1], color_param, refit_metric
)
st.altair_chart(chart2, use_container_width=True)
render_metric_comparisons() render_metric_comparisons()
@ -1094,7 +1687,7 @@ def main():
with col2: with col2:
st.markdown("### Box Count by Class") st.markdown("### Box Count by Class")
box_assignment_bars = _plot_box_assignment_bars(model_state) box_assignment_bars = _plot_box_assignment_bars(model_state, altair_colors)
st.altair_chart(box_assignment_bars, use_container_width=True) st.altair_chart(box_assignment_bars, use_container_width=True)
# Show statistics # Show statistics
@ -1302,30 +1895,44 @@ def main():
n_classes = predictions["predicted_class"].nunique() n_classes = predictions["predicted_class"].nunique()
st.metric("Number of Classes", n_classes) st.metric("Number of Classes", n_classes)
# Class distribution visualization # Class distribution and static map side by side
st.subheader("Class Distribution") st.subheader("Prediction Overview")
col_map, col_dist = st.columns([0.6, 0.4])
with col_map:
st.markdown("#### Static Prediction Map")
st.markdown("High-quality map with North Polar Stereographic projection")
with st.spinner("Generating static map..."):
static_fig = _plot_prediction_map_static(predictions, matplotlib_cmap)
st.pyplot(static_fig, width="stretch")
st.markdown(
"""
**Map Features:**
- North Polar Azimuthal Equidistant projection (optimized for Arctic)
- Predictions colored by class using inferno colormap
- Includes coastlines and 50°N latitude boundary
"""
)
with col_dist:
st.markdown("#### Class Distribution")
with st.spinner("Generating class distribution..."): with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_prediction_class_distribution(predictions) class_dist_chart = _plot_prediction_class_distribution(predictions, altair_colors)
st.altair_chart(class_dist_chart, use_container_width=True) st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown( st.markdown(
""" """
**Interpretation:** **Interpretation:**
- Shows the balance between predicted classes - Balance between predicted classes
- Class imbalance may indicate regional patterns or model bias - Class imbalance may indicate regional patterns
- Each bar represents the count of cells predicted for that class - Each bar shows cell count per class
""" """
) )
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
with st.spinner("Generating interactive map..."):
chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander # Additional statistics in expander
with st.expander("Detailed Prediction Statistics"): with st.expander("Detailed Prediction Statistics"):
st.write("**Class Distribution:**") st.write("**Class Distribution:**")
@ -1351,6 +1958,17 @@ def main():
) )
st.dataframe(class_df, width="stretch", hide_index=True) st.dataframe(class_df, width="stretch", hide_index=True)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
# with st.spinner("Generating interactive map..."):
# chart_map = _plot_prediction_map(predictions, folium_colors)
# st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
st.text("Interactive map functionality is currently disabled.")
st.balloons()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

4742
uv.lock generated

File diff suppressed because it is too large Load diff