366 lines
14 KiB
Python
366 lines
14 KiB
Python
"""Plotting functions for training data visualizations."""
|
|
|
|
import geopandas as gpd
|
|
import pandas as pd
|
|
import plotly.graph_objects as go
|
|
import pydeck as pdk
|
|
import streamlit as st
|
|
|
|
from entropice.dashboard.utils.colors import get_palette
|
|
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
|
from entropice.ml.dataset import CategoricalTrainingDataset
|
|
|
|
|
|
def render_all_distribution_histograms(
|
|
train_data_dict: dict[str, CategoricalTrainingDataset],
|
|
):
|
|
"""Render histograms for all three tasks side by side.
|
|
|
|
Args:
|
|
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
|
|
"""
|
|
st.subheader("📊 Target Distribution by Task")
|
|
|
|
# Create a 3-column layout for the three tasks
|
|
cols = st.columns(3)
|
|
|
|
tasks = ["binary", "count", "density"]
|
|
task_titles = {
|
|
"binary": "Binary Classification",
|
|
"count": "Count Classification",
|
|
"density": "Density Classification",
|
|
}
|
|
|
|
for idx, task in enumerate(tasks):
|
|
dataset = train_data_dict[task]
|
|
categories = dataset.y.binned.cat.categories.tolist()
|
|
colors = get_palette(task, len(categories))
|
|
|
|
with cols[idx]:
|
|
st.markdown(f"**{task_titles[task]}**")
|
|
|
|
# Create histogram data
|
|
counts_df = pd.DataFrame(
|
|
{
|
|
"Category": categories,
|
|
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
|
|
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
|
|
}
|
|
)
|
|
|
|
# Create stacked bar chart
|
|
fig = go.Figure()
|
|
|
|
fig.add_trace(
|
|
go.Bar(
|
|
name="Train",
|
|
x=counts_df["Category"],
|
|
y=counts_df["Train"],
|
|
marker_color=colors,
|
|
opacity=0.9,
|
|
text=counts_df["Train"],
|
|
textposition="inside",
|
|
textfont={"size": 10, "color": "white"},
|
|
)
|
|
)
|
|
|
|
fig.add_trace(
|
|
go.Bar(
|
|
name="Test",
|
|
x=counts_df["Category"],
|
|
y=counts_df["Test"],
|
|
marker_color=colors,
|
|
opacity=0.6,
|
|
text=counts_df["Test"],
|
|
textposition="inside",
|
|
textfont={"size": 10, "color": "white"},
|
|
)
|
|
)
|
|
|
|
fig.update_layout(
|
|
barmode="group",
|
|
height=400,
|
|
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
|
showlegend=True,
|
|
legend={
|
|
"orientation": "h",
|
|
"yanchor": "bottom",
|
|
"y": 1.02,
|
|
"xanchor": "right",
|
|
"x": 1,
|
|
},
|
|
xaxis_title=None,
|
|
yaxis_title="Count",
|
|
xaxis={"tickangle": -45},
|
|
)
|
|
|
|
st.plotly_chart(fig, width="stretch")
|
|
|
|
# Show summary statistics
|
|
total = len(dataset)
|
|
train_pct = (dataset.split == "train").sum() / total * 100
|
|
test_pct = (dataset.split == "test").sum() / total * 100
|
|
|
|
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
|
|
|
|
|
|
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
|
|
"""Assign colors to geodataframe based on the selected color mode.
|
|
|
|
Args:
|
|
gdf: GeoDataFrame to add colors to
|
|
color_mode: One of 'target_class' or 'split'
|
|
dataset: CategoricalTrainingDataset
|
|
selected_task: Task name for color palette selection
|
|
|
|
Returns:
|
|
GeoDataFrame with 'fill_color' column added
|
|
|
|
"""
|
|
if color_mode == "target_class":
|
|
categories = dataset.y.binned.cat.categories.tolist()
|
|
colors_palette = get_palette(selected_task, len(categories))
|
|
|
|
# Create color mapping
|
|
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
|
|
gdf["color"] = gdf["target_class"].map(color_map)
|
|
|
|
# Convert hex colors to RGB
|
|
def hex_to_rgb(hex_color):
|
|
hex_color = hex_color.lstrip("#")
|
|
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
|
|
|
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
|
|
|
|
elif color_mode == "split":
|
|
split_colors = {
|
|
"train": [66, 135, 245],
|
|
"test": [245, 135, 66],
|
|
} # Blue # Orange
|
|
gdf["fill_color"] = gdf["split"].map(split_colors)
|
|
|
|
return gdf
|
|
|
|
|
|
@st.fragment
|
|
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
|
"""Render a pydeck spatial map showing training data distribution with interactive controls.
|
|
|
|
This is a Streamlit fragment that reruns independently when users interact with the
|
|
visualization controls (color mode and opacity), without re-running the entire page.
|
|
|
|
Args:
|
|
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
|
|
"""
|
|
st.subheader("🗺️ Spatial Distribution Map")
|
|
|
|
# Create controls in columns
|
|
col1, col2 = st.columns([3, 1])
|
|
|
|
with col1:
|
|
vis_mode = st.selectbox(
|
|
"Visualization mode",
|
|
options=["binary", "count", "density", "split"],
|
|
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
|
|
key="spatial_map_mode",
|
|
)
|
|
|
|
with col2:
|
|
opacity = st.slider(
|
|
"Opacity",
|
|
min_value=0.1,
|
|
max_value=1.0,
|
|
value=0.7,
|
|
step=0.1,
|
|
key="spatial_map_opacity",
|
|
)
|
|
|
|
# Determine which task dataset to use and color mode
|
|
if vis_mode == "split":
|
|
# Use binary dataset for split visualization
|
|
dataset = train_data_dict["binary"]
|
|
color_mode = "split"
|
|
selected_task = "binary"
|
|
else:
|
|
# Use the selected task
|
|
dataset = train_data_dict[vis_mode]
|
|
color_mode = "target_class"
|
|
selected_task = vis_mode
|
|
|
|
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
|
|
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
|
|
|
|
# Fix antimeridian issues
|
|
gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
|
|
|
|
# Add binned labels and split information from current dataset
|
|
gdf["target_class"] = dataset.y.binned.to_numpy()
|
|
gdf["split"] = dataset.split.to_numpy()
|
|
gdf["raw_value"] = dataset.z.to_numpy()
|
|
|
|
# Add information from all three tasks for tooltip
|
|
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
|
|
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
|
|
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
|
|
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
|
|
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
|
|
|
|
# Convert to WGS84 for pydeck
|
|
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
|
|
|
|
# Assign colors based on the selected mode
|
|
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
|
|
|
|
# Convert to GeoJSON format and add elevation for 3D visualization
|
|
geojson_data = []
|
|
# Normalize raw values for elevation (only for count and density)
|
|
use_elevation = vis_mode in ["count", "density"]
|
|
if use_elevation:
|
|
raw_values = gdf_wgs84["raw_value"]
|
|
min_val, max_val = raw_values.min(), raw_values.max()
|
|
# Normalize to 0-1 range for better 3D visualization
|
|
if max_val > min_val:
|
|
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
|
|
else:
|
|
gdf_wgs84["elevation"] = 0
|
|
|
|
for _, row in gdf_wgs84.iterrows():
|
|
feature = {
|
|
"type": "Feature",
|
|
"geometry": row["geometry"].__geo_interface__,
|
|
"properties": {
|
|
"target_class": str(row["target_class"]),
|
|
"split": str(row["split"]),
|
|
"raw_value": float(row["raw_value"]),
|
|
"fill_color": row["fill_color"],
|
|
"elevation": float(row["elevation"]) if use_elevation else 0,
|
|
"binary_label": str(row["binary_label"]),
|
|
"count_category": str(row["count_category"]),
|
|
"count_raw": int(row["count_raw"]),
|
|
"density_category": str(row["density_category"]),
|
|
"density_raw": f"{float(row['density_raw']):.4f}",
|
|
},
|
|
}
|
|
geojson_data.append(feature)
|
|
|
|
# Create pydeck layer
|
|
layer = pdk.Layer(
|
|
"GeoJsonLayer",
|
|
geojson_data,
|
|
opacity=opacity,
|
|
stroked=True,
|
|
filled=True,
|
|
extruded=use_elevation,
|
|
wireframe=False,
|
|
get_fill_color="properties.fill_color",
|
|
get_line_color=[80, 80, 80],
|
|
line_width_min_pixels=0.5,
|
|
get_elevation="properties.elevation" if use_elevation else 0,
|
|
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
|
|
pickable=True,
|
|
)
|
|
|
|
# Set initial view state (centered on the Arctic)
|
|
# Adjust pitch and zoom based on whether we're using elevation
|
|
view_state = pdk.ViewState(
|
|
latitude=70,
|
|
longitude=0,
|
|
zoom=2 if not use_elevation else 1.5,
|
|
pitch=0 if not use_elevation else 45,
|
|
)
|
|
|
|
# Create deck
|
|
deck = pdk.Deck(
|
|
layers=[layer],
|
|
initial_view_state=view_state,
|
|
tooltip={
|
|
"html": "<b>Binary:</b> {binary_label}<br/>"
|
|
"<b>Count Category:</b> {count_category}<br/>"
|
|
"<b>Count Raw:</b> {count_raw}<br/>"
|
|
"<b>Density Category:</b> {density_category}<br/>"
|
|
"<b>Density Raw:</b> {density_raw}<br/>"
|
|
"<b>Split:</b> {split}",
|
|
"style": {"backgroundColor": "steelblue", "color": "white"},
|
|
},
|
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
|
)
|
|
|
|
# Render the map
|
|
st.pydeck_chart(deck)
|
|
|
|
# Show info about 3D visualization
|
|
if use_elevation:
|
|
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
|
|
|
# Add legend
|
|
with st.expander("Legend", expanded=True):
|
|
if color_mode == "target_class":
|
|
st.markdown("**Target Classes:**")
|
|
categories = dataset.y.binned.cat.categories.tolist()
|
|
colors_palette = get_palette(selected_task, len(categories))
|
|
intervals = dataset.y.intervals
|
|
|
|
# For count and density tasks, show intervals
|
|
if selected_task in ["count", "density"]:
|
|
for i, cat in enumerate(categories):
|
|
color = colors_palette[i]
|
|
interval_min, interval_max = intervals[i]
|
|
|
|
# Format interval display
|
|
if interval_min is None or interval_max is None:
|
|
interval_str = ""
|
|
elif selected_task == "count":
|
|
# Integer values for count
|
|
if interval_min == interval_max:
|
|
interval_str = f" ({int(interval_min)})"
|
|
else:
|
|
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
|
|
else: # density
|
|
# Percentage values for density
|
|
if interval_min == interval_max:
|
|
interval_str = f" ({interval_min * 100:.4f}%)"
|
|
else:
|
|
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
|
|
|
|
st.markdown(
|
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
|
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
|
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
|
f"<span>{cat}{interval_str}</span></div>",
|
|
unsafe_allow_html=True,
|
|
)
|
|
else:
|
|
# Binary task: use original column layout
|
|
legend_cols = st.columns(len(categories))
|
|
for i, cat in enumerate(categories):
|
|
with legend_cols[i]:
|
|
color = colors_palette[i]
|
|
st.markdown(
|
|
f'<div style="display: flex; align-items: center;">'
|
|
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
|
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
f"<span>{cat}</span></div>",
|
|
unsafe_allow_html=True,
|
|
)
|
|
if use_elevation:
|
|
st.markdown("---")
|
|
st.markdown("**Elevation (3D):**")
|
|
min_val = gdf_wgs84["raw_value"].min()
|
|
max_val = gdf_wgs84["raw_value"].max()
|
|
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
|
|
elif color_mode == "split":
|
|
st.markdown("**Data Split:**")
|
|
legend_html = (
|
|
'<div style="display: flex; gap: 20px;">'
|
|
'<div style="display: flex; align-items: center;">'
|
|
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
|
|
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
"<span>Train</span></div>"
|
|
'<div style="display: flex; align-items: center;">'
|
|
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
|
|
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
"<span>Test</span></div></div>"
|
|
)
|
|
st.markdown(legend_html, unsafe_allow_html=True)
|