Small fixes all over the place

This commit is contained in:
Tobias Hölzer 2026-01-08 20:00:09 +01:00
parent c92e856c55
commit 1495f71ac9
9 changed files with 3923 additions and 4084 deletions

7310
pixi.lock generated

File diff suppressed because it is too large Load diff

View file

@ -12,7 +12,7 @@ dependencies = [
"cartopy>=0.24.1",
"cdsapi>=0.7.6",
"cyclopts>=4.0.0",
"dask>=2025.5.1",
"dask>=2025.11.0",
"distributed>=2025.5.1",
"earthengine-api>=1.6.9",
"eemont>=2025.7.1",
@ -34,7 +34,6 @@ dependencies = [
"odc-geo[all]>=0.4.10",
"opt-einsum>=3.4.0",
"pyarrow>=18.1.0",
"rechunker>=0.5.2",
"requests>=2.32.3",
"rich>=14.0.0",
"rioxarray>=0.19.0",
@ -66,7 +65,9 @@ dependencies = [
"pypalettes>=0.2.1,<0.3",
"ty>=0.0.2,<0.0.3",
"ruff>=0.14.9,<0.15",
"pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10",
"pandas-stubs>=2.3.3.251201,<3",
"pytest>=9.0.2,<10",
"autogluon-tabular[all]>=1.5.0",
]
[project.scripts]
@ -90,15 +91,15 @@ url = "https://pypi.nvidia.com"
explicit = true
[tool.uv.sources]
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
# entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
xanimate = { git = "https://github.com/davbyr/xAnimate" }
xdem = { git = "https://github.com/GlacioHack/xdem" }
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" }
cuspatial-cu12 = { index = "nvidia" }
# cudf-cu12 = { index = "nvidia" }
# cuml-cu12 = { index = "nvidia" }
# cuspatial-cu12 = { index = "nvidia" }
[tool.ruff]
line-length = 120
@ -148,5 +149,10 @@ nccl = ">=2.27.7.1,<3"
cudnn = ">=9.13.1.26,<10"
cusparselt = ">=0.8.1.1,<0.9"
cuda-version = "12.9.*"
rapids = ">=25.10.0,<26"
# cudf = ">=25.10.0,<26"
# cuml = ">=25.10.0,<26"
healpix-geo = ">=0.0.6"
scikit-learn = ">=1.4.0,<1.8.0"
pyarrow = ">=7.0.0,<21.0.0"
cudf = ">=25.12.0,<26"
cuml = ">=25.12.0,<26"

View file

@ -1,12 +1,12 @@
import xarray as xr
import zarr
from rich import print
import dask.distributed as dd
import xarray as xr
from rich import print
from entropice.utils.paths import get_era5_stores
import entropice.utils.codecs
from entropice.utils.paths import get_era5_stores
def print_info(daily_raw = None, show_vars: bool = True):
def print_info(daily_raw=None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
@ -30,29 +30,109 @@ def print_info(daily_raw = None, show_vars: bool = True):
print(da.encoding)
print("")
def rechunk():
def rechunk(use_shards: bool = False):
if use_shards:
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
with (
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
dd.Client(cluster) as client,
):
print(f"Dashboard: {client.dashboard_link}")
daily_store = get_era5_stores("daily")
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print_info(daily_raw, False)
daily_raw = daily_raw.chunk({
daily_raw = daily_raw.chunk(
{
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1 # Should be 3600
})
print_info(daily_raw, False)
"longitude": -1, # Should be 3600
}
)
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
for var in daily_raw.data_vars:
encoding[var]["chunks"] = (120, 337, 3600)
if use_shards:
encoding[var]["shards"] = (1200, 337, 3600)
print(encoding)
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
def validate():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
encoding = entropice.utils.codecs.from_ds(daily_raw)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
print("\n=== Comparing Datasets ===")
# Compare dimensions
if daily_raw.sizes != daily_rechunked.sizes:
print("❌ Dimensions differ:")
print(f" Original: {daily_raw.sizes}")
print(f" Rechunked: {daily_rechunked.sizes}")
else:
print("✅ Dimensions match")
# Compare variables
raw_vars = set(daily_raw.data_vars)
rechunked_vars = set(daily_rechunked.data_vars)
if raw_vars != rechunked_vars:
print("❌ Variables differ:")
print(f" Only in original: {raw_vars - rechunked_vars}")
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
else:
print("✅ Variables match")
# Compare each variable
print("\n=== Variable Comparison ===")
all_equal = True
for var in raw_vars & rechunked_vars:
raw_var = daily_raw[var]
rechunked_var = daily_rechunked[var]
if raw_var.equals(rechunked_var):
print(f"{var}: Equal")
else:
all_equal = False
print(f"{var}: NOT Equal")
# Check if values are equal
try:
values_equal = raw_var.values.shape == rechunked_var.values.shape
if values_equal:
import numpy as np
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
if values_equal:
print(" → Values are numerically equal (likely metadata/encoding difference)")
else:
print(" → Values differ!")
print(f" Original shape: {raw_var.values.shape}")
print(f" Rechunked shape: {rechunked_var.values.shape}")
except Exception as e:
print(f" → Error comparing values: {e}")
# Check attributes
if raw_var.attrs != rechunked_var.attrs:
print(" → Attributes differ:")
print(f" Original: {raw_var.attrs}")
print(f" Rechunked: {rechunked_var.attrs}")
# Check encoding
if raw_var.encoding != rechunked_var.encoding:
print(" → Encoding differs:")
print(f" Original: {raw_var.encoding}")
print(f" Rechunked: {rechunked_var.encoding}")
if all_equal:
print("\n✅ Validation successful: All datasets are equal.")
else:
print("\n❌ Validation failed: Datasets have differences (see above).")
if __name__ == "__main__":
with (
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
rechunk()
print("Done.")
validate()

View file

@ -9,11 +9,11 @@ import numpy as np
import pandas as pd
import pydeck as pdk
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.class_ordering import get_ordered_classes
from entropice.dashboard.utils.colors import get_cmap, get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.training import TrainingSettings
@ -1154,73 +1154,36 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
@st.fragment
def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
"""Render 3D pydeck map showing prediction results.
def render_confusion_matrix_map(
result_path: Path, settings: TrainingSettings, merged_predictions: gpd.GeoDataFrame | None = None
):
"""Render 3D pydeck map showing model performance on training data.
Uses true labels for elevation (height) and different shades of red for incorrect predictions
based on the predicted class.
Displays cells from the training dataset with predictions, colored by correctness.
Uses true labels for elevation (height) and different shades of red for incorrect predictions.
Args:
result_path: Path to the training result directory.
result_path: Path to the training result directory (not used, kept for compatibility).
settings: Settings dictionary containing grid, level, task, and target information.
merged_predictions: GeoDataFrame with predictions, true labels, and split info.
"""
st.subheader("🗺️ Prediction Results Map")
# Load predicted probabilities
preds_file = result_path / "predicted_probabilities.parquet"
if not preds_file.exists():
st.warning("No predicted probabilities found for this training run.")
if merged_predictions is None:
st.warning("Prediction data not available. Cannot display map.")
return
preds_gdf = gpd.read_parquet(preds_file)
# Get task and target information from settings
task = settings.task
target = settings.target
# Get grid type and task from settings
grid = settings.grid
level = settings.level
task = settings.task
# Create dataset ensemble to get true labels
# We need to load the target data to get true labels
try:
ensemble = DatasetEnsemble(
grid=grid,
level=level,
target=target,
members=[], # We don't need feature data, just target
)
training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
except Exception as e:
st.error(f"Error loading training data: {e}")
return
# Get all cells from the complete dataset (not just test split)
# Use the full dataset which includes both train and test splits
all_cells = training_data.dataset.copy()
# Merge predictions with true labels
# Reset index to avoid ambiguity between index and column
labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"})
labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy()
# Merge with predictions - use left join to keep all cells
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left")
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
# Mark which cells have predictions (test split) vs not (training split)
merged["in_test_split"] = merged["predicted_class"].notna()
# For cells without predictions (training split), use true class as predicted class for visualization
merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"])
# Use the merged predictions which already have true labels, predictions, and split info
merged = merged_predictions.copy()
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
if len(merged) == 0:
st.warning("No matching predictions found for labeled cells.")
st.warning("No predictions found for labeled cells.")
return
# Mark correct vs incorrect predictions (only meaningful for test split)
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
# Get ordered class labels for the task
ordered_classes = get_ordered_classes(task)
@ -1228,51 +1191,55 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
# Filter by prediction correctness and split
categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"]
selected_category = st.selectbox(
"Filter by Category",
options=categories,
key="confusion_map_category",
# Split selector (similar to confusion matrix)
split_type = st.selectbox(
"Select Data Split",
options=["test", "train", "all"],
format_func=lambda x: {"test": "Test Set", "train": "Training Set (CV)", "all": "All Data"}[x],
help="Choose which data split to display on the map",
key="prediction_map_split_select",
)
with col2:
# Color scheme selector
show_only_incorrect = st.checkbox(
"Highlight Errors Only",
value=False,
help="Show only incorrect predictions in red, hide correct ones",
key="prediction_map_errors_only",
)
with col3:
opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="confusion_map_opacity",
key="prediction_map_opacity",
)
with col3:
line_width = st.slider(
"Line Width",
min_value=0.5,
max_value=3.0,
value=1.0,
step=0.5,
key="confusion_map_line_width",
)
# Filter data if needed
if selected_category == "Test Split Only":
display_gdf = merged[merged["in_test_split"]].copy()
elif selected_category == "Training Split Only":
display_gdf = merged[~merged["in_test_split"]].copy()
elif selected_category == "Correct (Test)":
display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy()
elif selected_category == "Incorrect (Test)":
display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy()
else: # "All"
# Filter data by split
if split_type == "test":
display_gdf = merged[merged["split"] == "test"].copy()
split_caption = "Test Set (held-out data)"
elif split_type == "train":
display_gdf = merged[merged["split"] == "train"].copy()
split_caption = "Training Set (CV data)"
else: # "all"
display_gdf = merged.copy()
split_caption = "All Available Data"
# Optionally filter to show only incorrect predictions
if show_only_incorrect:
display_gdf = display_gdf[~display_gdf["is_correct"]].copy()
if len(display_gdf) == 0:
st.warning(f"No cells found for category: {selected_category}")
st.warning(f"No cells found for {split_caption}.")
return
st.caption(f"📍 Showing {len(display_gdf)} cells from {split_caption}")
# Convert to WGS84 for pydeck
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
@ -1303,12 +1270,12 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
# Add line color based on split: blue for test split, orange for training split
# Add line color based on split: blue for test, orange for train
def get_line_color(row):
if row["in_test_split"]:
if row["split"] == "test":
return [52, 152, 219] # Blue for test split
else:
return [230, 126, 34] # Orange for training split
return [230, 126, 34] # Orange for train split
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
@ -1329,18 +1296,15 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
geojson_data = []
for _, row in display_gdf_wgs84.iterrows():
# Determine split and status for tooltip
split_name = "Test Split" if row["in_test_split"] else "Training Split"
if row["in_test_split"]:
split_name = "Test" if row["split"] == "test" else "Training (CV)"
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
else:
status = "(No prediction - training data)"
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"true_class": str(row["true_class"]),
"predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A",
"true_label": str(row["true_class"]),
"predicted_label": str(row["predicted_class"]),
"is_correct": bool(row["is_correct"]),
"split": split_name,
"status": status,
@ -1362,7 +1326,7 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color="properties.line_color",
line_width_min_pixels=line_width,
line_width_min_pixels=2,
get_elevation="properties.elevation",
elevation_scale=500000,
pickable=True,
@ -1376,10 +1340,10 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Split:</b> {split}<br/>"
"<b>True Label:</b> {true_class}<br/>"
"<b>Predicted Label:</b> {predicted_class}<br/>"
"<b>Status:</b> {status}",
"html": "<b>Status:</b> {status}<br/>"
"<b>True Label:</b> {true_label}<br/>"
"<b>Predicted Label:</b> {predicted_label}<br/>"
"<b>Split:</b> {split}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
@ -1388,83 +1352,55 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
# Render the map
st.pydeck_chart(deck)
# Show statistics
col1, col2, col3, col4 = st.columns(4)
# Show statistics for displayed data
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Labeled Cells", len(merged))
st.metric("Cells Displayed", len(display_gdf))
with col2:
test_count = len(merged[merged["in_test_split"]])
st.metric("Test Split", test_count)
correct = len(display_gdf[display_gdf["is_correct"]])
st.metric("Correct Predictions", correct)
with col3:
train_count = len(merged[~merged["in_test_split"]])
st.metric("Training Split", train_count)
with col4:
test_cells = merged[merged["in_test_split"]]
if len(test_cells) > 0:
correct = len(test_cells[test_cells["is_correct"]])
accuracy = correct / len(test_cells)
st.metric("Test Accuracy", f"{accuracy:.2%}")
if len(display_gdf) > 0:
accuracy = correct / len(display_gdf)
st.metric("Accuracy", f"{accuracy:.2%}")
else:
st.metric("Test Accuracy", "N/A")
st.metric("Accuracy", "N/A")
# Add legend
with st.expander("Legend", expanded=True):
# Split indicators (border colors)
st.markdown("**Data Split (Border Color):**")
st.markdown("**Fill Color (Prediction Correctness):**")
test_count = len(merged[merged["in_test_split"]])
train_count = len(merged[~merged["in_test_split"]])
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Test Split</b> ({test_count} cells, {test_count / len(merged) * 100:.1f}%)</span></div>",
unsafe_allow_html=True,
)
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 12px;">'
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Training Split</b> ({train_count} cells, {train_count / len(merged) * 100:.1f}%)</span></div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Fill Color (Prediction Results):**")
# Correct predictions (test split only)
test_cells = merged[merged["in_test_split"]]
correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0
incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0
# Correct predictions
correct_count = len(display_gdf[display_gdf["is_correct"]])
incorrect_count = len(display_gdf[~display_gdf["is_correct"]])
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span><b>Correct Predictions (Test)</b> ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)</span></div>",
f"<span><b>Correct Predictions</b> ({correct_count} cells, "
f"{correct_count / len(display_gdf) * 100 if len(display_gdf) > 0 else 0:.1f}%)</span></div>",
unsafe_allow_html=True,
)
# Incorrect predictions by predicted class (shades of red)
if incorrect_count > 0:
st.markdown(
f"<b>Incorrect Predictions by Predicted Class (Test)</b> ({incorrect} cells):", unsafe_allow_html=True
f"<b>Incorrect Predictions by Predicted Class</b> ({incorrect_count} cells):", unsafe_allow_html=True
)
for class_idx, class_label in enumerate(ordered_classes):
# Get count of incorrect predictions for this predicted class (test split only)
count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)])
# Get count of incorrect predictions for this predicted class
count = len(display_gdf[(~display_gdf["is_correct"]) & (display_gdf["predicted_class"] == class_label)])
if count > 0:
# Get color for this predicted class
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
percentage = count / incorrect * 100 if incorrect > 0 else 0
percentage = count / incorrect_count * 100
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
@ -1474,16 +1410,33 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
unsafe_allow_html=True,
)
# Note about training split
st.markdown("---")
st.markdown("**Border Color (Data Split):**")
# Count by split in displayed data
test_in_display = len(display_gdf[display_gdf["split"] == "test"])
train_in_display = len(display_gdf[display_gdf["split"] == "train"])
if test_in_display > 0:
st.markdown(
f'<div style="margin-top: 8px; font-style: italic; color: #888;">'
f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) "
f"since predictions are only available for the test split.</div>",
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Test Split</b> ({test_in_display} cells)</span></div>",
unsafe_allow_html=True,
)
if train_in_display > 0:
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Training Split</b> ({train_in_display} cells)</span></div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Elevation (3D):**")
st.markdown("**Elevation (3D Height):**")
# Show elevation mapping for each true class
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
@ -1514,13 +1467,23 @@ def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str)
true_labels = confusion_matrix.coords["true_label"].values
pred_labels = confusion_matrix.coords["predicted_label"].values
# For binary classification, map 0/1 to No-RTS/RTS
if task == "binary":
# Check if labels are already strings (from predictions) or numeric (from stored confusion matrices)
first_true_label = true_labels[0]
is_string_labels = isinstance(first_true_label, str) or (
hasattr(first_true_label, "dtype") and first_true_label.dtype.kind in ("U", "O")
)
if is_string_labels:
# Labels are already string labels, use them directly
true_labels_str = [str(label) for label in true_labels]
pred_labels_str = [str(label) for label in pred_labels]
elif task == "binary":
# Numeric binary labels - map 0/1 to No-RTS/RTS
label_map = {0: "No-RTS", 1: "RTS"}
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
else:
# For multiclass, use numeric labels as is
# Numeric multiclass labels - use as is
true_labels_str = [str(label) for label in true_labels]
pred_labels_str = [str(label) for label in pred_labels]

View file

@ -41,7 +41,6 @@ def get_members_from_settings(settings) -> list[L2SourceDataset]:
return settings.members
@st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.

View file

@ -2,7 +2,9 @@
from typing import cast
import geopandas as gpd
import streamlit as st
import xarray as xr
from stopuhr import stopwatch
from entropice.dashboard.plots.hyperparameter_analysis import (
@ -22,6 +24,127 @@ from entropice.dashboard.utils.stats import CVResultsStatistics
from entropice.utils.types import GridConfig
def load_predictions_with_labels(selected_result: TrainingResult) -> gpd.GeoDataFrame | None:
"""Load predictions and merge with training data to get true labels and split info.
Args:
selected_result: The selected TrainingResult object.
Returns:
GeoDataFrame with predictions, true labels, and split information, or None if unavailable.
"""
from sklearn.model_selection import train_test_split
from entropice.ml.dataset import DatasetEnsemble, bin_values, taskcol
# Load predictions
preds_gdf = selected_result.load_predictions()
if preds_gdf is None:
return None
# Create a minimal dataset ensemble to access target data
settings = selected_result.settings
dataset_ensemble = DatasetEnsemble(
grid=settings.grid,
level=settings.level,
target=settings.target,
members=[], # No feature data needed, just targets
)
# Load target dataset (just labels, no features)
with st.spinner("Loading target labels..."):
targets = dataset_ensemble._read_target()
# Get coverage and task columns
task_col = taskcol[settings.task][settings.target]
# Filter for valid labels (same as in _cat_and_split)
valid_labels = targets[task_col].notna()
filtered_targets = targets.loc[valid_labels].copy()
# Apply binning to get class labels (same logic as _cat_and_split)
if settings.task == "binary":
binned = filtered_targets[task_col].map({False: "No RTS", True: "RTS"}).astype("category")
elif settings.task == "count":
binned = bin_values(filtered_targets[task_col].astype(int), task=settings.task)
elif settings.task == "density":
binned = bin_values(filtered_targets[task_col], task=settings.task)
else:
raise ValueError(f"Invalid task: {settings.task}")
filtered_targets["true_class"] = binned.to_numpy()
# Recreate the train/test split deterministically (same random_state=42 as in _cat_and_split)
_train_idx, test_idx = train_test_split(
filtered_targets.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True
)
filtered_targets["split"] = "train"
filtered_targets.loc[test_idx, "split"] = "test"
filtered_targets["split"] = filtered_targets["split"].astype("category")
# Ensure cell_id is available as a column for merging
# Check if cell_id already exists, otherwise use the index
if "cell_id" not in filtered_targets.columns:
filtered_targets = filtered_targets.reset_index().rename(columns={"index": "cell_id"})
# Merge predictions with labels (inner join to keep only cells with predictions)
merged = filtered_targets.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
merged_gdf = gpd.GeoDataFrame(merged, geometry="geometry", crs=targets.crs)
return merged_gdf
def compute_confusion_matrix_from_merged_data(
merged_data: gpd.GeoDataFrame,
split_type: str,
label_names: list[str],
) -> xr.DataArray | None:
"""Compute confusion matrix from merged predictions and labels.
Args:
merged_data: GeoDataFrame with 'true_class', 'predicted_class', and 'split' columns.
split_type: One of 'test', 'train', or 'all'.
label_names: List of class label names in order.
Returns:
xarray.DataArray with confusion matrix or None if data unavailable.
"""
from sklearn.metrics import confusion_matrix
# Filter by split type
if split_type == "train":
data = merged_data[merged_data["split"] == "train"]
elif split_type == "test":
data = merged_data[merged_data["split"] == "test"]
elif split_type == "all":
data = merged_data
else:
raise ValueError(f"Invalid split_type: {split_type}")
if len(data) == 0:
st.warning(f"No data available for {split_type} split.")
return None
# Get true and predicted labels
y_true = data["true_class"].to_numpy()
y_pred = data["predicted_class"].to_numpy()
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=label_names)
# Create xarray DataArray
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
return cm_xr
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
"""Render sidebar for training run and analysis settings selection.
@ -233,21 +356,56 @@ def render_cv_statistics_section(selected_result, selected_metric):
)
def render_confusion_matrix_section(selected_result: TrainingResult):
@st.fragment
def render_confusion_matrix_section(selected_result: TrainingResult, merged_predictions: gpd.GeoDataFrame | None):
"""Render confusion matrix visualization and analysis.
Args:
selected_result: The selected TrainingResult object.
merged_predictions: GeoDataFrame with predictions merged with true labels and split info.
"""
st.header("🎲 Confusion Matrix")
st.caption("Detailed breakdown of predictions on the test set")
st.caption("Detailed breakdown of predictions")
# Add selector for confusion matrix type
cm_type = st.selectbox(
"Select Data Split",
options=["test", "train", "all"],
format_func=lambda x: {"test": "Test Set", "train": "CV Set (Train Split)", "all": "All Available Data"}[x],
help="Choose which data split to display the confusion matrix for",
key="cm_split_select",
)
# Get label names from settings
label_names = selected_result.settings.classes
# Compute or load confusion matrix based on selection
if cm_type == "test":
if selected_result.confusion_matrix is None:
st.warning("No confusion matrix available for this training run.")
st.warning("No confusion matrix available for the test set.")
return
cm = selected_result.confusion_matrix
st.info("📊 Showing confusion matrix for the **Test Set** (held-out data, never used during training)")
else:
if merged_predictions is None:
st.warning("Predictions data not available. Cannot compute confusion matrix.")
return
render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task)
with st.spinner(f"Computing confusion matrix for {cm_type} split..."):
cm = compute_confusion_matrix_from_merged_data(merged_predictions, cm_type, label_names)
if cm is None:
return
if cm_type == "train":
st.info(
"📊 Showing confusion matrix for the **CV Set (Train Split)** "
"(data used during hyperparameter search cross-validation)"
)
else: # all
st.info("📊 Showing confusion matrix for **All Available Data** (combined train and test splits)")
render_confusion_matrix_heatmap(cm, selected_result.settings.task)
def render_parameter_space_section(selected_result, selected_metric):
@ -374,6 +532,9 @@ def render_training_analysis_page():
return
selected_result, selected_metric, refit_metric, top_n = selection_result
# Load predictions with labels once (used by confusion matrix and map)
merged_predictions = load_predictions_with_labels(selected_result)
# Main content area
results = selected_result.results
settings = selected_result.settings
@ -389,7 +550,7 @@ def render_training_analysis_page():
st.divider()
# Confusion Matrix Section
render_confusion_matrix_section(selected_result)
render_confusion_matrix_section(selected_result, merged_predictions)
st.divider()
@ -400,9 +561,10 @@ def render_training_analysis_page():
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
# Prediction Analysis Map Section
st.header("🗺️ Model Performance Map")
st.caption("Interactive 3D map showing prediction correctness across the training dataset")
render_confusion_matrix_map(selected_result.path, settings, merged_predictions)
st.divider()

View file

@ -734,14 +734,14 @@ def spatial_agg(
3: _Aggregations.common(),
4: _Aggregations.common(),
5: _Aggregations(mean=True),
6: "interpolate",
6: "interpolate", # nearest neighbor interpolation
},
"healpix": {
6: _Aggregations.common(),
7: _Aggregations.common(),
8: _Aggregations(mean=True),
9: _Aggregations(mean=True),
10: "interpolate",
10: "interpolate", # nearest neighbor interpolation
},
}
aggregations = aggregations_by_gridlevel[grid][level]

View file

@ -8,8 +8,8 @@ import cupy as cp
import cyclopts
import pandas as pd
import toml
import torch
import xarray as xr
from array_api_compat import get_namespace
from cuml.ensemble import RandomForestClassifier
from cuml.neighbors import KNeighborsClassifier
from entropy import ESPAClassifier
@ -233,10 +233,9 @@ def random_cv(
# Compute predictions on the test set
y_pred = best_estimator.predict(training_data.X.test)
labels = list(range(len(training_data.y.labels)))
xp = get_namespace(y_test)
y_test = xp.as_array(y_test)
y_pred = xp.as_array(y_pred)
labels = xp.as_array(labels)
y_test = torch.asarray(y_test, device="cuda")
y_pred = torch.asarray(y_pred, device="cuda")
labels = torch.asarray(labels, device="cuda")
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
@ -244,7 +243,7 @@ def random_cv(
cm = confusion_matrix(y_test, y_pred, labels=labels)
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm = xr.DataArray(
xp.as_numpy(cm),
cm.cpu().numpy(),
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",

View file

@ -368,7 +368,7 @@ def _init_worker(r: xr.Dataset | None):
def _align_partition(
grid_partition_gdf: gpd.GeoDataFrame,
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
aggregations: _Aggregations | None, # None -> Interpolation
aggregations: _Aggregations | str, # str -> Interpolation method
pxbuffer: int,
):
# ? This function is expected to run inside a worker process
@ -441,7 +441,7 @@ def _align_partition(
)
memprof.log_memory("After reading partial raster", log=False)
if aggregations is None:
if isinstance(aggregations, str):
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
if grid_partition_gdf.crs.to_epsg() == 4326:
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
@ -453,14 +453,26 @@ def _align_partition(
cy = centroids.y
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
interp_coords = (
{"latitude": interp_y, "longitude": interp_x}
if "latitude" in raster.dims and "longitude" in raster.dims
else {"y": interp_y, "x": interp_x}
)
ydim = "latitude" if "latitude" in raster.dims else "y"
xdim = "longitude" if "longitude" in raster.dims else "x"
interp_coords = {ydim: interp_y, xdim: interp_x}
# ?: Cubic does not work with NaNs in xarray interp
with stopwatch("Interpolating data to grid centroids", log=False):
ongrid = partial_raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
# Fill the nan
y_is_rev = partial_raster.indexes[ydim].is_monotonic_decreasing
if y_is_rev:
partial_raster = partial_raster.sortby(ydim)
partial_raster = partial_raster.interpolate_na(
dim=ydim,
method=aggregations,
).interpolate_na(
dim=xdim,
method=aggregations,
)
if y_is_rev:
partial_raster = partial_raster.sortby(ydim)
ongrid = partial_raster.interp(interp_coords, method=aggregations, kwargs={"fill_value": np.nan})
memprof.log_memory("After interpolating data", log=False)
else:
others_shape = tuple(
@ -519,7 +531,7 @@ def _align_partition(
def _align_data(
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations | None,
aggregations: _Aggregations | str,
n_partitions: int | None,
concurrent_partitions: int,
pxbuffer: int,
@ -620,7 +632,7 @@ def _align_data(
def aggregate_raster_into_grid(
raster: xr.Dataset | Callable[[], xr.Dataset],
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
aggregations: _Aggregations | Literal["interpolate"],
aggregations: _Aggregations | Literal["nearest", "linear", "cubic", "interpolate"],
grid: Grid,
level: int,
n_partitions: int | None = 20,
@ -634,7 +646,11 @@ def aggregate_raster_into_grid(
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
No further partitioning will be done and the n_partitions argument will be ignored.
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
aggregations (_Aggregations | Literal[["nearest", "linear", "cubic", "interpolate"]):
The aggregations to perform.
If a string is provided, interpolation will be used with the specified method.
If "interpolate" is provided, the nearest neighbor interpolation will be used.
Supported methods are "nearest", "linear", and "cubic".
grid (Grid): The type of grid to use.
level (int): The level of the grid.
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
@ -649,7 +665,7 @@ def aggregate_raster_into_grid(
ongrid = _align_data(
grid_gdf,
raster,
aggregations if aggregations != "interpolate" else None,
aggregations if aggregations != "interpolate" else "nearest",
n_partitions=n_partitions,
concurrent_partitions=concurrent_partitions,
pxbuffer=pxbuffer,