entropice/src/entropice/dashboard/plots/training_data.py

366 lines
14 KiB
Python

"""Plotting functions for training data visualizations."""
import geopandas as gpd
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import CategoricalTrainingDataset
def render_all_distribution_histograms(
train_data_dict: dict[str, CategoricalTrainingDataset],
):
"""Render histograms for all three tasks side by side.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("📊 Target Distribution by Task")
# Create a 3-column layout for the three tasks
cols = st.columns(3)
tasks = ["binary", "count", "density"]
task_titles = {
"binary": "Binary Classification",
"count": "Count Classification",
"density": "Density Classification",
}
for idx, task in enumerate(tasks):
dataset = train_data_dict[task]
categories = dataset.y.binned.cat.categories.tolist()
colors = get_palette(task, len(categories))
with cols[idx]:
st.markdown(f"**{task_titles[task]}**")
# Create histogram data
counts_df = pd.DataFrame(
{
"Category": categories,
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
}
)
# Create stacked bar chart
fig = go.Figure()
fig.add_trace(
go.Bar(
name="Train",
x=counts_df["Category"],
y=counts_df["Train"],
marker_color=colors,
opacity=0.9,
text=counts_df["Train"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.add_trace(
go.Bar(
name="Test",
x=counts_df["Category"],
y=counts_df["Test"],
marker_color=colors,
opacity=0.6,
text=counts_df["Test"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.update_layout(
barmode="group",
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True,
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
xaxis_title=None,
yaxis_title="Count",
xaxis={"tickangle": -45},
)
st.plotly_chart(fig, width="stretch")
# Show summary statistics
total = len(dataset)
train_pct = (dataset.split == "train").sum() / total * 100
test_pct = (dataset.split == "test").sum() / total * 100
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
"""Assign colors to geodataframe based on the selected color mode.
Args:
gdf: GeoDataFrame to add colors to
color_mode: One of 'target_class' or 'split'
dataset: CategoricalTrainingDataset
selected_task: Task name for color palette selection
Returns:
GeoDataFrame with 'fill_color' column added
"""
if color_mode == "target_class":
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
# Create color mapping
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
gdf["color"] = gdf["target_class"].map(color_map)
# Convert hex colors to RGB
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip("#")
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
elif color_mode == "split":
split_colors = {
"train": [66, 135, 245],
"test": [245, 135, 66],
} # Blue # Orange
gdf["fill_color"] = gdf["split"].map(split_colors)
return gdf
@st.fragment
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
"""Render a pydeck spatial map showing training data distribution with interactive controls.
This is a Streamlit fragment that reruns independently when users interact with the
visualization controls (color mode and opacity), without re-running the entire page.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("🗺️ Spatial Distribution Map")
# Create controls in columns
col1, col2 = st.columns([3, 1])
with col1:
vis_mode = st.selectbox(
"Visualization mode",
options=["binary", "count", "density", "split"],
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
key="spatial_map_mode",
)
with col2:
opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="spatial_map_opacity",
)
# Determine which task dataset to use and color mode
if vis_mode == "split":
# Use binary dataset for split visualization
dataset = train_data_dict["binary"]
color_mode = "split"
selected_task = "binary"
else:
# Use the selected task
dataset = train_data_dict[vis_mode]
color_mode = "target_class"
selected_task = vis_mode
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
# Fix antimeridian issues
gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
# Add binned labels and split information from current dataset
gdf["target_class"] = dataset.y.binned.to_numpy()
gdf["split"] = dataset.split.to_numpy()
gdf["raw_value"] = dataset.z.to_numpy()
# Add information from all three tasks for tooltip
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
# Convert to WGS84 for pydeck
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
# Assign colors based on the selected mode
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
# Convert to GeoJSON format and add elevation for 3D visualization
geojson_data = []
# Normalize raw values for elevation (only for count and density)
use_elevation = vis_mode in ["count", "density"]
if use_elevation:
raw_values = gdf_wgs84["raw_value"]
min_val, max_val = raw_values.min(), raw_values.max()
# Normalize to 0-1 range for better 3D visualization
if max_val > min_val:
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
else:
gdf_wgs84["elevation"] = 0
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"target_class": str(row["target_class"]),
"split": str(row["split"]),
"raw_value": float(row["raw_value"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]) if use_elevation else 0,
"binary_label": str(row["binary_label"]),
"count_category": str(row["count_category"]),
"count_raw": int(row["count_raw"]),
"density_category": str(row["density_category"]),
"density_raw": f"{float(row['density_raw']):.4f}",
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=use_elevation,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if use_elevation else 0,
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
pickable=True,
)
# Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using elevation
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not use_elevation else 1.5,
pitch=0 if not use_elevation else 45,
)
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Binary:</b> {binary_label}<br/>"
"<b>Count Category:</b> {count_category}<br/>"
"<b>Count Raw:</b> {count_raw}<br/>"
"<b>Density Category:</b> {density_category}<br/>"
"<b>Density Raw:</b> {density_raw}<br/>"
"<b>Split:</b> {split}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
# Render the map
st.pydeck_chart(deck)
# Show info about 3D visualization
if use_elevation:
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
# Add legend
with st.expander("Legend", expanded=True):
if color_mode == "target_class":
st.markdown("**Target Classes:**")
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
intervals = dataset.y.intervals
# For count and density tasks, show intervals
if selected_task in ["count", "density"]:
for i, cat in enumerate(categories):
color = colors_palette[i]
interval_min, interval_max = intervals[i]
# Format interval display
if interval_min is None or interval_max is None:
interval_str = ""
elif selected_task == "count":
# Integer values for count
if interval_min == interval_max:
interval_str = f" ({int(interval_min)})"
else:
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
else: # density
# Percentage values for density
if interval_min == interval_max:
interval_str = f" ({interval_min * 100:.4f}%)"
else:
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{cat}{interval_str}</span></div>",
unsafe_allow_html=True,
)
else:
# Binary task: use original column layout
legend_cols = st.columns(len(categories))
for i, cat in enumerate(categories):
with legend_cols[i]:
color = colors_palette[i]
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
f"<span>{cat}</span></div>",
unsafe_allow_html=True,
)
if use_elevation:
st.markdown("---")
st.markdown("**Elevation (3D):**")
min_val = gdf_wgs84["raw_value"].min()
max_val = gdf_wgs84["raw_value"].max()
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
elif color_mode == "split":
st.markdown("**Data Split:**")
legend_html = (
'<div style="display: flex; gap: 20px;">'
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Train</span></div>"
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Test</span></div></div>"
)
st.markdown(legend_html, unsafe_allow_html=True)