Add multiclass training

This commit is contained in:
Tobias Hölzer 2025-11-09 22:34:03 +01:00
parent 553b54bb32
commit d5b35d6da4
7 changed files with 814 additions and 4867 deletions

11
pixi.lock generated
View file

@ -285,6 +285,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/81/ec/8bdccea3ff7d557601183581340c3768b7bb7b1e32c8991f1130f0c1e2c4/spectate-1.0.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7a/31/7d601cc639b0362a213552a838af601105591598a4b08ec80666458083d2/stopuhr-0.0.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl
@ -1301,7 +1302,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: 9d3fd2f5a282082c9205df502797c350d94b3c8b588fe7d1662f5169589925a9
sha256: 852c87cdbd1d452fccaa6253c6ce7410dda9fa32d2f951d35441e454414acfc0
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -1350,6 +1351,7 @@ packages:
- altair[all]>=5.5.0,<6
- h5netcdf>=1.7.3,<2
- streamlit-folium>=0.25.3,<0.26
- st-theme>=1.2.3,<2
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
@ -4507,6 +4509,13 @@ packages:
version: 1.0.1
sha256: c4585194c238979f953fbf2ecf9f94c84d9d0a929432c7104e39984f52c9e718
requires_python: '>=3.6'
- pypi: https://files.pythonhosted.org/packages/30/09/cd7134f1ed5074a7d456640e7ba9a8c8e68a831837b4e7bfd9f29e5700a4/st_theme-1.2.3-py3-none-any.whl
name: st-theme
version: 1.2.3
sha256: 0a54d9817dd5f8a6d7b0d071b25ae72eacf536c63a5fb97374923938021b1389
requires_dist:
- streamlit>=1.33
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
name: stack-data
version: 0.6.3

View file

@ -51,7 +51,7 @@ dependencies = [
"geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
]
[project.scripts]

View file

@ -67,6 +67,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")

View file

@ -87,9 +87,14 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file
def get_cv_results_dir(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
def get_cv_results_dir(
name: str,
grid: Literal["hex", "healpix"],
level: int,
task: Literal["binary", "multi"],
) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_binary"
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir

View file

@ -29,7 +29,50 @@ pretty.install()
set_config(array_api_dispatch=True)
def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, robust: bool = False):
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
"""Create X and y data from the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
Returns:
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
data = data[data["darts_has_coverage"]]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts", "darts_rts_count"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
if task == "binary":
labels = ["No RTS", "RTS"]
y_data = data.loc[X_data.index, "darts_has_rts"]
else:
# Put into n categories (log scaled)
y_data = data.loc[X_data.index, "darts_rts_count"]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
# Change the first interval to start at 1 and add a category for 0
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
return data, X_data, y_data, labels
def random_cv(
grid: Literal["hex", "healpix"],
level: int,
n_iter: int = 2000,
robust: bool = False,
task: Literal["binary", "multi"] = "binary",
):
"""Perform random cross-validation on the training dataset.
Args:
@ -37,21 +80,17 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
data = data[data["darts_has_coverage"]]
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
y_data = data.loc[X_data.index, "darts_has_rts"]
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
print(f"{y_data.describe()=}")
X = X_data.to_numpy(dtype="float32")
y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
param_grid = {
@ -62,7 +101,21 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
cv = KFold(n_splits=5, shuffle=True, random_state=42)
if task == "binary":
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
else:
metrics = [
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
"f1_macro",
"f1_weighted",
"precision_macro",
"precision_weighted",
"recall_macro",
"recall_weighted",
"jaccard_micro",
"jaccard_macro",
"jaccard_weighted",
]
search = RandomizedSearchCV(
clf,
param_grid,
@ -72,7 +125,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
random_state=42,
verbose=3,
scoring=metrics,
refit="f1",
refit="f1" if task == "binary" else "f1_weighted",
)
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
@ -88,10 +141,11 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
results_dir = get_cv_results_dir("random_search", grid=grid, level=level)
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
# Store the search settings
settings = {
"task": task,
"grid": grid,
"level": level,
"random_state": 42,
@ -115,7 +169,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
},
"cv_splits": cv.get_n_splits(),
"metrics": metrics,
"classes": ["No RTS", "RTS"],
"classes": labels,
}
settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}")
@ -144,7 +198,6 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
best_estimator = search.best_estimator_
# Annotate the state with xarray metadata
features = X_data.columns.tolist()
labels = y_data.unique().tolist()
boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(),
@ -185,7 +238,7 @@ def random_cv(grid: Literal["hex", "healpix"], level: int, n_iter: int = 2000, r
# Predict probabilities for all cells
print("Predicting probabilities for all cells...")
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=settings["classes"])
preds = predict_proba(grid=grid, level=level, clf=best_estimator, classes=labels)
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file)

View file

@ -5,15 +5,93 @@ from pathlib import Path
import altair as alt
import geopandas as gpd
import matplotlib.colors as mcolors
import matplotlib.path
import numpy as np
import pandas as pd
import streamlit as st
import streamlit_folium as st_folium
import toml
import ultraplot as uplt
import xarray as xr
import xdggs
from matplotlib.patches import PathPatch
from entropice.paths import RESULTS_DIR
from entropice.training import create_xy_data
def generate_unified_colormap(settings: dict):
"""Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot,
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
of classes from the settings, then generating appropriate colormaps for each library.
Args:
settings: Settings dictionary containing task type, classes, and other configuration.
Returns:
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
- matplotlib_cmap: matplotlib ListedColormap object
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
- altair_colors: list of hex color strings for Altair
"""
# Determine task type and number of classes from settings
task = settings.get("task", "binary")
n_classes = len(settings.get("classes", []))
# Check theme
is_dark_theme = st.context.theme.type == "dark"
# Define base colormaps for different tasks
if task == "binary":
# For binary: use a simple two-color scheme
if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else:
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
else:
# For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap as base (better perceptual uniformity than inferno)
cmap = uplt.Colormap("viridis")
# Sample colors evenly across the colormap
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
# Create matplotlib colormap (for ultraplot and geopandas)
matplotlib_cmap = mcolors.ListedColormap(base_colors)
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
folium_cmap = mcolors.ListedColormap(base_colors)
# Create Altair color list (Altair uses hex color strings in range)
altair_colors = base_colors
return matplotlib_cmap, folium_cmap, altair_colors
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)
def get_available_result_files() -> list[Path]:
@ -185,32 +263,229 @@ def extract_common_features(model_state: xr.Dataset) -> xr.DataArray | None:
return common_feature_array
def _plot_prediction_map(preds: gpd.GeoDataFrame):
return preds.explore(column="predicted_class", cmap="Set3", legend=True, tiles="CartoDB positron")
def _plot_prediction_map_static(preds: gpd.GeoDataFrame, matplotlib_cmap: mcolors.ListedColormap):
"""Create a static map of predictions using ultraplot and cartopy.
Args:
preds: GeoDataFrame with predicted classes (intervals as strings).
matplotlib_cmap: Matplotlib ListedColormap object.
Returns:
Matplotlib figure with predictions colored by class.
"""
# Create a copy to avoid modifying the original data
preds_plot = preds.copy()
# Replace the special (-1, 0] interval with "No RTS"
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
# Create categorical with ordered categories
preds_plot["predicted_class"] = pd.Categorical(
preds_plot["predicted_class"], categories=unique_classes, ordered=True
)
# Detect theme for styling
theme = st.context.theme.type
if theme == "light":
figcolor = "#f8f9fa"
axcolor = "#ffffff"
oceancolor = "#e3f2fd"
landcolor = "#ecf0f1"
gridcolor = "#666666"
fgcolor = "#2c3e50"
coastcolor = "#34495e"
else:
figcolor = "#0e1117"
axcolor = "#0e1117"
oceancolor = "#1a1d29"
landcolor = "#262730"
gridcolor = "#ffffff"
fgcolor = "#fafafa"
coastcolor = "#d0d0d0"
# Create figure with North Polar Stereographic projection
# proj = uplt.Proj("npaeqd")
proj = uplt.Proj("npstere", lon_0=-45)
fig, ax = uplt.subplots(proj=proj, figsize=(12, 12), facecolor=figcolor)
# Put a background on the figure.
bgpatch = PathPatch(matplotlib.path.Path.unit_rectangle(), transform=fig.transFigure, color=figcolor, zorder=-1)
fig.patches.append(bgpatch)
# Apply theme-appropriate styling with enhanced aesthetics
ax.format(
boundinglat=50,
coast=True,
coastcolor=coastcolor,
coastlinewidth=0.8,
land=True,
landcolor=landcolor,
ocean=True,
oceancolor=oceancolor,
title="Predicted RTS Classes",
titlecolor=fgcolor,
titleweight="bold",
titlesize=14,
color=fgcolor,
gridcolor=gridcolor,
gridlinewidth=0.5,
gridlinestyle="-",
gridalpha=0.3,
facecolor=axcolor,
labels=True,
labelcolor=fgcolor,
latlines=10,
lonlines=30,
)
# Plot the predictions using the provided colormap with enhanced styling
preds_plot.to_crs(proj.proj4_init).plot(
ax=ax,
column="predicted_class",
cmap=matplotlib_cmap,
legend=True,
legend_kwds={
"loc": "lower left",
"frameon": True,
"framealpha": 0.95,
"edgecolor": gridcolor,
"facecolor": figcolor,
"fancybox": True,
"shadow": False,
"title": "RTS Class",
"title_fontsize": 11,
"fontsize": 9,
"labelspacing": 0.8,
"borderpad": 1.0,
"columnspacing": 1.0,
},
edgecolor="none",
linewidth=0,
)
# Enhance legend styling after creation
legend = ax.get_legend()
if legend:
legend.set_title("RTS Class", prop={"size": 11, "weight": "bold"})
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor(gridcolor)
legend.get_frame().set_facecolor(figcolor)
legend.get_frame().set_alpha(0.95)
# Set text color for legend labels
for text in legend.get_texts():
text.set_color(fgcolor)
# Set title color
legend.get_title().set_color(fgcolor)
return fig
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame):
def _plot_prediction_map(preds: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
"""Create an interactive map of predictions with properly sorted classes.
Args:
preds: GeoDataFrame with predicted classes (intervals as strings).
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
Returns:
Folium map with predictions colored by class.
"""
# Create a copy to avoid modifying the original data
preds_plot = preds.copy()
# Replace the special (-1, 0] interval with "No RTS"
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(preds_plot["predicted_class"].unique(), key=sort_key)
# Create categorical with ordered categories
preds_plot["predicted_class"] = pd.Categorical(
preds_plot["predicted_class"], categories=unique_classes, ordered=True
)
# Select tiles based on theme
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
return preds_plot.explore(column="predicted_class", cmap=folium_cmap, legend=True, tiles=tiles)
def _plot_prediction_class_distribution(preds: gpd.GeoDataFrame, altair_colors: list[str]):
"""Create a bar chart showing the count of each predicted class.
Args:
preds: GeoDataFrame with predicted classes.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing class distribution.
"""
df = pd.DataFrame({"predicted_class": preds["predicted_class"].to_numpy()})
# Create a copy and apply the same transformations as the map
preds_plot = preds.copy()
preds_plot["predicted_class"] = preds_plot["predicted_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
df = pd.DataFrame({"predicted_class": preds_plot["predicted_class"].to_numpy()})
counts = df["predicted_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Predicted Class"),
x=alt.X("class:N", title="Predicted Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
y=alt.Y("count:Q", title="Number of Cells"),
color=alt.Color("class:N", title="Class", legend=None),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
@ -286,41 +561,36 @@ def _plot_k_binned(
return chart
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
def _plot_params_binned(results: pd.DataFrame, x: str, hue: str, col: str, metric: str):
"""Plot epsilon-binned results with K parameter."""
assert "initial_K" in results.columns, "initial_K column not found in results."
assert metric in results.columns, f"{metric} not found in results."
assert x in ["eps_cl", "eps_e", "initial_K"]
assert hue in ["eps_cl", "eps_e", "initial_K"]
assert col in ["eps_cl_binned", "eps_e_binned", "initial_K_binned"]
if target == "eps_cl":
hue = "eps_cl"
col = "eps_e_binned"
elif target == "eps_e":
hue = "eps_e"
col = "eps_cl_binned"
else:
raise ValueError(f"Invalid target: {target}")
assert x in results.columns, f"{x} column not found in results."
assert hue in results.columns, f"{hue} column not found in results."
assert col in results.columns, f"{col} column not found in results."
# Prepare data
plot_data = results[["initial_K", metric, hue, col]].copy()
plot_data = results[[x, metric, hue, col]].copy()
# Sort bins by their left value and convert to string with sorted categories
plot_data = plot_data.sort_values(col)
plot_data[col] = plot_data[col].astype(str)
bin_order = plot_data[col].unique().tolist()
xscale = alt.Scale(type="log") if x in ["eps_cl", "eps_e"] else alt.Scale()
cscheme = "bluepurple" if hue == "eps_e" else "purplered" if hue == "eps_cl" else "greenblue"
# Create the chart
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X("initial_K:Q", title="Initial K"),
x=alt.X(f"{x}:Q", title=x, scale=xscale),
y=alt.Y(f"{metric}:Q", title=metric),
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme="viridis"), title=hue),
color=alt.Color(f"{hue}:Q", scale=alt.Scale(type="log", scheme=cscheme), title=hue),
tooltip=[
"initial_K:Q",
f"{x}:Q",
alt.Tooltip(f"{metric}:Q", format=".4f"),
alt.Tooltip(f"{hue}:Q", format=".2e"),
f"{col}:N",
@ -334,11 +604,14 @@ def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
def _parse_results_dir_name(results_dir: Path) -> str:
gridname, date = results_dir.name.split("_random_search_cv")
gridname, date = results_dir.name.replace("_binary", "").replace("_multi", "").split("_random_search_cv")
gridname = gridname.lstrip("permafrost_")
date = datetime.strptime(date, "%Y%m%d-%H%M%S")
date = date.strftime("%Y-%m-%d %H:%M:%S")
return f"{gridname} ({date})"
settings = toml.load(results_dir / "search_settings.toml")["settings"]
task = settings.get("task", "binary")
return f"[{task.capitalize()}] {gridname.capitalize()} ({date})"
def _plot_top_features(model_state: xr.Dataset, top_n: int = 10):
@ -667,11 +940,12 @@ def _plot_box_assignments(model_state: xr.Dataset):
return chart
def _plot_box_assignment_bars(model_state: xr.Dataset):
def _plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]):
"""Create a bar chart showing how many boxes are assigned to each class.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing count of boxes per class.
@ -690,14 +964,40 @@ def _plot_box_assignment_bars(model_state: xr.Dataset):
# Count boxes per class
counts = primary_classes.groupby("class").size().reset_index(name="count")
# Replace the special (-1, 0] interval with "No RTS" if present
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(str(class_str).split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
# Create bar chart
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Class Label"),
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color("class:N", title="Class", legend=None),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Number of Boxes"),
@ -758,7 +1058,7 @@ def _plot_common_features(common_array: xr.DataArray):
return chart
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str):
def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str, color_param: str, refit_metric: str):
"""Create a scatter plot comparing two metrics with parameter-based coloring.
Args:
@ -766,6 +1066,7 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
x_metric: Name of the metric to plot on x-axis (e.g., 'precision', 'accuracy').
y_metric: Name of the metric to plot on y-axis (e.g., 'recall', 'jaccard').
color_param: Parameter to use for coloring ('initial_K', 'eps_cl', or 'eps_e').
refit_metric: The metric used for refitting (e.g., 'f1' or 'f1_macro').
Returns:
Altair chart showing the metric comparison.
@ -781,17 +1082,17 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
alt.Chart(results)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"mean_test_{x_metric}:Q", title=x_metric.capitalize(), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=y_metric.capitalize(), scale=alt.Scale(zero=False)),
x=alt.X(f"mean_test_{x_metric}:Q", title=format_metric_name(x_metric), scale=alt.Scale(zero=False)),
y=alt.Y(f"mean_test_{y_metric}:Q", title=format_metric_name(y_metric), scale=alt.Scale(zero=False)),
color=alt.Color(
f"{color_param}:Q",
scale=color_scale,
title=color_param,
),
tooltip=[
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=x_metric.capitalize()),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=y_metric.capitalize()),
alt.Tooltip("mean_test_f1:Q", format=".4f", title="F1"),
alt.Tooltip(f"mean_test_{x_metric}:Q", format=".4f", title=format_metric_name(x_metric)),
alt.Tooltip(f"mean_test_{y_metric}:Q", format=".4f", title=format_metric_name(y_metric)),
alt.Tooltip(f"mean_test_{refit_metric}:Q", format=".4f", title=format_metric_name(refit_metric)),
alt.Tooltip("initial_K:Q", format=".0f", title="initial_K"),
alt.Tooltip("eps_cl:Q", format=".2e", title="eps_cl"),
alt.Tooltip("eps_e:Q", format=".2e", title="eps_e"),
@ -807,6 +1108,148 @@ def _plot_metric_comparison(results: pd.DataFrame, x_metric: str, y_metric: str,
return chart
def _load_training_data(grid: str, level: int, task: str):
"""Load training data for analysis.
Args:
grid: Grid type (hex or healpix).
level: Grid level.
task: Classification task type (binary or multi).
Returns:
Tuple of (GeoDataFrame with training data, feature DataFrame, labels Series, label names list).
"""
# Use create_xy_data to get X and y data
data, x_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
# Convert y_data to string labels for visualization
if task == "binary":
y_data_str = y_data.map({False: "No RTS", True: "RTS"})
else:
# For multi-class, reconstruct the interval labels from the codes
y_data_counts = data.loc[x_data.index, "darts_rts_count"]
n_categories = 5
bins = pd.qcut(y_data_counts, q=n_categories, duplicates="drop").unique().categories
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
y_data_str = pd.cut(y_data_counts, bins=bins).astype(str)
# Create a GeoDataFrame with geometry and labels
training_gdf = data.loc[x_data.index, ["geometry"]].copy()
training_gdf["target_class"] = y_data_str
return training_gdf, x_data, y_data_str, labels
def _plot_training_data_map(training_gdf: gpd.GeoDataFrame, folium_cmap: mcolors.ListedColormap):
"""Create an interactive map of training data with properly sorted classes.
Args:
training_gdf: GeoDataFrame with target classes (intervals as strings or category names).
folium_cmap: Matplotlib ListedColormap object (for geopandas.explore).
Returns:
Folium map with training data colored by class.
"""
# Create a copy to avoid modifying the original data
training_plot = training_gdf.copy()
# Replace the special (-1, 0] interval with "No RTS" if present
training_plot["target_class"] = training_plot["target_class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Get unique classes and sort them
unique_classes = sorted(training_plot["target_class"].unique(), key=sort_key)
# Create categorical with ordered categories
training_plot["target_class"] = pd.Categorical(
training_plot["target_class"], categories=unique_classes, ordered=True
)
# Select tiles based on theme
tiles = "CartoDB dark_matter" if st.context.theme.type == "dark" else "CartoDB positron"
return training_plot.explore(column="target_class", cmap=folium_cmap, legend=True, tiles=tiles)
def _plot_training_class_distribution(y_data: pd.Series, altair_colors: list[str]):
"""Create a bar chart showing the count of each class in training data.
Args:
y_data: Series with target classes.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing class distribution.
"""
# Create a copy and apply the same transformations as the map
y_data_plot = y_data.copy()
y_data_plot = y_data_plot.replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(class_str.split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
df = pd.DataFrame({"target_class": y_data_plot.to_numpy()})
counts = df["target_class"].value_counts().reset_index()
counts.columns = ["class", "count"]
counts["percentage"] = (counts["count"] / counts["count"].sum() * 100).round(2)
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Target Class", sort=class_order, axis=alt.Axis(labelAngle=0)),
y=alt.Y("count:Q", title="Number of Samples"),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Count"),
alt.Tooltip("percentage:Q", format=".2f", title="Percentage (%)"),
],
)
.properties(
width=400,
height=300,
title="Training Data Class Distribution",
)
)
return chart
def main():
"""Run Streamlit dashboard application."""
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
@ -846,7 +1289,19 @@ def main():
common_feature_array = extract_common_features(model_state)
predictions = gpd.read_parquet(results_dir / "predicted_probabilities.parquet").set_crs("epsg:3413")
# Determine task type and available metrics
task = settings.get("task", "binary")
available_metrics = settings.get("metrics", ["accuracy", "recall", "precision", "f1", "jaccard"])
refit_metric = "f1" if task == "binary" else "f1_macro"
# Generate unified colormaps once for all visualizations
matplotlib_cmap, folium_cmap, altair_colors = generate_unified_colormap(settings)
st.sidebar.success(f"Loaded {len(results)} results")
st.sidebar.info(f"Task: {task.capitalize()} Classification")
# Dump the settings into the sidebar
with st.sidebar.expander("Search Settings", expanded=True):
st.json(settings)
# Display some basic statistics first (lightweight)
st.header("Parameter-Search Overview")
@ -857,16 +1312,18 @@ def main():
st.metric("Total Runs", len(results))
with col2:
# Best model based on F1 score
best_f1_idx = results["mean_test_f1"].idxmax()
st.metric("Best Model Index (by F1)", f"#{best_f1_idx}")
# Best model based on refit metric
best_idx = results[f"mean_test_{refit_metric}"].idxmax()
st.metric(f"Best Model Index (by {format_metric_name(refit_metric)})", f"#{best_idx}")
# Show best parameters for the best model
best_params = results.loc[best_f1_idx, ["initial_K", "eps_cl", "eps_e", "mean_test_f1", "std_test_f1"]]
best_params = results.loc[
best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{refit_metric}", f"std_test_{refit_metric}"]
]
with st.container(border=True):
st.subheader(":abacus: Best Model Parameters")
st.caption("Parameters of retrained best model (selected by F1 score)")
st.caption(f"Parameters of retrained best model (selected by {format_metric_name(refit_metric)} score)")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("initial_K", f"{best_params['initial_K']:.0f}")
@ -877,32 +1334,132 @@ def main():
# Show all metrics for the best model in a container
st.subheader(":bar_chart: Performance Across All Metrics")
st.caption("Complete performance profile of the best model (selected by F1 score)")
st.caption(
f"Complete performance profile of the best model (selected by {format_metric_name(refit_metric)} score)"
)
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
cols = st.columns(len(available_metrics))
for idx, metric in enumerate(available_metrics):
with cols[idx]:
best_score = results.loc[best_f1_idx, f"mean_test_{metric}"]
best_std = results.loc[best_f1_idx, f"std_test_{metric}"]
# Highlight F1 since that's what we optimized for
best_score = results.loc[best_idx, f"mean_test_{metric}"]
best_std = results.loc[best_idx, f"std_test_{metric}"]
# Highlight refit metric since that's what we optimized for
st.metric(
f"{metric.capitalize()}",
format_metric_name(metric),
f"{best_score:.4f}",
delta=f"±{best_std:.4f}",
help="Mean ± std across cross-validation folds",
)
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["Search Results", "Model State", "Inference Analysis"])
tab0, tab1, tab2, tab3 = st.tabs(["Training Data", "Search Results", "Model State", "Inference Analysis"])
with tab0:
# Training Data Analysis
st.header("Training Data Analysis")
st.markdown("Comprehensive analysis of the training dataset used for model development")
# Load training data
with st.spinner("Loading training data..."):
training_gdf, X_data, y_data, _ = _load_training_data(
grid=settings["grid"], level=settings["level"], task=task
)
# Summary statistics
st.subheader("Dataset Statistics")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Samples", f"{len(training_gdf):,}")
with col2:
st.metric("Number of Features", f"{X_data.shape[1]:,}")
with col3:
n_classes = y_data.nunique()
st.metric("Number of Classes", n_classes)
with col4:
missing_pct = X_data.isnull().sum().sum() / (X_data.shape[0] * X_data.shape[1]) * 100
st.metric("Missing Values", f"{missing_pct:.2f}%")
# Class distribution visualization
st.subheader("Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_training_class_distribution(y_data, altair_colors)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between different classes in the training dataset
- Class imbalance affects model learning and may require special handling
- Each bar represents the count of training samples for that class
"""
)
# Interactive map
st.subheader("Interactive Training Data Map")
st.markdown("Explore the spatial distribution of training samples by class")
with st.spinner("Generating interactive map..."):
training_map = _plot_training_data_map(training_gdf, folium_cmap)
st_folium.st_folium(training_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Training Data Statistics"):
st.write("**Class Distribution:**")
class_counts = y_data.value_counts().sort_index()
# Create columns for better layout
n_cols = min(5, len(class_counts))
cols = st.columns(n_cols)
for idx, (class_label, count) in enumerate(class_counts.items()):
percentage = count / len(y_data) * 100
with cols[idx % n_cols]:
st.metric(f"Class {class_label}", f"{count:,} ({percentage:.2f}%)")
# Show detailed table
st.write("**Detailed Class Breakdown:**")
class_df = pd.DataFrame(
{
"Class": class_counts.index,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(y_data) * 100).round(2),
}
)
st.dataframe(class_df, width="stretch", hide_index=True)
# Feature statistics
st.write("**Feature Statistics:**")
st.markdown(f"- Total number of features: **{X_data.shape[1]}**")
st.markdown(f"- Features with missing values: **{X_data.isnull().any().sum()}**")
# Show feature types breakdown
feature_types = X_data.columns.to_series().apply(lambda x: x.split("_")[0]).value_counts()
st.write("**Feature Type Distribution:**")
feature_type_df = pd.DataFrame(
{
"Feature Type": feature_types.index,
"Count": feature_types.to_numpy(),
}
)
st.dataframe(feature_type_df, width="stretch", hide_index=True)
with tab1:
# Metric selection - only used in this tab
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
selected_metric = st.selectbox(
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
metric_display_names = {metric: format_metric_name(metric) for metric in available_metrics}
selected_metric_display = st.selectbox(
"Select Metric",
options=[format_metric_name(m) for m in available_metrics],
help="Choose which metric to visualize",
)
# Convert back to raw metric name
selected_metric = next(k for k, v in metric_display_names.items() if v == selected_metric_display)
# Show best parameters
with st.expander("Best Parameters"):
@ -911,7 +1468,7 @@ def main():
st.dataframe(best_params.to_frame().T, width="content")
# Main plots
st.header(f"Visualization for {selected_metric.capitalize()}")
st.header(f"Visualization for {format_metric_name(selected_metric)}")
# K-binned plot configuration
@st.fragment
@ -978,17 +1535,35 @@ def main():
col1, col2 = st.columns(2)
with col1:
st.subheader("K vs eps_cl")
with st.spinner("Generating eps_cl plot..."):
chart3 = _plot_eps_binned(results_binned, "eps_cl", f"mean_test_{selected_metric}")
st.subheader(f"K vs {selected_metric} (binned by eps_e)")
with st.spinner("Generating plot..."):
chart3 = _plot_params_binned(
results_binned, "initial_K", "eps_cl", "eps_e_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart3, use_container_width=True)
with col2:
st.subheader("K vs eps_e")
with st.spinner("Generating eps_e plot..."):
chart4 = _plot_eps_binned(results_binned, "eps_e", f"mean_test_{selected_metric}")
st.subheader(f"eps_e vs {selected_metric} (binned by K)")
with st.spinner("Generating plot..."):
chart4 = _plot_params_binned(
results_binned, "eps_e", "eps_cl", "initial_K_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart4, use_container_width=True)
with col2:
st.subheader(f"K vs {selected_metric} (binned by eps_cl)")
with st.spinner("Generating plot..."):
chart5 = _plot_params_binned(
results_binned, "initial_K", "eps_e", "eps_cl_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart5, use_container_width=True)
st.subheader(f"eps_cl vs {selected_metric} (binned by K)")
with st.spinner("Generating plot..."):
chart6 = _plot_params_binned(
results_binned, "eps_cl", "eps_e", "initial_K_binned", f"mean_test_{selected_metric}"
)
st.altair_chart(chart6, use_container_width=True)
render_k_binned_plots()
# Metric comparison plots
@ -1003,19 +1578,37 @@ def main():
help="Choose which parameter to use for coloring the scatter plots",
)
# Dynamically determine which metrics to compare based on available metrics
if task == "binary":
# For binary: show recall vs precision and accuracy vs jaccard
comparisons = [
("precision", "recall", "Recall vs Precision"),
("accuracy", "jaccard", "Accuracy vs Jaccard"),
]
else:
# For multiclass: show micro vs macro variants
comparisons = [
("precision_macro", "recall_macro", "Recall Macro vs Precision Macro"),
("accuracy", "jaccard_macro", "Accuracy vs Jaccard Macro"),
]
col1, col2 = st.columns(2)
with col1:
st.subheader("Recall vs Precision")
with st.spinner("Generating Recall vs Precision plot..."):
recall_precision_chart = _plot_metric_comparison(results, "precision", "recall", color_param)
st.altair_chart(recall_precision_chart, use_container_width=True)
st.subheader(comparisons[0][2])
with st.spinner(f"Generating {comparisons[0][2]} plot..."):
chart1 = _plot_metric_comparison(
results, comparisons[0][0], comparisons[0][1], color_param, refit_metric
)
st.altair_chart(chart1, use_container_width=True)
with col2:
st.subheader("Accuracy vs Jaccard")
with st.spinner("Generating Accuracy vs Jaccard plot..."):
accuracy_jaccard_chart = _plot_metric_comparison(results, "accuracy", "jaccard", color_param)
st.altair_chart(accuracy_jaccard_chart, use_container_width=True)
st.subheader(comparisons[1][2])
with st.spinner(f"Generating {comparisons[1][2]} plot..."):
chart2 = _plot_metric_comparison(
results, comparisons[1][0], comparisons[1][1], color_param, refit_metric
)
st.altair_chart(chart2, use_container_width=True)
render_metric_comparisons()
@ -1094,7 +1687,7 @@ def main():
with col2:
st.markdown("### Box Count by Class")
box_assignment_bars = _plot_box_assignment_bars(model_state)
box_assignment_bars = _plot_box_assignment_bars(model_state, altair_colors)
st.altair_chart(box_assignment_bars, use_container_width=True)
# Show statistics
@ -1302,30 +1895,44 @@ def main():
n_classes = predictions["predicted_class"].nunique()
st.metric("Number of Classes", n_classes)
# Class distribution visualization
st.subheader("Class Distribution")
# Class distribution and static map side by side
st.subheader("Prediction Overview")
col_map, col_dist = st.columns([0.6, 0.4])
with col_map:
st.markdown("#### Static Prediction Map")
st.markdown("High-quality map with North Polar Stereographic projection")
with st.spinner("Generating static map..."):
static_fig = _plot_prediction_map_static(predictions, matplotlib_cmap)
st.pyplot(static_fig, width="stretch")
st.markdown(
"""
**Map Features:**
- North Polar Azimuthal Equidistant projection (optimized for Arctic)
- Predictions colored by class using inferno colormap
- Includes coastlines and 50°N latitude boundary
"""
)
with col_dist:
st.markdown("#### Class Distribution")
with st.spinner("Generating class distribution..."):
class_dist_chart = _plot_prediction_class_distribution(predictions)
class_dist_chart = _plot_prediction_class_distribution(predictions, altair_colors)
st.altair_chart(class_dist_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- Shows the balance between predicted classes
- Class imbalance may indicate regional patterns or model bias
- Each bar represents the count of cells predicted for that class
- Balance between predicted classes
- Class imbalance may indicate regional patterns
- Each bar shows cell count per class
"""
)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
with st.spinner("Generating interactive map..."):
chart_map = _plot_prediction_map(predictions)
st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
# Additional statistics in expander
with st.expander("Detailed Prediction Statistics"):
st.write("**Class Distribution:**")
@ -1351,6 +1958,17 @@ def main():
)
st.dataframe(class_df, width="stretch", hide_index=True)
# Interactive map
st.subheader("Interactive Prediction Map")
st.markdown("Explore predictions spatially with the interactive map below")
# with st.spinner("Generating interactive map..."):
# chart_map = _plot_prediction_map(predictions, folium_colors)
# st_folium.st_folium(chart_map, width="100%", height=600, returned_objects=[])
st.text("Interactive map functionality is currently disabled.")
st.balloons()
if __name__ == "__main__":
main()

4742
uv.lock generated

File diff suppressed because it is too large Load diff