diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 6a54837..85e3671 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -18,7 +18,6 @@ from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
-from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@@ -28,7 +27,6 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
- training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
@@ -38,8 +36,7 @@ def main():
{
"Overview": [overview_page],
"Data": [data_page],
- "Training": [training_data_page, training_analysis_page, autogluon_page],
- "Model State": [model_state_page],
+ "Experiments": [training_analysis_page, autogluon_page, model_state_page],
"Inference": [inference_page],
}
)
diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py
deleted file mode 100644
index cb0299a..0000000
--- a/src/entropice/dashboard/plots/source_data.py
+++ /dev/null
@@ -1,1054 +0,0 @@
-"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
-
-import geopandas as gpd
-import numpy as np
-import pandas as pd
-import plotly.graph_objects as go
-import pydeck as pdk
-import streamlit as st
-import xarray as xr
-
-from entropice.dashboard.utils.colors import get_cmap
-from entropice.dashboard.utils.geometry import fix_hex_geometry
-
-# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
-
-
-def render_alphaearth_overview(ds: xr.Dataset):
- """Render overview statistics for AlphaEarth embeddings data.
-
- Args:
- ds: xarray Dataset containing AlphaEarth embeddings.
-
- """
- st.subheader("📊 Data Overview")
-
- # Key metrics
- col1, col2, col3, col4 = st.columns(4)
-
- with col1:
- st.metric("Cells", f"{len(ds['cell_ids']):,}")
-
- with col2:
- st.metric("Embedding Dims", f"{len(ds['band'])}")
-
- with col3:
- years = sorted(ds["year"].values)
- st.metric("Years", f"{min(years)}–{max(years)}")
-
- with col4:
- st.metric("Aggregations", f"{len(ds['agg'])}")
-
- # Show aggregations as badges in an expander
- with st.expander("ℹ️ Data Details", expanded=False):
- st.markdown("**Spatial Aggregations:**")
- aggs = ds["agg"].to_numpy()
- aggs_html = " ".join(
- [
- f'{a}'
- for a in aggs
- ]
- )
- st.markdown(aggs_html, unsafe_allow_html=True)
-
-
-@st.fragment
-def render_alphaearth_plots(ds: xr.Dataset):
- """Render interactive plots for AlphaEarth embeddings data.
-
- Args:
- ds: xarray Dataset containing AlphaEarth embeddings.
-
- """
- st.markdown("---")
- st.markdown("**Embedding Distribution by Band**")
-
- embeddings_data = ds["embeddings"]
-
- # Select year and aggregation for visualization
- col1, col2 = st.columns(2)
- with col1:
- selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
- with col2:
- selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
-
- # Get data for selected year and aggregation
- year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
-
- # Calculate statistics for each band
- band_stats = []
- for band_idx in range(len(ds["band"])):
- band_data = year_agg_data.isel(band=band_idx).values.flatten()
- band_data = band_data[~np.isnan(band_data)] # Remove NaN values
-
- if len(band_data) > 0:
- band_stats.append(
- {
- "Band": band_idx,
- "Mean": float(np.mean(band_data)),
- "Std": float(np.std(band_data)),
- "Min": float(np.min(band_data)),
- "25%": float(np.percentile(band_data, 25)),
- "Median": float(np.median(band_data)),
- "75%": float(np.percentile(band_data, 75)),
- "Max": float(np.max(band_data)),
- }
- )
-
- band_df = pd.DataFrame(band_stats)
-
- # Create plot showing distribution across bands
- fig = go.Figure()
-
- # Add min/max range first (background)
- fig.add_trace(
- go.Scatter(
- x=band_df["Band"],
- y=band_df["Min"],
- mode="lines",
- line={"color": "lightgray", "width": 1, "dash": "dash"},
- name="Min/Max Range",
- showlegend=True,
- )
- )
-
- fig.add_trace(
- go.Scatter(
- x=band_df["Band"],
- y=band_df["Max"],
- mode="lines",
- fill="tonexty",
- fillcolor="rgba(200, 200, 200, 0.1)",
- line={"color": "lightgray", "width": 1, "dash": "dash"},
- showlegend=False,
- )
- )
-
- # Add std band
- fig.add_trace(
- go.Scatter(
- x=band_df["Band"],
- y=band_df["Mean"] - band_df["Std"],
- mode="lines",
- line={"width": 0},
- showlegend=False,
- hoverinfo="skip",
- )
- )
-
- fig.add_trace(
- go.Scatter(
- x=band_df["Band"],
- y=band_df["Mean"] + band_df["Std"],
- mode="lines",
- fill="tonexty",
- fillcolor="rgba(31, 119, 180, 0.2)",
- line={"width": 0},
- name="±1 Std",
- )
- )
-
- # Add mean line on top
- fig.add_trace(
- go.Scatter(
- x=band_df["Band"],
- y=band_df["Mean"],
- mode="lines+markers",
- name="Mean",
- line={"color": "#1f77b4", "width": 2},
- marker={"size": 4},
- )
- )
-
- fig.update_layout(
- title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
- xaxis_title="Band",
- yaxis_title="Embedding Value",
- height=450,
- hovermode="x unified",
- )
-
- st.plotly_chart(fig, width="stretch")
-
- # Band statistics
- with st.expander("📈 Statistics by Embedding Band", expanded=False):
- st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
-
- # Calculate statistics for each band
- band_stats = []
- for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
- band_data = embeddings_data.isel(band=band_idx)
- band_stats.append(
- {
- "Band": int(band_idx),
- "Mean": float(band_data.mean().values),
- "Std": float(band_data.std().values),
- "Min": float(band_data.min().values),
- "Max": float(band_data.max().values),
- }
- )
-
- band_df = pd.DataFrame(band_stats)
- st.dataframe(band_df, width="stretch", hide_index=True)
-
- if len(ds["band"]) > 10:
- st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
-
-
-def render_arcticdem_overview(ds: xr.Dataset):
- """Render overview statistics for ArcticDEM terrain data.
-
- Args:
- ds: xarray Dataset containing ArcticDEM data.
-
- """
- st.subheader("📊 Data Overview")
-
- # Key metrics
- col1, col2, col3 = st.columns(3)
-
- with col1:
- st.metric("Cells", f"{len(ds['cell_ids']):,}")
-
- with col2:
- st.metric("Variables", f"{len(ds.data_vars)}")
-
- with col3:
- st.metric("Aggregations", f"{len(ds['aggregations'])}")
-
- # Show details in expander
- with st.expander("ℹ️ Data Details", expanded=False):
- st.markdown("**Spatial Aggregations:**")
- aggs = ds["aggregations"].to_numpy()
- aggs_html = " ".join(
- [
- f'{a}'
- for a in aggs
- ]
- )
- st.markdown(aggs_html, unsafe_allow_html=True)
-
- # Statistics by variable
- st.markdown("---")
- st.markdown("**📈 Variable Statistics**")
-
- var_stats = []
- for var_name in ds.data_vars:
- var_data = ds[var_name]
- var_stats.append(
- {
- "Variable": var_name,
- "Mean": float(var_data.mean().values),
- "Std": float(var_data.std().values),
- "Min": float(var_data.min().values),
- "Max": float(var_data.max().values),
- "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
- }
- )
-
- stats_df = pd.DataFrame(var_stats)
- st.dataframe(stats_df, width="stretch", hide_index=True)
-
-
-@st.fragment
-def render_arcticdem_plots(ds: xr.Dataset):
- """Render interactive plots for ArcticDEM terrain data.
-
- Args:
- ds: xarray Dataset containing ArcticDEM data.
-
- """
- st.markdown("---")
- st.markdown("**Variable Distributions**")
-
- variables = list(ds.data_vars)
-
- # Select a variable to visualize
- selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
-
- if selected_var:
- var_data = ds[selected_var]
-
- # Create histogram
- fig = go.Figure()
-
- for agg in ds["aggregations"].to_numpy():
- agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
- agg_data = agg_data[~np.isnan(agg_data)]
-
- fig.add_trace(
- go.Histogram(
- x=agg_data,
- name=str(agg),
- opacity=0.7,
- nbinsx=50,
- )
- )
-
- fig.update_layout(
- title=f"Distribution of {selected_var} by Aggregation",
- xaxis_title=selected_var,
- yaxis_title="Count",
- barmode="overlay",
- height=400,
- )
-
- st.plotly_chart(fig, width="stretch")
-
-
-def render_era5_overview(ds: xr.Dataset, temporal_type: str):
- """Render overview statistics for ERA5 climate data.
-
- Args:
- ds: xarray Dataset containing ERA5 data.
- temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
-
- """
- st.subheader("📊 Data Overview")
-
- # Key metrics
- has_agg = "aggregations" in ds.dims
- col1, col2, col3, col4 = st.columns(4)
-
- with col1:
- st.metric("Cells", f"{len(ds['cell_ids']):,}")
-
- with col2:
- st.metric("Variables", f"{len(ds.data_vars)}")
-
- with col3:
- time_values = pd.to_datetime(ds["time"].values)
- st.metric(
- "Time Steps",
- f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}",
- )
-
- with col4:
- if has_agg:
- st.metric("Aggregations", f"{len(ds['aggregations'])}")
- else:
- st.metric("Temporal Type", temporal_type.capitalize())
-
- # Show details in expander
- with st.expander("ℹ️ Data Details", expanded=False):
- st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
- st.markdown(
- f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
- )
-
- if has_agg:
- st.markdown("**Spatial Aggregations:**")
- aggs = ds["aggregations"].to_numpy()
- aggs_html = " ".join(
- [
- f'{a}'
- for a in aggs
- ]
- )
- st.markdown(aggs_html, unsafe_allow_html=True)
-
- # Statistics by variable
- st.markdown("---")
- st.markdown("**📈 Variable Statistics**")
-
- var_stats = []
- for var_name in ds.data_vars:
- var_data = ds[var_name]
- var_stats.append(
- {
- "Variable": var_name,
- "Mean": float(var_data.mean().values),
- "Std": float(var_data.std().values),
- "Min": float(var_data.min().values),
- "Max": float(var_data.max().values),
- "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
- }
- )
-
- stats_df = pd.DataFrame(var_stats)
- st.dataframe(stats_df, width="stretch", hide_index=True)
-
-
-@st.fragment
-def render_era5_plots(ds: xr.Dataset, temporal_type: str):
- """Render interactive plots for ERA5 climate data.
-
- Args:
- ds: xarray Dataset containing ERA5 data.
- temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
-
- """
- st.markdown("---")
- st.markdown("**Temporal Trends**")
-
- variables = list(ds.data_vars)
- has_agg = "aggregations" in ds.dims
-
- if has_agg:
- col1, col2, col3 = st.columns([2, 2, 1])
- with col1:
- selected_var = st.selectbox(
- "Select variable to visualize",
- options=variables,
- key=f"era5_{temporal_type}_var_select",
- )
- with col2:
- selected_agg = st.selectbox(
- "Aggregation",
- options=ds["aggregations"].values,
- key=f"era5_{temporal_type}_agg_select",
- )
- with col3:
- show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
- show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
- else:
- col1, col2 = st.columns([3, 1])
- with col1:
- selected_var = st.selectbox(
- "Select variable to visualize",
- options=variables,
- key=f"era5_{temporal_type}_var_select",
- )
- with col2:
- show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
- show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
-
- if selected_var:
- var_data = ds[selected_var]
-
- # Calculate statistics over space for each time step
- time_values = pd.to_datetime(ds["time"].to_numpy())
-
- if has_agg:
- # Select specific aggregation, then calculate stats over cells
- var_data_agg = var_data.sel(aggregations=selected_agg)
- time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
- time_std = var_data_agg.std(dim="cell_ids").to_numpy()
- time_min = var_data_agg.min(dim="cell_ids").to_numpy()
- time_max = var_data_agg.max(dim="cell_ids").to_numpy()
- else:
- time_mean = var_data.mean(dim="cell_ids").to_numpy()
- time_std = var_data.std(dim="cell_ids").to_numpy()
- time_min = var_data.min(dim="cell_ids").to_numpy()
- time_max = var_data.max(dim="cell_ids").to_numpy()
-
- fig = go.Figure()
-
- # Add min/max range first (background) - optional
- if show_minmax:
- fig.add_trace(
- go.Scatter(
- x=time_values,
- y=time_min,
- mode="lines",
- line={"color": "lightgray", "width": 1, "dash": "dash"},
- name="Min/Max Range",
- showlegend=True,
- )
- )
-
- fig.add_trace(
- go.Scatter(
- x=time_values,
- y=time_max,
- mode="lines",
- fill="tonexty",
- fillcolor="rgba(200, 200, 200, 0.1)",
- line={"color": "lightgray", "width": 1, "dash": "dash"},
- showlegend=False,
- )
- )
-
- # Add std band - optional
- if show_std:
- fig.add_trace(
- go.Scatter(
- x=time_values,
- y=time_mean - time_std,
- mode="lines",
- line={"width": 0},
- showlegend=False,
- hoverinfo="skip",
- )
- )
-
- fig.add_trace(
- go.Scatter(
- x=time_values,
- y=time_mean + time_std,
- mode="lines",
- fill="tonexty",
- fillcolor="rgba(31, 119, 180, 0.2)",
- line={"width": 0},
- name="±1 Std",
- )
- )
-
- # Add mean line on top
- fig.add_trace(
- go.Scatter(
- x=time_values,
- y=time_mean,
- mode="lines+markers",
- name="Mean",
- line={"color": "#1f77b4", "width": 2},
- marker={"size": 4},
- )
- )
-
- title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
- fig.update_layout(
- title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
- xaxis_title="Time",
- yaxis_title=selected_var,
- height=400,
- hovermode="x unified",
- )
-
- st.plotly_chart(fig, width="stretch")
-
-
-@st.fragment
-def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
- """Render interactive pydeck map for AlphaEarth embeddings.
-
- Args:
- ds: xarray Dataset containing AlphaEarth data.
- targets: GeoDataFrame with geometry for each cell.
- grid: Grid type ('hex' or 'healpix').
-
- """
- st.subheader("🗺️ AlphaEarth Spatial Distribution")
-
- # Year slider (full width)
- years = sorted(ds["year"].values)
- selected_year = st.slider(
- "Year",
- min_value=int(years[0]),
- max_value=int(years[-1]),
- value=int(years[-1]),
- step=1,
- key="alphaearth_year",
- )
-
- # Other controls
- col1, col2, col3 = st.columns([2, 2, 1])
-
- with col1:
- selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
-
- with col2:
- selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
-
- with col3:
- opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
-
- # Extract data for selected parameters
- data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
-
- # Create GeoDataFrame
- gdf = targets.copy()
- gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
- gdf = gdf.set_index("cell_id")
-
- # Add values
- values_df = data_values.to_dataframe(name="value")
- gdf = gdf.join(values_df, how="inner")
-
- # Convert to WGS84 first
- gdf_wgs84 = gdf.to_crs("EPSG:4326")
-
- # Fix geometries after CRS conversion
- if grid == "hex":
- gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
-
- # Normalize values for color mapping
- values = gdf_wgs84["value"].to_numpy()
- vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
- normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
-
- # Apply colormap
- cmap = get_cmap("embeddings")
- colors = [cmap(val) for val in normalized]
- gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
-
- # Set elevation based on normalized values
- gdf_wgs84["elevation"] = normalized
-
- # Create GeoJSON
- geojson_data = []
- for _, row in gdf_wgs84.iterrows():
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": {
- "value": float(row["value"]),
- "fill_color": row["fill_color"],
- "elevation": float(row["elevation"]),
- },
- }
- geojson_data.append(feature)
-
- # Create pydeck layer with 3D elevation
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=True,
- get_fill_color="properties.fill_color",
- get_line_color=[80, 80, 80],
- get_elevation="properties.elevation",
- elevation_scale=500000,
- line_width_min_pixels=0.5,
- pickable=True,
- )
-
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
-
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": "Value: {value}",
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- st.pydeck_chart(deck)
-
- # Show statistics
- st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
-
-
-@st.fragment
-def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
- """Render interactive pydeck map for ArcticDEM terrain data.
-
- Args:
- ds: xarray Dataset containing ArcticDEM data.
- targets: GeoDataFrame with geometry for each cell.
- grid: Grid type ('hex' or 'healpix').
-
- """
- st.subheader("🗺️ ArcticDEM Spatial Distribution")
-
- # Controls
- variables = list(ds.data_vars)
- col1, col2, col3 = st.columns([3, 3, 1])
-
- with col1:
- selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
-
- with col2:
- selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
-
- with col3:
- opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
-
- # Extract data for selected parameters
- data_values = ds[selected_var].sel(aggregations=selected_agg)
-
- # Create GeoDataFrame
- gdf = targets.copy()
- gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
- gdf = gdf.set_index("cell_id")
-
- # Add values
- values_df = data_values.to_dataframe(name="value")
- gdf = gdf.join(values_df, how="inner")
-
- # Add all aggregation values for tooltip
- if len(ds["aggregations"]) > 1:
- for agg in ds["aggregations"].values:
- agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
- # Drop the aggregations column to avoid conflicts
- if "aggregations" in agg_data.columns:
- agg_data = agg_data.drop(columns=["aggregations"])
- gdf = gdf.join(agg_data, how="left")
-
- # Convert to WGS84 first
- gdf_wgs84 = gdf.to_crs("EPSG:4326")
-
- # Fix geometries after CRS conversion
- if grid == "hex":
- gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
-
- # Normalize values for color mapping
- values = gdf_wgs84["value"].values
- values_clean = values[~np.isnan(values)]
-
- if len(values_clean) > 0:
- vmin, vmax = np.nanpercentile(values_clean, [2, 98])
- normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
-
- # Apply colormap
- cmap = get_cmap(f"arcticdem_{selected_var}")
- colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
- gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
-
- # Set elevation based on normalized values
- gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
-
- # Create GeoJSON
- geojson_data = []
- for _, row in gdf_wgs84.iterrows():
- properties = {
- "value": float(row["value"]) if not np.isnan(row["value"]) else None,
- "fill_color": row["fill_color"],
- "elevation": float(row["elevation"]),
- }
- # Add all aggregation values if available
- if len(ds["aggregations"]) > 1:
- for agg in ds["aggregations"].values:
- agg_col = f"agg_{agg}"
- if agg_col in row.index:
- properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
-
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": properties,
- }
- geojson_data.append(feature)
-
- # Create pydeck layer with 3D elevation
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=True,
- get_fill_color="properties.fill_color",
- get_line_color=[80, 80, 80],
- get_elevation="properties.elevation",
- elevation_scale=500000,
- line_width_min_pixels=0.5,
- pickable=True,
- )
-
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
-
- # Build tooltip HTML for ArcticDEM
- if len(ds["aggregations"]) > 1:
- tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"]
- tooltip_lines.append("All aggregations:
")
- for agg in ds["aggregations"].values:
- tooltip_lines.append(f" {agg}: {{agg_{agg}}}
")
- tooltip_html = "".join(tooltip_lines)
- else:
- tooltip_html = f"{selected_var}: {{value}}"
-
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": tooltip_html,
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- st.pydeck_chart(deck)
-
- # Show statistics
- st.caption(
- f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
- f"Std: {np.nanstd(values_clean):.2f}"
- )
- else:
- st.warning("No valid data available for selected parameters")
-
-
-@st.fragment
-def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
- """Render interactive pydeck map for grid cell areas.
-
- Args:
- grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
- grid: Grid type ('hex' or 'healpix').
-
- """
- st.subheader("🗺️ Grid Cell Areas Distribution")
-
- # Controls
- col1, col2 = st.columns([3, 1])
-
- with col1:
- area_metric = st.selectbox(
- "Area Metric",
- options=["cell_area", "land_area", "water_area", "land_ratio"],
- format_func=lambda x: x.replace("_", " ").title(),
- key="areas_metric",
- )
-
- with col2:
- opacity = st.slider(
- "Opacity",
- min_value=0.1,
- max_value=1.0,
- value=0.7,
- step=0.1,
- key="areas_map_opacity",
- )
-
- # Create GeoDataFrame
- gdf = grid_gdf.copy()
-
- # Convert to WGS84 first
- gdf_wgs84 = gdf.to_crs("EPSG:4326")
-
- # Fix geometries after CRS conversion
- if grid == "hex":
- gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
-
- # Get values for the selected metric
- values = gdf_wgs84[area_metric].to_numpy()
-
- # Normalize values for color mapping
- vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
- normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
-
- # Apply colormap based on metric type
- if area_metric == "land_ratio":
- cmap = get_cmap("terrain") # Different colormap for ratio
- else:
- cmap = get_cmap("terrain")
-
- colors = [cmap(val) for val in normalized]
- gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
-
- # Set elevation based on normalized values for 3D visualization
- gdf_wgs84["elevation"] = normalized
-
- # Create GeoJSON
- geojson_data = []
- for _, row in gdf_wgs84.iterrows():
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": {
- "cell_area": f"{float(row['cell_area']):.2f}",
- "land_area": f"{float(row['land_area']):.2f}",
- "water_area": f"{float(row['water_area']):.2f}",
- "land_ratio": f"{float(row['land_ratio']):.2%}",
- "fill_color": row["fill_color"],
- "elevation": float(row["elevation"]),
- },
- }
- geojson_data.append(feature)
-
- # Create pydeck layer with 3D elevation
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=True,
- get_fill_color="properties.fill_color",
- get_line_color=[80, 80, 80],
- get_elevation="properties.elevation",
- elevation_scale=500000,
- line_width_min_pixels=0.5,
- pickable=True,
- )
-
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
-
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": "Cell Area: {cell_area} km²
"
- "Land Area: {land_area} km²
"
- "Water Area: {water_area} km²
"
- "Land Ratio: {land_ratio}",
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- st.pydeck_chart(deck)
-
- # Show statistics
- st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
-
- # Show additional info
- st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
-
-
-@st.fragment
-def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
- """Render interactive pydeck map for ERA5 climate data.
-
- Args:
- ds: xarray Dataset containing ERA5 data.
- targets: GeoDataFrame with geometry for each cell.
- grid: Grid type ('hex' or 'healpix').
- temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
-
- """
- st.subheader("🗺️ ERA5 Spatial Distribution")
-
- # Controls
- variables = list(ds.data_vars)
- has_agg = "aggregations" in ds.dims
-
- # Top row: Variable, Aggregation (if applicable), and Opacity
- if has_agg:
- col1, col2, col3 = st.columns([2, 2, 1])
- with col1:
- selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
- with col2:
- selected_agg = st.selectbox(
- "Aggregation",
- options=ds["aggregations"].values,
- key=f"era5_{temporal_type}_agg",
- )
- with col3:
- opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
- else:
- col1, col2 = st.columns([4, 1])
- with col1:
- selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
- with col2:
- opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
-
- # Bottom row: Time slider (full width)
- time_values = pd.to_datetime(ds["time"].values)
- time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
- selected_time_idx = st.slider(
- "Time",
- min_value=0,
- max_value=len(time_values) - 1,
- value=len(time_values) - 1,
- format="",
- key=f"era5_{temporal_type}_time_slider",
- )
- st.caption(f"Selected: {time_labels[selected_time_idx]}")
- selected_time = time_values[selected_time_idx]
-
- # Extract data for selected parameters
- if has_agg:
- data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
- else:
- data_values = ds[selected_var].sel(time=selected_time)
-
- # Create GeoDataFrame
- gdf = targets.copy()
- gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
- gdf = gdf.set_index("cell_id")
-
- # Add values
- values_df = data_values.to_dataframe(name="value")
- gdf = gdf.join(values_df, how="inner")
-
- # Add all aggregation values for tooltip if has_agg
- if has_agg and len(ds["aggregations"]) > 1:
- for agg in ds["aggregations"].values:
- agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
- # Drop dimension columns to avoid conflicts
- cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
- if cols_to_drop:
- agg_data = agg_data.drop(columns=cols_to_drop)
- gdf = gdf.join(agg_data, how="left")
-
- # Convert to WGS84 first
- gdf_wgs84 = gdf.to_crs("EPSG:4326")
-
- # Fix geometries after CRS conversion
- if grid == "hex":
- gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
-
- # Normalize values for color mapping
- values = gdf_wgs84["value"].values
- values_clean = values[~np.isnan(values)]
-
- if len(values_clean) > 0:
- vmin, vmax = np.nanpercentile(values_clean, [2, 98])
- normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
-
- # Apply colormap - use variable-specific colors
- cmap = get_cmap(f"era5_{selected_var}")
- colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
- gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
-
- # Set elevation based on normalized values
- gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
-
- # Create GeoJSON
- geojson_data = []
- for _, row in gdf_wgs84.iterrows():
- properties = {
- "value": float(row["value"]) if not np.isnan(row["value"]) else None,
- "fill_color": row["fill_color"],
- "elevation": float(row["elevation"]),
- }
- # Add all aggregation values if available
- if has_agg and len(ds["aggregations"]) > 1:
- for agg in ds["aggregations"].values:
- agg_col = f"agg_{agg}"
- if agg_col in row.index:
- properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
-
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": properties,
- }
- geojson_data.append(feature)
-
- # Create pydeck layer with 3D elevation
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=True,
- get_fill_color="properties.fill_color",
- get_line_color=[80, 80, 80],
- get_elevation="properties.elevation",
- elevation_scale=500000,
- line_width_min_pixels=0.5,
- pickable=True,
- )
-
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
-
- # Build tooltip HTML for ERA5
- if has_agg and len(ds["aggregations"]) > 1:
- tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"]
- tooltip_lines.append("All aggregations:
")
- for agg in ds["aggregations"].values:
- tooltip_lines.append(f" {agg}: {{agg_{agg}}}
")
- tooltip_html = "".join(tooltip_lines)
- else:
- tooltip_html = f"{selected_var}: {{value}}"
-
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": tooltip_html,
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- st.pydeck_chart(deck)
-
- # Show statistics
- st.caption(
- f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
- f"Std: {np.nanstd(values_clean):.4f}"
- )
- else:
- st.warning("No valid data available for selected parameters")
diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py
deleted file mode 100644
index 966c64a..0000000
--- a/src/entropice/dashboard/plots/training_data.py
+++ /dev/null
@@ -1,366 +0,0 @@
-"""Plotting functions for training data visualizations."""
-
-import geopandas as gpd
-import pandas as pd
-import plotly.graph_objects as go
-import pydeck as pdk
-import streamlit as st
-
-from entropice.dashboard.utils.colors import get_palette
-from entropice.dashboard.utils.geometry import fix_hex_geometry
-from entropice.ml.dataset import CategoricalTrainingDataset
-
-
-def render_all_distribution_histograms(
- train_data_dict: dict[str, CategoricalTrainingDataset],
-):
- """Render histograms for all three tasks side by side.
-
- Args:
- train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
-
- """
- st.subheader("📊 Target Distribution by Task")
-
- # Create a 3-column layout for the three tasks
- cols = st.columns(3)
-
- tasks = ["binary", "count", "density"]
- task_titles = {
- "binary": "Binary Classification",
- "count": "Count Classification",
- "density": "Density Classification",
- }
-
- for idx, task in enumerate(tasks):
- dataset = train_data_dict[task]
- categories = dataset.y.binned.cat.categories.tolist()
- colors = get_palette(task, len(categories))
-
- with cols[idx]:
- st.markdown(f"**{task_titles[task]}**")
-
- # Create histogram data
- counts_df = pd.DataFrame(
- {
- "Category": categories,
- "Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
- "Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
- }
- )
-
- # Create stacked bar chart
- fig = go.Figure()
-
- fig.add_trace(
- go.Bar(
- name="Train",
- x=counts_df["Category"],
- y=counts_df["Train"],
- marker_color=colors,
- opacity=0.9,
- text=counts_df["Train"],
- textposition="inside",
- textfont={"size": 10, "color": "white"},
- )
- )
-
- fig.add_trace(
- go.Bar(
- name="Test",
- x=counts_df["Category"],
- y=counts_df["Test"],
- marker_color=colors,
- opacity=0.6,
- text=counts_df["Test"],
- textposition="inside",
- textfont={"size": 10, "color": "white"},
- )
- )
-
- fig.update_layout(
- barmode="group",
- height=400,
- margin={"l": 20, "r": 20, "t": 20, "b": 20},
- showlegend=True,
- legend={
- "orientation": "h",
- "yanchor": "bottom",
- "y": 1.02,
- "xanchor": "right",
- "x": 1,
- },
- xaxis_title=None,
- yaxis_title="Count",
- xaxis={"tickangle": -45},
- )
-
- st.plotly_chart(fig, width="stretch")
-
- # Show summary statistics
- total = len(dataset)
- train_pct = (dataset.split == "train").sum() / total * 100
- test_pct = (dataset.split == "test").sum() / total * 100
-
- st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
-
-
-def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
- """Assign colors to geodataframe based on the selected color mode.
-
- Args:
- gdf: GeoDataFrame to add colors to
- color_mode: One of 'target_class' or 'split'
- dataset: CategoricalTrainingDataset
- selected_task: Task name for color palette selection
-
- Returns:
- GeoDataFrame with 'fill_color' column added
-
- """
- if color_mode == "target_class":
- categories = dataset.y.binned.cat.categories.tolist()
- colors_palette = get_palette(selected_task, len(categories))
-
- # Create color mapping
- color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
- gdf["color"] = gdf["target_class"].map(color_map)
-
- # Convert hex colors to RGB
- def hex_to_rgb(hex_color):
- hex_color = hex_color.lstrip("#")
- return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
-
- gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
-
- elif color_mode == "split":
- split_colors = {
- "train": [66, 135, 245],
- "test": [245, 135, 66],
- } # Blue # Orange
- gdf["fill_color"] = gdf["split"].map(split_colors)
-
- return gdf
-
-
-@st.fragment
-def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
- """Render a pydeck spatial map showing training data distribution with interactive controls.
-
- This is a Streamlit fragment that reruns independently when users interact with the
- visualization controls (color mode and opacity), without re-running the entire page.
-
- Args:
- train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
-
- """
- st.subheader("🗺️ Spatial Distribution Map")
-
- # Create controls in columns
- col1, col2 = st.columns([3, 1])
-
- with col1:
- vis_mode = st.selectbox(
- "Visualization mode",
- options=["binary", "count", "density", "split"],
- format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
- key="spatial_map_mode",
- )
-
- with col2:
- opacity = st.slider(
- "Opacity",
- min_value=0.1,
- max_value=1.0,
- value=0.7,
- step=0.1,
- key="spatial_map_opacity",
- )
-
- # Determine which task dataset to use and color mode
- if vis_mode == "split":
- # Use binary dataset for split visualization
- dataset = train_data_dict["binary"]
- color_mode = "split"
- selected_task = "binary"
- else:
- # Use the selected task
- dataset = train_data_dict[vis_mode]
- color_mode = "target_class"
- selected_task = vis_mode
-
- # Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
- gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
-
- # Fix antimeridian issues
- gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
-
- # Add binned labels and split information from current dataset
- gdf["target_class"] = dataset.y.binned.to_numpy()
- gdf["split"] = dataset.split.to_numpy()
- gdf["raw_value"] = dataset.z.to_numpy()
-
- # Add information from all three tasks for tooltip
- gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
- gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
- gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
- gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
- gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
-
- # Convert to WGS84 for pydeck
- gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
-
- # Assign colors based on the selected mode
- gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
-
- # Convert to GeoJSON format and add elevation for 3D visualization
- geojson_data = []
- # Normalize raw values for elevation (only for count and density)
- use_elevation = vis_mode in ["count", "density"]
- if use_elevation:
- raw_values = gdf_wgs84["raw_value"]
- min_val, max_val = raw_values.min(), raw_values.max()
- # Normalize to 0-1 range for better 3D visualization
- if max_val > min_val:
- gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
- else:
- gdf_wgs84["elevation"] = 0
-
- for _, row in gdf_wgs84.iterrows():
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": {
- "target_class": str(row["target_class"]),
- "split": str(row["split"]),
- "raw_value": float(row["raw_value"]),
- "fill_color": row["fill_color"],
- "elevation": float(row["elevation"]) if use_elevation else 0,
- "binary_label": str(row["binary_label"]),
- "count_category": str(row["count_category"]),
- "count_raw": int(row["count_raw"]),
- "density_category": str(row["density_category"]),
- "density_raw": f"{float(row['density_raw']):.4f}",
- },
- }
- geojson_data.append(feature)
-
- # Create pydeck layer
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=use_elevation,
- wireframe=False,
- get_fill_color="properties.fill_color",
- get_line_color=[80, 80, 80],
- line_width_min_pixels=0.5,
- get_elevation="properties.elevation" if use_elevation else 0,
- elevation_scale=500000, # Scale normalized values (0-1) to 500km height
- pickable=True,
- )
-
- # Set initial view state (centered on the Arctic)
- # Adjust pitch and zoom based on whether we're using elevation
- view_state = pdk.ViewState(
- latitude=70,
- longitude=0,
- zoom=2 if not use_elevation else 1.5,
- pitch=0 if not use_elevation else 45,
- )
-
- # Create deck
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": "Binary: {binary_label}
"
- "Count Category: {count_category}
"
- "Count Raw: {count_raw}
"
- "Density Category: {density_category}
"
- "Density Raw: {density_raw}
"
- "Split: {split}",
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- # Render the map
- st.pydeck_chart(deck)
-
- # Show info about 3D visualization
- if use_elevation:
- st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
-
- # Add legend
- with st.expander("Legend", expanded=True):
- if color_mode == "target_class":
- st.markdown("**Target Classes:**")
- categories = dataset.y.binned.cat.categories.tolist()
- colors_palette = get_palette(selected_task, len(categories))
- intervals = dataset.y.intervals
-
- # For count and density tasks, show intervals
- if selected_task in ["count", "density"]:
- for i, cat in enumerate(categories):
- color = colors_palette[i]
- interval_min, interval_max = intervals[i]
-
- # Format interval display
- if interval_min is None or interval_max is None:
- interval_str = ""
- elif selected_task == "count":
- # Integer values for count
- if interval_min == interval_max:
- interval_str = f" ({int(interval_min)})"
- else:
- interval_str = f" ({int(interval_min)}-{int(interval_max)})"
- else: # density
- # Percentage values for density
- if interval_min == interval_max:
- interval_str = f" ({interval_min * 100:.4f}%)"
- else:
- interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
-
- st.markdown(
- f'