Make inference page
This commit is contained in:
parent
1919cc6a7e
commit
6960571742
4 changed files with 813 additions and 6 deletions
|
|
@ -1,8 +1,107 @@
|
|||
"""Inference page: Visualization of model inference results across the study region."""
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.inference import (
|
||||
render_class_comparison,
|
||||
render_class_distribution_histogram,
|
||||
render_inference_map,
|
||||
render_inference_statistics,
|
||||
render_spatial_distribution_stats,
|
||||
)
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
|
||||
|
||||
def render_inference_page():
|
||||
"""Render the Inference page of the dashboard."""
|
||||
st.title("Inference Results")
|
||||
st.write("This page will display inference results and visualizations.")
|
||||
# Add more components and visualizations as needed for inference results.
|
||||
st.title("🗺️ Inference Results")
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.training`")
|
||||
return
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options with task-first naming
|
||||
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(training_options.keys()),
|
||||
index=0,
|
||||
help="Select a training run to view inference results",
|
||||
)
|
||||
|
||||
selected_result = training_options[selected_name]
|
||||
|
||||
st.divider()
|
||||
|
||||
# Show run information in sidebar
|
||||
st.subheader("Run Information")
|
||||
|
||||
settings = selected_result.settings
|
||||
|
||||
st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
|
||||
st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
|
||||
st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
|
||||
st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
|
||||
st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
|
||||
|
||||
# Main content area - Run Information at the top
|
||||
st.header("📋 Run Configuration")
|
||||
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
|
||||
with col2:
|
||||
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
|
||||
with col3:
|
||||
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
|
||||
with col4:
|
||||
st.metric("Level", selected_result.settings.get("level", "Unknown"))
|
||||
with col5:
|
||||
st.metric("Target", selected_result.settings.get("target", "Unknown").replace("darts_", ""))
|
||||
|
||||
st.divider()
|
||||
|
||||
# Check if predictions file exists
|
||||
preds_file = selected_result.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
st.error("No inference results found for this training run.")
|
||||
st.info("Inference results are generated automatically during training.")
|
||||
return
|
||||
|
||||
# Load predictions for statistics
|
||||
import geopandas as gpd
|
||||
|
||||
predictions_gdf = gpd.read_parquet(preds_file)
|
||||
task = selected_result.settings.get("task", "binary")
|
||||
|
||||
# Inference Statistics Section
|
||||
render_inference_statistics(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Spatial Coverage Section
|
||||
render_spatial_distribution_stats(predictions_gdf)
|
||||
|
||||
st.divider()
|
||||
|
||||
# 3D Map Visualization Section
|
||||
render_inference_map(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Distribution Section
|
||||
render_class_distribution_histogram(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Comparison Section
|
||||
render_class_comparison(predictions_gdf, task)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import altair as alt
|
||||
import antimeridian
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_cmap, get_palette
|
||||
from entropice.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||
|
|
@ -59,9 +66,8 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
|||
|
||||
# Check if refit metric exists in results
|
||||
if refit_col not in results.columns:
|
||||
st.warning(
|
||||
f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
|
||||
)
|
||||
available_metrics = [col.replace("mean_test_", "") for col in score_cols]
|
||||
st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}")
|
||||
# Use the first available metric as fallback
|
||||
refit_col = score_cols[0]
|
||||
refit_metric = refit_col.replace("mean_test_", "")
|
||||
|
|
@ -1092,3 +1098,268 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
|||
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian."""
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
st.error(f"Error fixing geometry: {e}")
|
||||
return geom
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_confusion_matrix_map(result_path: Path, settings: dict):
|
||||
"""Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
|
||||
|
||||
Args:
|
||||
result_path: Path to the training result directory.
|
||||
settings: Settings dictionary containing grid, level, task, and target information.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Confusion Matrix Spatial Distribution")
|
||||
|
||||
# Load predicted probabilities
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
st.warning("No predicted probabilities found for this training run.")
|
||||
return
|
||||
|
||||
preds_gdf = gpd.read_parquet(preds_file)
|
||||
|
||||
# Get task and target information from settings
|
||||
task = settings.get("task", "binary")
|
||||
target = settings.get("target", "darts_rts")
|
||||
grid = settings.get("grid", "hex")
|
||||
level = settings.get("level", 3)
|
||||
|
||||
# Create dataset ensemble to get true labels
|
||||
# We need to load the target data to get true labels
|
||||
try:
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid,
|
||||
level=level,
|
||||
target=target,
|
||||
members=[], # We don't need feature data, just target
|
||||
)
|
||||
training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
|
||||
except Exception as e:
|
||||
st.error(f"Error loading training data: {e}")
|
||||
return
|
||||
|
||||
# Get the labeled cells (those with true labels)
|
||||
labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)]
|
||||
|
||||
# Merge predictions with true labels
|
||||
# Reset index to avoid ambiguity between index and column
|
||||
labeled_gdf = labeled_cells.copy()
|
||||
labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"})
|
||||
labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy()
|
||||
|
||||
# Merge with predictions - ensure we keep GeoDataFrame type
|
||||
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
|
||||
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
|
||||
|
||||
if len(merged) == 0:
|
||||
st.warning("No matching predictions found for labeled cells.")
|
||||
return
|
||||
|
||||
# Determine confusion matrix category
|
||||
def get_confusion_category(row):
|
||||
true_label = row["true_class"]
|
||||
pred_label = row["predicted_class"]
|
||||
|
||||
if task == "binary":
|
||||
# For binary classification
|
||||
if true_label == "RTS" and pred_label == "RTS":
|
||||
return "True Positive"
|
||||
elif true_label == "RTS" and pred_label == "No-RTS":
|
||||
return "False Negative"
|
||||
elif true_label == "No-RTS" and pred_label == "RTS":
|
||||
return "False Positive"
|
||||
else: # true_label == "No-RTS" and pred_label == "No-RTS"
|
||||
return "True Negative"
|
||||
else:
|
||||
# For multiclass (count/density)
|
||||
if true_label == pred_label:
|
||||
return "Correct"
|
||||
else:
|
||||
return "Incorrect"
|
||||
|
||||
merged["confusion_category"] = merged.apply(get_confusion_category, axis=1)
|
||||
|
||||
# Create controls
|
||||
col1, col2 = st.columns([3, 1])
|
||||
|
||||
with col1:
|
||||
# Filter by confusion category
|
||||
if task == "binary":
|
||||
categories = ["All", "True Positive", "False Positive", "True Negative", "False Negative"]
|
||||
else:
|
||||
categories = ["All", "Correct", "Incorrect"]
|
||||
|
||||
selected_category = st.selectbox(
|
||||
"Filter by Category",
|
||||
options=categories,
|
||||
key="confusion_map_category",
|
||||
)
|
||||
|
||||
with col2:
|
||||
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="confusion_map_opacity")
|
||||
|
||||
# Filter data if needed
|
||||
if selected_category != "All":
|
||||
display_gdf = merged[merged["confusion_category"] == selected_category].copy()
|
||||
else:
|
||||
display_gdf = merged.copy()
|
||||
|
||||
if len(display_gdf) == 0:
|
||||
st.warning(f"No cells found for category: {selected_category}")
|
||||
return
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues for hex grids
|
||||
if grid == "hex":
|
||||
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Assign colors based on confusion category
|
||||
if task == "binary":
|
||||
color_map = {
|
||||
"True Positive": [46, 204, 113], # Green
|
||||
"False Positive": [231, 76, 60], # Red
|
||||
"True Negative": [52, 152, 219], # Blue
|
||||
"False Negative": [241, 196, 15], # Yellow
|
||||
}
|
||||
else:
|
||||
color_map = {
|
||||
"Correct": [46, 204, 113], # Green
|
||||
"Incorrect": [231, 76, 60], # Red
|
||||
}
|
||||
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map)
|
||||
|
||||
# Add elevation based on confusion category (higher for errors)
|
||||
if task == "binary":
|
||||
elevation_map = {
|
||||
"True Positive": 0.8,
|
||||
"False Positive": 1.0,
|
||||
"True Negative": 0.3,
|
||||
"False Negative": 1.0,
|
||||
}
|
||||
else:
|
||||
elevation_map = {
|
||||
"Correct": 0.5,
|
||||
"Incorrect": 1.0,
|
||||
}
|
||||
|
||||
display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map)
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = []
|
||||
for _, row in display_gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"true_class": str(row["true_class"]),
|
||||
"predicted_class": str(row["predicted_class"]),
|
||||
"confusion_category": str(row["confusion_category"]),
|
||||
"fill_color": row["fill_color"],
|
||||
"elevation": float(row["elevation"]),
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=True,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation",
|
||||
elevation_scale=500000,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state (centered on the Arctic)
|
||||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>True Label:</b> {true_class}<br/>"
|
||||
"<b>Predicted Label:</b> {predicted_class}<br/>"
|
||||
"<b>Category:</b> {confusion_category}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
# Render the map
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Labeled Cells", len(merged))
|
||||
|
||||
if task == "binary":
|
||||
with col2:
|
||||
tp = len(merged[merged["confusion_category"] == "True Positive"])
|
||||
fp = len(merged[merged["confusion_category"] == "False Positive"])
|
||||
tn = len(merged[merged["confusion_category"] == "True Negative"])
|
||||
fn = len(merged[merged["confusion_category"] == "False Negative"])
|
||||
|
||||
accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
|
||||
with col3:
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
st.metric("F1 Score", f"{f1:.3f}")
|
||||
|
||||
# Show confusion matrix counts
|
||||
st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}")
|
||||
else:
|
||||
with col2:
|
||||
correct = len(merged[merged["confusion_category"] == "Correct"])
|
||||
accuracy = correct / len(merged) if len(merged) > 0 else 0
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
|
||||
with col3:
|
||||
incorrect = len(merged[merged["confusion_category"] == "Incorrect"])
|
||||
st.metric("Incorrect", incorrect)
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown("**Confusion Matrix Categories:**")
|
||||
|
||||
for category, color in color_map.items():
|
||||
count = len(merged[merged["confusion_category"] == category])
|
||||
percentage = count / len(merged) * 100 if len(merged) > 0 else 0
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>{category}: {count} ({percentage:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("**Elevation (3D):**")
|
||||
st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.")
|
||||
st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
|
|
|||
429
src/entropice/dashboard/plots/inference.py
Normal file
429
src/entropice/dashboard/plots/inference.py
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
"""Plotting functions for inference result visualizations."""
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_palette
|
||||
from entropice.dashboard.utils.data import TrainingResult
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian."""
|
||||
import antimeridian
|
||||
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
st.error(f"Error fixing geometry: {e}")
|
||||
return geom
|
||||
|
||||
|
||||
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render summary statistics about inference results.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("📊 Inference Summary")
|
||||
|
||||
# Get class distribution
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
|
||||
# Create metrics layout
|
||||
if task == "binary":
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
|
||||
|
||||
with col2:
|
||||
rts_count = class_counts.get("RTS", 0)
|
||||
rts_pct = rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
|
||||
st.metric("RTS Predictions", f"{rts_count:,} ({rts_pct:.1f}%)")
|
||||
|
||||
with col3:
|
||||
no_rts_count = class_counts.get("No-RTS", 0)
|
||||
no_rts_pct = no_rts_count / len(predictions_gdf) * 100 if len(predictions_gdf) > 0 else 0
|
||||
st.metric("No-RTS Predictions", f"{no_rts_count:,} ({no_rts_pct:.1f}%)")
|
||||
else:
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Predictions", f"{len(predictions_gdf):,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Unique Classes", len(class_counts))
|
||||
|
||||
with col3:
|
||||
most_common = class_counts.index[0] if len(class_counts) > 0 else "N/A"
|
||||
st.metric("Most Common Class", most_common)
|
||||
|
||||
|
||||
def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render histogram of predicted class distribution.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("📊 Predicted Class Distribution")
|
||||
|
||||
# Get class counts
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts().sort_index()
|
||||
|
||||
# Get colors based on task
|
||||
categories = class_counts.index.tolist()
|
||||
colors = get_palette(task, len(categories))
|
||||
|
||||
# Create bar chart
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=categories,
|
||||
y=class_counts.values,
|
||||
marker_color=colors,
|
||||
opacity=0.9,
|
||||
text=class_counts.to_numpy(),
|
||||
textposition="outside",
|
||||
textfont={"size": 12},
|
||||
hovertemplate="<b>%{x}</b><br>Count: %{y:,}<br>Percentage: %{customdata:.1f}%<extra></extra>",
|
||||
customdata=class_counts.to_numpy() / len(predictions_gdf) * 100,
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 40, "b": 20},
|
||||
showlegend=False,
|
||||
xaxis_title="Predicted Class",
|
||||
yaxis_title="Count",
|
||||
xaxis={"tickangle": -45 if len(categories) > 3 else 0},
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Show percentages in a table
|
||||
with st.expander("📋 Detailed Class Distribution", expanded=False):
|
||||
distribution_df = pd.DataFrame(
|
||||
{
|
||||
"Class": categories,
|
||||
"Count": class_counts.to_numpy(),
|
||||
"Percentage": (class_counts.to_numpy() / len(predictions_gdf) * 100).round(2),
|
||||
}
|
||||
)
|
||||
st.dataframe(distribution_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame):
|
||||
"""Render spatial statistics about predictions.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
|
||||
"""
|
||||
st.subheader("🌍 Spatial Coverage")
|
||||
|
||||
# Calculate spatial extent
|
||||
bounds = predictions_gdf.total_bounds
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Min Latitude", f"{bounds[1]:.2f}°")
|
||||
|
||||
with col2:
|
||||
st.metric("Max Latitude", f"{bounds[3]:.2f}°")
|
||||
|
||||
with col3:
|
||||
st.metric("Min Longitude", f"{bounds[0]:.2f}°")
|
||||
|
||||
with col4:
|
||||
st.metric("Max Longitude", f"{bounds[2]:.2f}°")
|
||||
|
||||
# Calculate total area if cell_area is available
|
||||
if "cell_area" in predictions_gdf.columns:
|
||||
total_area = predictions_gdf["cell_area"].sum()
|
||||
st.info(f"📏 **Total Area Covered:** {total_area:,.0f} km²")
|
||||
|
||||
|
||||
def _prepare_geojson_features(display_gdf_wgs84: gpd.GeoDataFrame) -> list:
|
||||
"""Convert GeoDataFrame to GeoJSON features for pydeck.
|
||||
|
||||
Args:
|
||||
display_gdf_wgs84: GeoDataFrame in WGS84 projection with required columns.
|
||||
|
||||
Returns:
|
||||
List of GeoJSON feature dictionaries.
|
||||
|
||||
"""
|
||||
geojson_data = []
|
||||
for _, row in display_gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"cell_id": str(row["cell_id"]),
|
||||
"predicted_class": str(row["predicted_class"]),
|
||||
"fill_color": row["fill_color"],
|
||||
"elevation": float(row["elevation"]),
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
return geojson_data
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_inference_map(result: TrainingResult):
|
||||
"""Render 3D pydeck map showing inference results with interactive controls.
|
||||
|
||||
This is a Streamlit fragment that reruns independently when users interact with the
|
||||
visualization controls (color mode and opacity), without re-running the entire page.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object containing prediction data.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Inference Results Map")
|
||||
|
||||
# Load predictions
|
||||
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
|
||||
|
||||
# Get settings
|
||||
task = result.settings.get("task", "binary")
|
||||
grid = result.settings.get("grid", "hex")
|
||||
|
||||
# Create controls in columns
|
||||
col1, col2, col3 = st.columns([2, 2, 1])
|
||||
|
||||
with col1:
|
||||
# Get unique classes for filtering
|
||||
all_classes = sorted(preds_gdf["predicted_class"].unique())
|
||||
filter_options = ["All Classes", *all_classes]
|
||||
|
||||
selected_filter = st.selectbox(
|
||||
"Filter by Predicted Class",
|
||||
options=filter_options,
|
||||
key="inference_map_filter",
|
||||
)
|
||||
|
||||
with col2:
|
||||
use_elevation = st.checkbox(
|
||||
"Enable 3D Elevation",
|
||||
value=True,
|
||||
help="Show predictions with elevation (requires count/density for meaningful height)",
|
||||
key="inference_map_elevation",
|
||||
)
|
||||
|
||||
with col3:
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="inference_map_opacity",
|
||||
)
|
||||
|
||||
# Filter data if needed
|
||||
if selected_filter != "All Classes":
|
||||
display_gdf = preds_gdf[preds_gdf["predicted_class"] == selected_filter].copy()
|
||||
else:
|
||||
display_gdf = preds_gdf.copy()
|
||||
|
||||
if len(display_gdf) == 0:
|
||||
st.warning(f"No predictions found for filter: {selected_filter}")
|
||||
return
|
||||
|
||||
st.info(f"Displaying {len(display_gdf):,} out of {len(preds_gdf):,} total predictions")
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues for hex grids
|
||||
if grid == "hex":
|
||||
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Assign colors based on predicted class
|
||||
colors_palette = get_palette(task, len(all_classes))
|
||||
|
||||
# Create color mapping for all classes
|
||||
color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)}
|
||||
|
||||
# Convert hex colors to RGB
|
||||
def hex_to_rgb(hex_color):
|
||||
hex_color = hex_color.lstrip("#")
|
||||
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
||||
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["predicted_class"].map(
|
||||
{cls: hex_to_rgb(color) for cls, color in color_map.items()}
|
||||
)
|
||||
|
||||
# Add elevation based on class encoding (for ordered classes)
|
||||
if use_elevation and len(all_classes) > 1:
|
||||
# Create a normalized elevation based on class order
|
||||
class_to_elevation = {cls: i / (len(all_classes) - 1) for i, cls in enumerate(all_classes)}
|
||||
display_gdf_wgs84["elevation"] = display_gdf_wgs84["predicted_class"].map(class_to_elevation)
|
||||
else:
|
||||
display_gdf_wgs84["elevation"] = 0.0
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = _prepare_geojson_features(display_gdf_wgs84)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=use_elevation,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation" if use_elevation else 0,
|
||||
elevation_scale=500000, # Scale to 500km height
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state (centered on the Arctic)
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not use_elevation else 1.5,
|
||||
pitch=0 if not use_elevation else 45,
|
||||
)
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>Cell ID:</b> {cell_id}<br/><b>Predicted Class:</b> {predicted_class}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
# Render the map
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show info about 3D visualization
|
||||
if use_elevation:
|
||||
st.info("💡 3D elevation represents class order. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown("**Predicted Classes:**")
|
||||
|
||||
for cls in all_classes:
|
||||
color_hex = color_map[cls]
|
||||
count = len(display_gdf[display_gdf["predicted_class"] == cls])
|
||||
total_count = len(preds_gdf[preds_gdf["predicted_class"] == cls])
|
||||
percentage = total_count / len(preds_gdf) * 100 if len(preds_gdf) > 0 else 0
|
||||
|
||||
# Show if currently displayed or total count
|
||||
if selected_filter == "All Classes":
|
||||
count_str = f"{count:,} ({percentage:.1f}%)"
|
||||
else:
|
||||
count_str = f"{count:,} displayed / {total_count:,} total ({percentage:.1f}%)"
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: {color_hex}; '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>{cls}: {count_str}</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
if use_elevation and len(all_classes) > 1:
|
||||
st.markdown("---")
|
||||
st.markdown("**Elevation (3D):**")
|
||||
st.markdown(f"Height represents class order: {all_classes[0]} (low) → {all_classes[-1]} (high)")
|
||||
|
||||
|
||||
def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render comparison plots between different predicted classes.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.subheader("🔍 Class Comparison")
|
||||
|
||||
# Get class distribution
|
||||
class_counts = predictions_gdf["predicted_class"].value_counts()
|
||||
|
||||
if len(class_counts) < 2:
|
||||
st.info("Need at least 2 classes for comparison.")
|
||||
return
|
||||
|
||||
# Create pie chart
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("**Class Proportions")
|
||||
|
||||
colors = get_palette(task, len(class_counts))
|
||||
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Pie(
|
||||
labels=class_counts.index,
|
||||
values=class_counts.values,
|
||||
marker_colors=colors,
|
||||
textinfo="label+percent",
|
||||
textposition="auto",
|
||||
hovertemplate="<b>%{label}</b><br>Count: %{value:,}<br>Percentage: %{percent}<extra></extra>",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.markdown("**Cumulative Distribution")
|
||||
|
||||
# Create cumulative distribution
|
||||
sorted_counts = class_counts.sort_values(ascending=False)
|
||||
cumulative = sorted_counts.cumsum()
|
||||
cumulative_pct = cumulative / cumulative.iloc[-1] * 100
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=list(range(len(cumulative))),
|
||||
y=cumulative_pct.to_numpy(),
|
||||
mode="lines+markers",
|
||||
line={"color": colors[0], "width": 3},
|
||||
marker={"size": 8},
|
||||
customdata=sorted_counts.index,
|
||||
hovertemplate="<b>%{customdata}</b><br>Cumulative: %{y:.1f}%<extra></extra>",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
height=400,
|
||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
||||
xaxis_title="Class Rank",
|
||||
yaxis_title="Cumulative Percentage",
|
||||
yaxis={"range": [0, 105]},
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
|
@ -4,6 +4,7 @@ import streamlit as st
|
|||
|
||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||
render_binned_parameter_space,
|
||||
render_confusion_matrix_map,
|
||||
render_espa_binned_parameter_space,
|
||||
render_multi_metric_comparison,
|
||||
render_parameter_correlation,
|
||||
|
|
@ -128,6 +129,13 @@ def render_training_analysis_page():
|
|||
|
||||
st.divider()
|
||||
|
||||
# Confusion Matrix Map Section
|
||||
st.header("🗺️ Prediction Results Map")
|
||||
|
||||
render_confusion_matrix_map(selected_result.path, settings)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Quick Statistics
|
||||
st.header("📈 Cross-Validation Statistics")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue