From 696057174284e69cedcf94be2809119145fbd9e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 25 Dec 2025 18:59:27 +0100 Subject: [PATCH] Make inference page --- src/entropice/dashboard/inference_page.py | 105 ++++- .../plots/hyperparameter_analysis.py | 277 ++++++++++- src/entropice/dashboard/plots/inference.py | 429 ++++++++++++++++++ .../dashboard/training_analysis_page.py | 8 + 4 files changed, 813 insertions(+), 6 deletions(-) create mode 100644 src/entropice/dashboard/plots/inference.py diff --git a/src/entropice/dashboard/inference_page.py b/src/entropice/dashboard/inference_page.py index c3123bc..5cdd277 100644 --- a/src/entropice/dashboard/inference_page.py +++ b/src/entropice/dashboard/inference_page.py @@ -1,8 +1,107 @@ +"""Inference page: Visualization of model inference results across the study region.""" + import streamlit as st +from entropice.dashboard.plots.inference import ( + render_class_comparison, + render_class_distribution_histogram, + render_inference_map, + render_inference_statistics, + render_spatial_distribution_stats, +) +from entropice.dashboard.utils.data import load_all_training_results + def render_inference_page(): """Render the Inference page of the dashboard.""" - st.title("Inference Results") - st.write("This page will display inference results and visualizations.") - # Add more components and visualizations as needed for inference results. + st.title("πŸ—ΊοΈ Inference Results") + + # Load all available training results + training_results = load_all_training_results() + + if not training_results: + st.warning("No training results found. Please run some training experiments first.") + st.info("Run training using: `pixi run python -m entropice.training`") + return + + # Sidebar: Training run selection + with st.sidebar: + st.header("Select Training Run") + + # Create selection options with task-first naming + training_options = {tr.get_display_name("task_first"): tr for tr in training_results} + + selected_name = st.selectbox( + "Training Run", + options=list(training_options.keys()), + index=0, + help="Select a training run to view inference results", + ) + + selected_result = training_options[selected_name] + + st.divider() + + # Show run information in sidebar + st.subheader("Run Information") + + settings = selected_result.settings + + st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}") + st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}") + st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}") + st.markdown(f"**Level:** {settings.get('level', 'Unknown')}") + st.markdown(f"**Target:** {settings.get('target', 'Unknown')}") + + # Main content area - Run Information at the top + st.header("πŸ“‹ Run Configuration") + + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize()) + with col2: + st.metric("Model", selected_result.settings.get("model", "Unknown").upper()) + with col3: + st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize()) + with col4: + st.metric("Level", selected_result.settings.get("level", "Unknown")) + with col5: + st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", "")) + + st.divider() + + # Check if predictions file exists + preds_file = selected_result.path / "predicted_probabilities.parquet" + if not preds_file.exists(): + st.error("No inference results found for this training run.") + st.info("Inference results are generated automatically during training.") + return + + # Load predictions for statistics + import geopandas as gpd + + predictions_gdf = gpd.read_parquet(preds_file) + task = selected_result.settings.get("task", "binary") + + # Inference Statistics Section + render_inference_statistics(predictions_gdf, task) + + st.divider() + + # Spatial Coverage Section + render_spatial_distribution_stats(predictions_gdf) + + st.divider() + + # 3D Map Visualization Section + render_inference_map(selected_result) + + st.divider() + + # Class Distribution Section + render_class_distribution_histogram(predictions_gdf, task) + + st.divider() + + # Class Comparison Section + render_class_comparison(predictions_gdf, task) diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 75381c5..18945a9 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -1,12 +1,19 @@ """Hyperparameter analysis plotting functions for RandomizedSearchCV results.""" +from pathlib import Path + import altair as alt +import antimeridian +import geopandas as gpd import matplotlib.colors as mcolors import numpy as np import pandas as pd +import pydeck as pdk import streamlit as st +from shapely.geometry import shape from entropice.dashboard.plots.colors import get_cmap, get_palette +from entropice.dataset import DatasetEnsemble def render_performance_summary(results: pd.DataFrame, refit_metric: str): @@ -59,9 +66,8 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str): # Check if refit metric exists in results if refit_col not in results.columns: - st.warning( - f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}" - ) + available_metrics = [col.replace("mean_test_", "") for col in score_cols] + st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}") # Use the first available metric as fallback refit_col = score_cols[0] refit_metric = refit_col.replace("mean_test_", "") @@ -1092,3 +1098,268 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1 display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}") st.dataframe(display_df, hide_index=True, width="stretch") + + +def _fix_hex_geometry(geom): + """Fix hexagon geometry crossing the antimeridian.""" + try: + return shape(antimeridian.fix_shape(geom)) + except ValueError as e: + st.error(f"Error fixing geometry: {e}") + return geom + + +@st.fragment +def render_confusion_matrix_map(result_path: Path, settings: dict): + """Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN). + + Args: + result_path: Path to the training result directory. + settings: Settings dictionary containing grid, level, task, and target information. + + """ + st.subheader("πŸ—ΊοΈ Confusion Matrix Spatial Distribution") + + # Load predicted probabilities + preds_file = result_path / "predicted_probabilities.parquet" + if not preds_file.exists(): + st.warning("No predicted probabilities found for this training run.") + return + + preds_gdf = gpd.read_parquet(preds_file) + + # Get task and target information from settings + task = settings.get("task", "binary") + target = settings.get("target", "darts_rts") + grid = settings.get("grid", "hex") + level = settings.get("level", 3) + + # Create dataset ensemble to get true labels + # We need to load the target data to get true labels + try: + ensemble = DatasetEnsemble( + grid=grid, + level=level, + target=target, + members=[], # We don't need feature data, just target + ) + training_data = ensemble.create_cat_training_dataset(task=task, device="cpu") + except Exception as e: + st.error(f"Error loading training data: {e}") + return + + # Get the labeled cells (those with true labels) + labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)] + + # Merge predictions with true labels + # Reset index to avoid ambiguity between index and column + labeled_gdf = labeled_cells.copy() + labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"}) + labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy() + + # Merge with predictions - ensure we keep GeoDataFrame type + merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner") + merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs) + + if len(merged) == 0: + st.warning("No matching predictions found for labeled cells.") + return + + # Determine confusion matrix category + def get_confusion_category(row): + true_label = row["true_class"] + pred_label = row["predicted_class"] + + if task == "binary": + # For binary classification + if true_label == "RTS" and pred_label == "RTS": + return "True Positive" + elif true_label == "RTS" and pred_label == "No-RTS": + return "False Negative" + elif true_label == "No-RTS" and pred_label == "RTS": + return "False Positive" + else: # true_label == "No-RTS" and pred_label == "No-RTS" + return "True Negative" + else: + # For multiclass (count/density) + if true_label == pred_label: + return "Correct" + else: + return "Incorrect" + + merged["confusion_category"] = merged.apply(get_confusion_category, axis=1) + + # Create controls + col1, col2 = st.columns([3, 1]) + + with col1: + # Filter by confusion category + if task == "binary": + categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"] + else: + categories = ["All", "Correct", "Incorrect"] + + selected_category = st.selectbox( + "Filter by Category", + options=categories, + key="confusion_map_category", + ) + + with col2: + opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity") + + # Filter data if needed + if selected_category != "All": + display_gdf = merged[merged["confusion_category"] == selected_category].copy() + else: + display_gdf = merged.copy() + + if len(display_gdf) == 0: + st.warning(f"No cells found for category: {selected_category}") + return + + # Convert to WGS84 for pydeck + display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues for hex grids + if grid == "hex": + display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Assign colors based on confusion category + if task == "binary": + color_map = { + "True Positive": [46, 204, 113], # Green + "False Positive": [231, 76, 60], # Red + "True Negative": [52, 152, 219], # Blue + "False Negative": [241, 196, 15], # Yellow + } + else: + color_map = { + "Correct": [46, 204, 113], # Green + "Incorrect": [231, 76, 60], # Red + } + + display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map) + + # Add elevation based on confusion category (higher for errors) + if task == "binary": + elevation_map = { + "True Positive": 0.8, + "False Positive": 1.0, + "True Negative": 0.3, + "False Negative": 1.0, + } + else: + elevation_map = { + "Correct": 0.5, + "Incorrect": 1.0, + } + + display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map) + + # Convert to GeoJSON format + geojson_data = [] + for _, row in display_gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "true_class": str(row["true_class"]), + "predicted_class": str(row["predicted_class"]), + "confusion_category": str(row["confusion_category"]), + "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + extruded=True, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation", + elevation_scale=500000, + pickable=True, + ) + + # Set initial view state (centered on the Arctic) + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": "True Label: {true_class}
" + "Predicted Label: {predicted_class}
" + "Category: {confusion_category}", + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + # Render the map + st.pydeck_chart(deck) + + # Show statistics + col1, col2, col3 = st.columns(3) + + with col1: + st.metric("Total Labeled Cells", len(merged)) + + if task == "binary": + with col2: + tp = len(merged[merged["confusion_category"] == "True Positive"]) + fp = len(merged[merged["confusion_category"] == "False Positive"]) + tn = len(merged[merged["confusion_category"] == "True Negative"]) + fn = len(merged[merged["confusion_category"] == "False Negative"]) + + accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0 + st.metric("Accuracy", f"{accuracy:.2%}") + + with col3: + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + st.metric("F1 Score", f"{f1:.3f}") + + # Show confusion matrix counts + st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}") + else: + with col2: + correct = len(merged[merged["confusion_category"] == "Correct"]) + accuracy = correct / len(merged) if len(merged) > 0 else 0 + st.metric("Accuracy", f"{accuracy:.2%}") + + with col3: + incorrect = len(merged[merged["confusion_category"] == "Incorrect"]) + st.metric("Incorrect", incorrect) + + # Add legend + with st.expander("Legend", expanded=True): + st.markdown("**Confusion Matrix Categories:**") + + for category, color in color_map.items(): + count = len(merged[merged["confusion_category"] == category]) + percentage = count / len(merged) * 100 if len(merged) > 0 else 0 + + st.markdown( + f'
' + f'
' + f"{category}: {count} ({percentage:.1f}%)
", + unsafe_allow_html=True, + ) + + st.markdown("---") + st.markdown("**Elevation (3D):**") + st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.") + st.info("πŸ’‘ Rotate the map by holding Ctrl/Cmd and dragging.") diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py new file mode 100644 index 0000000..b6c8720 --- /dev/null +++ b/src/entropice/dashboard/plots/inference.py @@ -0,0 +1,429 @@ +"""Plotting functions for inference result visualizations.""" + +import geopandas as gpd +import pandas as pd +import plotly.graph_objects as go +import pydeck as pdk +import streamlit as st +from shapely.geometry import shape + +from entropice.dashboard.plots.colors import get_palette +from entropice.dashboard.utils.data import TrainingResult + + +def _fix_hex_geometry(geom): + """Fix hexagon geometry crossing the antimeridian.""" + import antimeridian + + try: + return shape(antimeridian.fix_shape(geom)) + except ValueError as e: + st.error(f"Error fixing geometry: {e}") + return geom + + +def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render summary statistics about inference results. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.subheader("πŸ“Š Inference Summary") + + # Get class distribution + class_counts = predictions_gdf["predicted_class"].value_counts() + + # Create metrics layout + if task == "binary": + col1, col2, col3 = st.columns(3) + + with col1: + st.metric("Total Predictions", f"{len(predictions_gdf):,}") + + with col2: + rts_count = class_counts.get("RTS", 0) + rts_pct = rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0 + st.metric("RTS Predictions", f"{rts_count:,} ({rts_pct:.1f}%)") + + with col3: + no_rts_count = class_counts.get("No-RTS", 0) + no_rts_pct = no_rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0 + st.metric("No-RTS Predictions", f"{no_rts_count:,} ({no_rts_pct:.1f}%)") + else: + col1, col2, col3 = st.columns(3) + + with col1: + st.metric("Total Predictions", f"{len(predictions_gdf):,}") + + with col2: + st.metric("Unique Classes", len(class_counts)) + + with col3: + most_common = class_counts.index[0] if len(class_counts) > 0 else "N/A" + st.metric("Most Common Class", most_common) + + +def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render histogram of predicted class distribution. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.subheader("πŸ“Š Predicted Class Distribution") + + # Get class counts + class_counts = predictions_gdf["predicted_class"].value_counts().sort_index() + + # Get colors based on task + categories = class_counts.index.tolist() + colors = get_palette(task, len(categories)) + + # Create bar chart + fig = go.Figure() + + fig.add_trace( + go.Bar( + x=categories, + y=class_counts.values, + marker_color=colors, + opacity=0.9, + text=class_counts.to_numpy(), + textposition="outside", + textfont={"size": 12}, + hovertemplate="%{x}
Count: %{y:,}
Percentage: %{customdata:.1f}%", + customdata=class_counts.to_numpy() / len(predictions_gdf) * 100, + ) + ) + + fig.update_layout( + height=400, + margin={"l": 20, "r": 20, "t": 40, "b": 20}, + showlegend=False, + xaxis_title="Predicted Class", + yaxis_title="Count", + xaxis={"tickangle": -45 if len(categories) > 3 else 0}, + ) + + st.plotly_chart(fig, use_container_width=True) + + # Show percentages in a table + with st.expander("πŸ“‹ Detailed Class Distribution", expanded=False): + distribution_df = pd.DataFrame( + { + "Class": categories, + "Count": class_counts.to_numpy(), + "Percentage": (class_counts.to_numpy() / len(predictions_gdf) * 100).round(2), + } + ) + st.dataframe(distribution_df, hide_index=True, use_container_width=True) + + +def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame): + """Render spatial statistics about predictions. + + Args: + predictions_gdf: GeoDataFrame with predictions. + + """ + st.subheader("🌍 Spatial Coverage") + + # Calculate spatial extent + bounds = predictions_gdf.total_bounds + + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("Min Latitude", f"{bounds[1]:.2f}Β°") + + with col2: + st.metric("Max Latitude", f"{bounds[3]:.2f}Β°") + + with col3: + st.metric("Min Longitude", f"{bounds[0]:.2f}Β°") + + with col4: + st.metric("Max Longitude", f"{bounds[2]:.2f}Β°") + + # Calculate total area if cell_area is available + if "cell_area" in predictions_gdf.columns: + total_area = predictions_gdf["cell_area"].sum() + st.info(f"πŸ“ **Total Area Covered:** {total_area:,.0f} kmΒ²") + + +def _prepare_geojson_features(display_gdf_wgs84: gpd.GeoDataFrame) -> list: + """Convert GeoDataFrame to GeoJSON features for pydeck. + + Args: + display_gdf_wgs84: GeoDataFrame in WGS84 projection with required columns. + + Returns: + List of GeoJSON feature dictionaries. + + """ + geojson_data = [] + for _, row in display_gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "cell_id": str(row["cell_id"]), + "predicted_class": str(row["predicted_class"]), + "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), + }, + } + geojson_data.append(feature) + return geojson_data + + +@st.fragment +def render_inference_map(result: TrainingResult): + """Render 3D pydeck map showing inference results with interactive controls. + + This is a Streamlit fragment that reruns independently when users interact with the + visualization controls (color mode and opacity), without re-running the entire page. + + Args: + result: TrainingResult object containing prediction data. + + """ + st.subheader("πŸ—ΊοΈ Inference Results Map") + + # Load predictions + preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet") + + # Get settings + task = result.settings.get("task", "binary") + grid = result.settings.get("grid", "hex") + + # Create controls in columns + col1, col2, col3 = st.columns([2, 2, 1]) + + with col1: + # Get unique classes for filtering + all_classes = sorted(preds_gdf["predicted_class"].unique()) + filter_options = ["All Classes", *all_classes] + + selected_filter = st.selectbox( + "Filter by Predicted Class", + options=filter_options, + key="inference_map_filter", + ) + + with col2: + use_elevation = st.checkbox( + "Enable 3D Elevation", + value=True, + help="Show predictions with elevation (requires count/density for meaningful height)", + key="inference_map_elevation", + ) + + with col3: + opacity = st.slider( + "Opacity", + min_value=0.1, + max_value=1.0, + value=0.7, + step=0.1, + key="inference_map_opacity", + ) + + # Filter data if needed + if selected_filter != "All Classes": + display_gdf = preds_gdf[preds_gdf["predicted_class"] == selected_filter].copy() + else: + display_gdf = preds_gdf.copy() + + if len(display_gdf) == 0: + st.warning(f"No predictions found for filter: {selected_filter}") + return + + st.info(f"Displaying {len(display_gdf):,} out of {len(preds_gdf):,} total predictions") + + # Convert to WGS84 for pydeck + display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues for hex grids + if grid == "hex": + display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Assign colors based on predicted class + colors_palette = get_palette(task, len(all_classes)) + + # Create color mapping for all classes + color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)} + + # Convert hex colors to RGB + def hex_to_rgb(hex_color): + hex_color = hex_color.lstrip("#") + return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)] + + display_gdf_wgs84["fill_color"] = display_gdf_wgs84["predicted_class"].map( + {cls: hex_to_rgb(color) for cls, color in color_map.items()} + ) + + # Add elevation based on class encoding (for ordered classes) + if use_elevation and len(all_classes) > 1: + # Create a normalized elevation based on class order + class_to_elevation = {cls: i / (len(all_classes) - 1) for i, cls in enumerate(all_classes)} + display_gdf_wgs84["elevation"] = display_gdf_wgs84["predicted_class"].map(class_to_elevation) + else: + display_gdf_wgs84["elevation"] = 0.0 + + # Convert to GeoJSON format + geojson_data = _prepare_geojson_features(display_gdf_wgs84) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + extruded=use_elevation, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation" if use_elevation else 0, + elevation_scale=500000, # Scale to 500km height + pickable=True, + ) + + # Set initial view state (centered on the Arctic) + view_state = pdk.ViewState( + latitude=70, + longitude=0, + zoom=2 if not use_elevation else 1.5, + pitch=0 if not use_elevation else 45, + ) + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": "Cell ID: {cell_id}
Predicted Class: {predicted_class}", + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + # Render the map + st.pydeck_chart(deck) + + # Show info about 3D visualization + if use_elevation: + st.info("πŸ’‘ 3D elevation represents class order. Rotate the map by holding Ctrl/Cmd and dragging.") + + # Add legend + with st.expander("Legend", expanded=True): + st.markdown("**Predicted Classes:**") + + for cls in all_classes: + color_hex = color_map[cls] + count = len(display_gdf[display_gdf["predicted_class"] == cls]) + total_count = len(preds_gdf[preds_gdf["predicted_class"] == cls]) + percentage = total_count / len(preds_gdf) * 100 if len(preds_gdf) > 0 else 0 + + # Show if currently displayed or total count + if selected_filter == "All Classes": + count_str = f"{count:,} ({percentage:.1f}%)" + else: + count_str = f"{count:,} displayed / {total_count:,} total ({percentage:.1f}%)" + + st.markdown( + f'
' + f'
' + f"{cls}: {count_str}
", + unsafe_allow_html=True, + ) + + if use_elevation and len(all_classes) > 1: + st.markdown("---") + st.markdown("**Elevation (3D):**") + st.markdown(f"Height represents class order: {all_classes[0]} (low) β†’ {all_classes[-1]} (high)") + + +def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render comparison plots between different predicted classes. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.subheader("πŸ” Class Comparison") + + # Get class distribution + class_counts = predictions_gdf["predicted_class"].value_counts() + + if len(class_counts) < 2: + st.info("Need at least 2 classes for comparison.") + return + + # Create pie chart + col1, col2 = st.columns(2) + + with col1: + st.markdown("**Class Proportions") + + colors = get_palette(task, len(class_counts)) + + fig = go.Figure( + data=[ + go.Pie( + labels=class_counts.index, + values=class_counts.values, + marker_colors=colors, + textinfo="label+percent", + textposition="auto", + hovertemplate="%{label}
Count: %{value:,}
Percentage: %{percent}", + ) + ] + ) + + fig.update_layout( + height=400, + margin={"l": 20, "r": 20, "t": 20, "b": 20}, + showlegend=True, + ) + + st.plotly_chart(fig, use_container_width=True) + + with col2: + st.markdown("**Cumulative Distribution") + + # Create cumulative distribution + sorted_counts = class_counts.sort_values(ascending=False) + cumulative = sorted_counts.cumsum() + cumulative_pct = cumulative / cumulative.iloc[-1] * 100 + + fig = go.Figure() + + fig.add_trace( + go.Scatter( + x=list(range(len(cumulative))), + y=cumulative_pct.to_numpy(), + mode="lines+markers", + line={"color": colors[0], "width": 3}, + marker={"size": 8}, + customdata=sorted_counts.index, + hovertemplate="%{customdata}
Cumulative: %{y:.1f}%", + ) + ) + + fig.update_layout( + height=400, + margin={"l": 20, "r": 20, "t": 20, "b": 20}, + xaxis_title="Class Rank", + yaxis_title="Cumulative Percentage", + yaxis={"range": [0, 105]}, + ) + + st.plotly_chart(fig, use_container_width=True) diff --git a/src/entropice/dashboard/training_analysis_page.py b/src/entropice/dashboard/training_analysis_page.py index 1a72de7..e2484ff 100644 --- a/src/entropice/dashboard/training_analysis_page.py +++ b/src/entropice/dashboard/training_analysis_page.py @@ -4,6 +4,7 @@ import streamlit as st from entropice.dashboard.plots.hyperparameter_analysis import ( render_binned_parameter_space, + render_confusion_matrix_map, render_espa_binned_parameter_space, render_multi_metric_comparison, render_parameter_correlation, @@ -128,6 +129,13 @@ def render_training_analysis_page(): st.divider() + # Confusion Matrix Map Section + st.header("πŸ—ΊοΈ Prediction Results Map") + + render_confusion_matrix_map(selected_result.path, settings) + + st.divider() + # Quick Statistics st.header("πŸ“ˆ Cross-Validation Statistics")