diff --git a/src/entropice/dashboard/inference_page.py b/src/entropice/dashboard/inference_page.py
index c3123bc..5cdd277 100644
--- a/src/entropice/dashboard/inference_page.py
+++ b/src/entropice/dashboard/inference_page.py
@@ -1,8 +1,107 @@
+"""Inference page: Visualization of model inference results across the study region."""
+
import streamlit as st
+from entropice.dashboard.plots.inference import (
+ render_class_comparison,
+ render_class_distribution_histogram,
+ render_inference_map,
+ render_inference_statistics,
+ render_spatial_distribution_stats,
+)
+from entropice.dashboard.utils.data import load_all_training_results
+
def render_inference_page():
"""Render the Inference page of the dashboard."""
- st.title("Inference Results")
- st.write("This page will display inference results and visualizations.")
- # Add more components and visualizations as needed for inference results.
+ st.title("πΊοΈ Inference Results")
+
+ # Load all available training results
+ training_results = load_all_training_results()
+
+ if not training_results:
+ st.warning("No training results found. Please run some training experiments first.")
+ st.info("Run training using: `pixi run python -m entropice.training`")
+ return
+
+ # Sidebar: Training run selection
+ with st.sidebar:
+ st.header("Select Training Run")
+
+ # Create selection options with task-first naming
+ training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
+
+ selected_name = st.selectbox(
+ "Training Run",
+ options=list(training_options.keys()),
+ index=0,
+ help="Select a training run to view inference results",
+ )
+
+ selected_result = training_options[selected_name]
+
+ st.divider()
+
+ # Show run information in sidebar
+ st.subheader("Run Information")
+
+ settings = selected_result.settings
+
+ st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
+ st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
+ st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
+ st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
+ st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
+
+ # Main content area - Run Information at the top
+ st.header("π Run Configuration")
+
+ col1, col2, col3, col4, col5 = st.columns(5)
+ with col1:
+ st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
+ with col2:
+ st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
+ with col3:
+ st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
+ with col4:
+ st.metric("Level", selected_result.settings.get("level", "Unknown"))
+ with col5:
+ st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", ""))
+
+ st.divider()
+
+ # Check if predictions file exists
+ preds_file = selected_result.path / "predicted_probabilities.parquet"
+ if not preds_file.exists():
+ st.error("No inference results found for this training run.")
+ st.info("Inference results are generated automatically during training.")
+ return
+
+ # Load predictions for statistics
+ import geopandas as gpd
+
+ predictions_gdf = gpd.read_parquet(preds_file)
+ task = selected_result.settings.get("task", "binary")
+
+ # Inference Statistics Section
+ render_inference_statistics(predictions_gdf, task)
+
+ st.divider()
+
+ # Spatial Coverage Section
+ render_spatial_distribution_stats(predictions_gdf)
+
+ st.divider()
+
+ # 3D Map Visualization Section
+ render_inference_map(selected_result)
+
+ st.divider()
+
+ # Class Distribution Section
+ render_class_distribution_histogram(predictions_gdf, task)
+
+ st.divider()
+
+ # Class Comparison Section
+ render_class_comparison(predictions_gdf, task)
diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py
index 75381c5..18945a9 100644
--- a/src/entropice/dashboard/plots/hyperparameter_analysis.py
+++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py
@@ -1,12 +1,19 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
+from pathlib import Path
+
import altair as alt
+import antimeridian
+import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
+import pydeck as pdk
import streamlit as st
+from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap, get_palette
+from entropice.dataset import DatasetEnsemble
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
@@ -59,9 +66,8 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
# Check if refit metric exists in results
if refit_col not in results.columns:
- st.warning(
- f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
- )
+ available_metrics = [col.replace("mean_test_", "") for col in score_cols]
+ st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}")
# Use the first available metric as fallback
refit_col = score_cols[0]
refit_metric = refit_col.replace("mean_test_", "")
@@ -1092,3 +1098,268 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
st.dataframe(display_df, hide_index=True, width="stretch")
+
+
+def _fix_hex_geometry(geom):
+ """Fix hexagon geometry crossing the antimeridian."""
+ try:
+ return shape(antimeridian.fix_shape(geom))
+ except ValueError as e:
+ st.error(f"Error fixing geometry: {e}")
+ return geom
+
+
+@st.fragment
+def render_confusion_matrix_map(result_path: Path, settings: dict):
+ """Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
+
+ Args:
+ result_path: Path to the training result directory.
+ settings: Settings dictionary containing grid, level, task, and target information.
+
+ """
+ st.subheader("πΊοΈ Confusion Matrix Spatial Distribution")
+
+ # Load predicted probabilities
+ preds_file = result_path / "predicted_probabilities.parquet"
+ if not preds_file.exists():
+ st.warning("No predicted probabilities found for this training run.")
+ return
+
+ preds_gdf = gpd.read_parquet(preds_file)
+
+ # Get task and target information from settings
+ task = settings.get("task", "binary")
+ target = settings.get("target", "darts_rts")
+ grid = settings.get("grid", "hex")
+ level = settings.get("level", 3)
+
+ # Create dataset ensemble to get true labels
+ # We need to load the target data to get true labels
+ try:
+ ensemble = DatasetEnsemble(
+ grid=grid,
+ level=level,
+ target=target,
+ members=[], # We don't need feature data, just target
+ )
+ training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
+ except Exception as e:
+ st.error(f"Error loading training data: {e}")
+ return
+
+ # Get the labeled cells (those with true labels)
+ labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)]
+
+ # Merge predictions with true labels
+ # Reset index to avoid ambiguity between index and column
+ labeled_gdf = labeled_cells.copy()
+ labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"})
+ labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy()
+
+ # Merge with predictions - ensure we keep GeoDataFrame type
+ merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
+ merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
+
+ if len(merged) == 0:
+ st.warning("No matching predictions found for labeled cells.")
+ return
+
+ # Determine confusion matrix category
+ def get_confusion_category(row):
+ true_label = row["true_class"]
+ pred_label = row["predicted_class"]
+
+ if task == "binary":
+ # For binary classification
+ if true_label == "RTS" and pred_label == "RTS":
+ return "True Positive"
+ elif true_label == "RTS" and pred_label == "No-RTS":
+ return "False Negative"
+ elif true_label == "No-RTS" and pred_label == "RTS":
+ return "False Positive"
+ else: # true_label == "No-RTS" and pred_label == "No-RTS"
+ return "True Negative"
+ else:
+ # For multiclass (count/density)
+ if true_label == pred_label:
+ return "Correct"
+ else:
+ return "Incorrect"
+
+ merged["confusion_category"] = merged.apply(get_confusion_category, axis=1)
+
+ # Create controls
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ # Filter by confusion category
+ if task == "binary":
+ categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"]
+ else:
+ categories = ["All", "Correct", "Incorrect"]
+
+ selected_category = st.selectbox(
+ "Filter by Category",
+ options=categories,
+ key="confusion_map_category",
+ )
+
+ with col2:
+ opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity")
+
+ # Filter data if needed
+ if selected_category != "All":
+ display_gdf = merged[merged["confusion_category"] == selected_category].copy()
+ else:
+ display_gdf = merged.copy()
+
+ if len(display_gdf) == 0:
+ st.warning(f"No cells found for category: {selected_category}")
+ return
+
+ # Convert to WGS84 for pydeck
+ display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
+
+ # Fix antimeridian issues for hex grids
+ if grid == "hex":
+ display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry)
+
+ # Assign colors based on confusion category
+ if task == "binary":
+ color_map = {
+ "True Positive": [46, 204, 113], # Green
+ "False Positive": [231, 76, 60], # Red
+ "True Negative": [52, 152, 219], # Blue
+ "False Negative": [241, 196, 15], # Yellow
+ }
+ else:
+ color_map = {
+ "Correct": [46, 204, 113], # Green
+ "Incorrect": [231, 76, 60], # Red
+ }
+
+ display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map)
+
+ # Add elevation based on confusion category (higher for errors)
+ if task == "binary":
+ elevation_map = {
+ "True Positive": 0.8,
+ "False Positive": 1.0,
+ "True Negative": 0.3,
+ "False Negative": 1.0,
+ }
+ else:
+ elevation_map = {
+ "Correct": 0.5,
+ "Incorrect": 1.0,
+ }
+
+ display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map)
+
+ # Convert to GeoJSON format
+ geojson_data = []
+ for _, row in display_gdf_wgs84.iterrows():
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": {
+ "true_class": str(row["true_class"]),
+ "predicted_class": str(row["predicted_class"]),
+ "confusion_category": str(row["confusion_category"]),
+ "fill_color": row["fill_color"],
+ "elevation": float(row["elevation"]),
+ },
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=opacity,
+ stroked=True,
+ filled=True,
+ extruded=True,
+ wireframe=False,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ get_elevation="properties.elevation",
+ elevation_scale=500000,
+ pickable=True,
+ )
+
+ # Set initial view state (centered on the Arctic)
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
+
+ # Create deck
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={
+ "html": "True Label: {true_class}
"
+ "Predicted Label: {predicted_class}
"
+ "Category: {confusion_category}",
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ # Render the map
+ st.pydeck_chart(deck)
+
+ # Show statistics
+ col1, col2, col3 = st.columns(3)
+
+ with col1:
+ st.metric("Total Labeled Cells", len(merged))
+
+ if task == "binary":
+ with col2:
+ tp = len(merged[merged["confusion_category"] == "True Positive"])
+ fp = len(merged[merged["confusion_category"] == "False Positive"])
+ tn = len(merged[merged["confusion_category"] == "True Negative"])
+ fn = len(merged[merged["confusion_category"] == "False Negative"])
+
+ accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0
+ st.metric("Accuracy", f"{accuracy:.2%}")
+
+ with col3:
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
+ st.metric("F1 Score", f"{f1:.3f}")
+
+ # Show confusion matrix counts
+ st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}")
+ else:
+ with col2:
+ correct = len(merged[merged["confusion_category"] == "Correct"])
+ accuracy = correct / len(merged) if len(merged) > 0 else 0
+ st.metric("Accuracy", f"{accuracy:.2%}")
+
+ with col3:
+ incorrect = len(merged[merged["confusion_category"] == "Incorrect"])
+ st.metric("Incorrect", incorrect)
+
+ # Add legend
+ with st.expander("Legend", expanded=True):
+ st.markdown("**Confusion Matrix Categories:**")
+
+ for category, color in color_map.items():
+ count = len(merged[merged["confusion_category"] == category])
+ percentage = count / len(merged) * 100 if len(merged) > 0 else 0
+
+ st.markdown(
+ f'