Make inference page

This commit is contained in:
Tobias Hölzer 2025-12-25 18:59:27 +01:00
parent 1919cc6a7e
commit 6960571742
4 changed files with 813 additions and 6 deletions

View file

@ -1,8 +1,107 @@
"""Inference page: Visualization of model inference results across the study region."""
import streamlit as st
from entropice.dashboard.plots.inference import (
render_class_comparison,
render_class_distribution_histogram,
render_inference_map,
render_inference_statistics,
render_spatial_distribution_stats,
)
from entropice.dashboard.utils.data import load_all_training_results
def render_inference_page():
"""Render the Inference page of the dashboard."""
st.title("Inference Results")
st.write("This page will display inference results and visualizations.")
# Add more components and visualizations as needed for inference results.
st.title("🗺️ Inference Results")
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`")
return
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
# Create selection options with task-first naming
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to view inference results",
)
selected_result = training_options[selected_name]
st.divider()
# Show run information in sidebar
st.subheader("Run Information")
settings = selected_result.settings
st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
# Main content area - Run Information at the top
st.header("📋 Run Configuration")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
with col2:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
with col3:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
with col4:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col5:
st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", ""))
st.divider()
# Check if predictions file exists
preds_file = selected_result.path / "predicted_probabilities.parquet"
if not preds_file.exists():
st.error("No inference results found for this training run.")
st.info("Inference results are generated automatically during training.")
return
# Load predictions for statistics
import geopandas as gpd
predictions_gdf = gpd.read_parquet(preds_file)
task = selected_result.settings.get("task", "binary")
# Inference Statistics Section
render_inference_statistics(predictions_gdf, task)
st.divider()
# Spatial Coverage Section
render_spatial_distribution_stats(predictions_gdf)
st.divider()
# 3D Map Visualization Section
render_inference_map(selected_result)
st.divider()
# Class Distribution Section
render_class_distribution_histogram(predictions_gdf, task)
st.divider()
# Class Comparison Section
render_class_comparison(predictions_gdf, task)

View file

@ -1,12 +1,19 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
from pathlib import Path
import altair as alt
import antimeridian
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import pydeck as pdk
import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap, get_palette
from entropice.dataset import DatasetEnsemble
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
@ -59,9 +66,8 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
# Check if refit metric exists in results
if refit_col not in results.columns:
st.warning(
f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
)
available_metrics = [col.replace("mean_test_", "") for col in score_cols]
st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}")
# Use the first available metric as fallback
refit_col = score_cols[0]
refit_metric = refit_col.replace("mean_test_", "")
@ -1092,3 +1098,268 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
st.dataframe(display_df, hide_index=True, width="stretch")
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
@st.fragment
def render_confusion_matrix_map(result_path: Path, settings: dict):
"""Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
Args:
result_path: Path to the training result directory.
settings: Settings dictionary containing grid, level, task, and target information.
"""
st.subheader("🗺️ Confusion Matrix Spatial Distribution")
# Load predicted probabilities
preds_file = result_path / "predicted_probabilities.parquet"
if not preds_file.exists():
st.warning("No predicted probabilities found for this training run.")
return
preds_gdf = gpd.read_parquet(preds_file)
# Get task and target information from settings
task = settings.get("task", "binary")
target = settings.get("target", "darts_rts")
grid = settings.get("grid", "hex")
level = settings.get("level", 3)
# Create dataset ensemble to get true labels
# We need to load the target data to get true labels
try:
ensemble = DatasetEnsemble(
grid=grid,
level=level,
target=target,
members=[], # We don't need feature data, just target
)
training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
except Exception as e:
st.error(f"Error loading training data: {e}")
return
# Get the labeled cells (those with true labels)
labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)]
# Merge predictions with true labels
# Reset index to avoid ambiguity between index and column
labeled_gdf = labeled_cells.copy()
labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"})
labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy()
# Merge with predictions - ensure we keep GeoDataFrame type
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
if len(merged) == 0:
st.warning("No matching predictions found for labeled cells.")
return
# Determine confusion matrix category
def get_confusion_category(row):
true_label = row["true_class"]
pred_label = row["predicted_class"]
if task == "binary":
# For binary classification
if true_label == "RTS" and pred_label == "RTS":
return "True Positive"
elif true_label == "RTS" and pred_label == "No-RTS":
return "False Negative"
elif true_label == "No-RTS" and pred_label == "RTS":
return "False Positive"
else: # true_label == "No-RTS" and pred_label == "No-RTS"
return "True Negative"
else:
# For multiclass (count/density)
if true_label == pred_label:
return "Correct"
else:
return "Incorrect"
merged["confusion_category"] = merged.apply(get_confusion_category, axis=1)
# Create controls
col1, col2 = st.columns([3, 1])
with col1:
# Filter by confusion category
if task == "binary":
categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"]
else:
categories = ["All", "Correct", "Incorrect"]
selected_category = st.selectbox(
"Filter by Category",
options=categories,
key="confusion_map_category",
)
with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity")
# Filter data if needed
if selected_category != "All":
display_gdf = merged[merged["confusion_category"] == selected_category].copy()
else:
display_gdf = merged.copy()
if len(display_gdf) == 0:
st.warning(f"No cells found for category: {selected_category}")
return
# Convert to WGS84 for pydeck
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
# Fix antimeridian issues for hex grids
if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Assign colors based on confusion category
if task == "binary":
color_map = {
"True Positive": [46, 204, 113], # Green
"False Positive": [231, 76, 60], # Red
"True Negative": [52, 152, 219], # Blue
"False Negative": [241, 196, 15], # Yellow
}
else:
color_map = {
"Correct": [46, 204, 113], # Green
"Incorrect": [231, 76, 60], # Red
}
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map)
# Add elevation based on confusion category (higher for errors)
if task == "binary":
elevation_map = {
"True Positive": 0.8,
"False Positive": 1.0,
"True Negative": 0.3,
"False Negative": 1.0,
}
else:
elevation_map = {
"Correct": 0.5,
"Incorrect": 1.0,
}
display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map)
# Convert to GeoJSON format
geojson_data = []
for _, row in display_gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"true_class": str(row["true_class"]),
"predicted_class": str(row["predicted_class"]),
"confusion_category": str(row["confusion_category"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation",
elevation_scale=500000,
pickable=True,
)
# Set initial view state (centered on the Arctic)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>True Label:</b> {true_class}<br/>"
"<b>Predicted Label:</b> {predicted_class}<br/>"
"<b>Category:</b> {confusion_category}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
# Render the map
st.pydeck_chart(deck)
# Show statistics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Labeled Cells", len(merged))
if task == "binary":
with col2:
tp = len(merged[merged["confusion_category"] == "True Positive"])
fp = len(merged[merged["confusion_category"] == "False Positive"])
tn = len(merged[merged["confusion_category"] == "True Negative"])
fn = len(merged[merged["confusion_category"] == "False Negative"])
accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0
st.metric("Accuracy", f"{accuracy:.2%}")
with col3:
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
st.metric("F1 Score", f"{f1:.3f}")
# Show confusion matrix counts
st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}")
else:
with col2:
correct = len(merged[merged["confusion_category"] == "Correct"])
accuracy = correct / len(merged) if len(merged) > 0 else 0
st.metric("Accuracy", f"{accuracy:.2%}")
with col3:
incorrect = len(merged[merged["confusion_category"] == "Incorrect"])
st.metric("Incorrect", incorrect)
# Add legend
with st.expander("Legend", expanded=True):
st.markdown("**Confusion Matrix Categories:**")
for category, color in color_map.items():
count = len(merged[merged["confusion_category"] == category])
percentage = count / len(merged) * 100 if len(merged) > 0 else 0
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{category}: {count} ({percentage:.1f}%)</span></div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Elevation (3D):**")
st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.")
st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.")

View file

@ -0,0 +1,429 @@
"""Plotting functions for inference result visualizations."""
import geopandas as gpd
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_palette
from entropice.dashboard.utils.data import TrainingResult
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
import antimeridian
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render summary statistics about inference results.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.subheader("📊 Inference Summary")
# Get class distribution
class_counts = predictions_gdf["predicted_class"].value_counts()
# Create metrics layout
if task == "binary":
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
with col2:
rts_count = class_counts.get("RTS", 0)
rts_pct = rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
st.metric("RTS Predictions", f"{rts_count:,} ({rts_pct:.1f}%)")
with col3:
no_rts_count = class_counts.get("No-RTS", 0)
no_rts_pct = no_rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
st.metric("No-RTS Predictions", f"{no_rts_count:,} ({no_rts_pct:.1f}%)")
else:
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
with col2:
st.metric("Unique Classes", len(class_counts))
with col3:
most_common = class_counts.index[0] if len(class_counts) > 0 else "N/A"
st.metric("Most Common Class", most_common)
def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render histogram of predicted class distribution.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.subheader("📊 Predicted Class Distribution")
# Get class counts
class_counts = predictions_gdf["predicted_class"].value_counts().sort_index()
# Get colors based on task
categories = class_counts.index.tolist()
colors = get_palette(task, len(categories))
# Create bar chart
fig = go.Figure()
fig.add_trace(
go.Bar(
x=categories,
y=class_counts.values,
marker_color=colors,
opacity=0.9,
text=class_counts.to_numpy(),
textposition="outside",
textfont={"size": 12},
hovertemplate="<b>%{x}</b><br>Count: %{y:,}<br>Percentage: %{customdata:.1f}%<extra></extra>",
customdata=class_counts.to_numpy() / len(predictions_gdf) * 100,
)
)
fig.update_layout(
height=400,
margin={"l": 20, "r": 20, "t": 40, "b": 20},
showlegend=False,
xaxis_title="Predicted Class",
yaxis_title="Count",
xaxis={"tickangle": -45 if len(categories) > 3 else 0},
)
st.plotly_chart(fig, use_container_width=True)
# Show percentages in a table
with st.expander("📋 Detailed Class Distribution", expanded=False):
distribution_df = pd.DataFrame(
{
"Class": categories,
"Count": class_counts.to_numpy(),
"Percentage": (class_counts.to_numpy() / len(predictions_gdf) * 100).round(2),
}
)
st.dataframe(distribution_df, hide_index=True, use_container_width=True)
def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame):
"""Render spatial statistics about predictions.
Args:
predictions_gdf: GeoDataFrame with predictions.
"""
st.subheader("🌍 Spatial Coverage")
# Calculate spatial extent
bounds = predictions_gdf.total_bounds
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Min Latitude", f"{bounds[1]:.2f}°")
with col2:
st.metric("Max Latitude", f"{bounds[3]:.2f}°")
with col3:
st.metric("Min Longitude", f"{bounds[0]:.2f}°")
with col4:
st.metric("Max Longitude", f"{bounds[2]:.2f}°")
# Calculate total area if cell_area is available
if "cell_area" in predictions_gdf.columns:
total_area = predictions_gdf["cell_area"].sum()
st.info(f"📏 **Total Area Covered:** {total_area:,.0f} km²")
def _prepare_geojson_features(display_gdf_wgs84: gpd.GeoDataFrame) -> list:
"""Convert GeoDataFrame to GeoJSON features for pydeck.
Args:
display_gdf_wgs84: GeoDataFrame in WGS84 projection with required columns.
Returns:
List of GeoJSON feature dictionaries.
"""
geojson_data = []
for _, row in display_gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"cell_id": str(row["cell_id"]),
"predicted_class": str(row["predicted_class"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
return geojson_data
@st.fragment
def render_inference_map(result: TrainingResult):
"""Render 3D pydeck map showing inference results with interactive controls.
This is a Streamlit fragment that reruns independently when users interact with the
visualization controls (color mode and opacity), without re-running the entire page.
Args:
result: TrainingResult object containing prediction data.
"""
st.subheader("🗺️ Inference Results Map")
# Load predictions
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
# Get settings
task = result.settings.get("task", "binary")
grid = result.settings.get("grid", "hex")
# Create controls in columns
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
# Get unique classes for filtering
all_classes = sorted(preds_gdf["predicted_class"].unique())
filter_options = ["All Classes", *all_classes]
selected_filter = st.selectbox(
"Filter by Predicted Class",
options=filter_options,
key="inference_map_filter",
)
with col2:
use_elevation = st.checkbox(
"Enable 3D Elevation",
value=True,
help="Show predictions with elevation (requires count/density for meaningful height)",
key="inference_map_elevation",
)
with col3:
opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="inference_map_opacity",
)
# Filter data if needed
if selected_filter != "All Classes":
display_gdf = preds_gdf[preds_gdf["predicted_class"] == selected_filter].copy()
else:
display_gdf = preds_gdf.copy()
if len(display_gdf) == 0:
st.warning(f"No predictions found for filter: {selected_filter}")
return
st.info(f"Displaying {len(display_gdf):,} out of {len(preds_gdf):,} total predictions")
# Convert to WGS84 for pydeck
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
# Fix antimeridian issues for hex grids
if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Assign colors based on predicted class
colors_palette = get_palette(task, len(all_classes))
# Create color mapping for all classes
color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)}
# Convert hex colors to RGB
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip("#")
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["predicted_class"].map(
{cls: hex_to_rgb(color) for cls, color in color_map.items()}
)
# Add elevation based on class encoding (for ordered classes)
if use_elevation and len(all_classes) > 1:
# Create a normalized elevation based on class order
class_to_elevation = {cls: i / (len(all_classes) - 1) for i, cls in enumerate(all_classes)}
display_gdf_wgs84["elevation"] = display_gdf_wgs84["predicted_class"].map(class_to_elevation)
else:
display_gdf_wgs84["elevation"] = 0.0
# Convert to GeoJSON format
geojson_data = _prepare_geojson_features(display_gdf_wgs84)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=use_elevation,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if use_elevation else 0,
elevation_scale=500000, # Scale to 500km height
pickable=True,
)
# Set initial view state (centered on the Arctic)
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not use_elevation else 1.5,
pitch=0 if not use_elevation else 45,
)
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Cell ID:</b> {cell_id}<br/><b>Predicted Class:</b> {predicted_class}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
# Render the map
st.pydeck_chart(deck)
# Show info about 3D visualization
if use_elevation:
st.info("💡 3D elevation represents class order. Rotate the map by holding Ctrl/Cmd and dragging.")
# Add legend
with st.expander("Legend", expanded=True):
st.markdown("**Predicted Classes:**")
for cls in all_classes:
color_hex = color_map[cls]
count = len(display_gdf[display_gdf["predicted_class"] == cls])
total_count = len(preds_gdf[preds_gdf["predicted_class"] == cls])
percentage = total_count / len(preds_gdf) * 100 if len(preds_gdf) > 0 else 0
# Show if currently displayed or total count
if selected_filter == "All Classes":
count_str = f"{count:,} ({percentage:.1f}%)"
else:
count_str = f"{count:,} displayed / {total_count:,} total ({percentage:.1f}%)"
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: {color_hex}; '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{cls}: {count_str}</span></div>",
unsafe_allow_html=True,
)
if use_elevation and len(all_classes) > 1:
st.markdown("---")
st.markdown("**Elevation (3D):**")
st.markdown(f"Height represents class order: {all_classes[0]} (low) → {all_classes[-1]} (high)")
def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render comparison plots between different predicted classes.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.subheader("🔍 Class Comparison")
# Get class distribution
class_counts = predictions_gdf["predicted_class"].value_counts()
if len(class_counts) < 2:
st.info("Need at least 2 classes for comparison.")
return
# Create pie chart
col1, col2 = st.columns(2)
with col1:
st.markdown("**Class Proportions")
colors = get_palette(task, len(class_counts))
fig = go.Figure(
data=[
go.Pie(
labels=class_counts.index,
values=class_counts.values,
marker_colors=colors,
textinfo="label+percent",
textposition="auto",
hovertemplate="<b>%{label}</b><br>Count: %{value:,}<br>Percentage: %{percent}<extra></extra>",
)
]
)
fig.update_layout(
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True,
)
st.plotly_chart(fig, use_container_width=True)
with col2:
st.markdown("**Cumulative Distribution")
# Create cumulative distribution
sorted_counts = class_counts.sort_values(ascending=False)
cumulative = sorted_counts.cumsum()
cumulative_pct = cumulative / cumulative.iloc[-1] * 100
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=list(range(len(cumulative))),
y=cumulative_pct.to_numpy(),
mode="lines+markers",
line={"color": colors[0], "width": 3},
marker={"size": 8},
customdata=sorted_counts.index,
hovertemplate="<b>%{customdata}</b><br>Cumulative: %{y:.1f}%<extra></extra>",
)
)
fig.update_layout(
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
xaxis_title="Class Rank",
yaxis_title="Cumulative Percentage",
yaxis={"range": [0, 105]},
)
st.plotly_chart(fig, use_container_width=True)

View file

@ -4,6 +4,7 @@ import streamlit as st
from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
render_confusion_matrix_map,
render_espa_binned_parameter_space,
render_multi_metric_comparison,
render_parameter_correlation,
@ -128,6 +129,13 @@ def render_training_analysis_page():
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
st.divider()
# Quick Statistics
st.header("📈 Cross-Validation Statistics")