entropice/src/entropice/dashboard/plots/source_data.py

1064 lines
35 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
import antimeridian
import geopandas as gpd
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
import xarray as xr
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_cmap
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def render_alphaearth_overview(ds: xr.Dataset):
"""Render overview statistics for AlphaEarth embeddings data.
Args:
ds: xarray Dataset containing AlphaEarth embeddings.
"""
st.subheader("📊 Data Overview")
# Key metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Embedding Dims", f"{len(ds['band'])}")
with col3:
years = sorted(ds["year"].values)
st.metric("Years", f"{min(years)}{max(years)}")
with col4:
st.metric("Aggregations", f"{len(ds['agg'])}")
# Show aggregations as badges in an expander
with st.expander(" Data Details", expanded=False):
st.markdown("**Spatial Aggregations:**")
aggs = ds["agg"].to_numpy()
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
@st.fragment
def render_alphaearth_plots(ds: xr.Dataset):
"""Render interactive plots for AlphaEarth embeddings data.
Args:
ds: xarray Dataset containing AlphaEarth embeddings.
"""
st.markdown("---")
st.markdown("**Embedding Distribution by Band**")
embeddings_data = ds["embeddings"]
# Select year and aggregation for visualization
col1, col2 = st.columns(2)
with col1:
selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
with col2:
selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
# Get data for selected year and aggregation
year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
# Calculate statistics for each band
band_stats = []
for band_idx in range(len(ds["band"])):
band_data = year_agg_data.isel(band=band_idx).values.flatten()
band_data = band_data[~np.isnan(band_data)] # Remove NaN values
if len(band_data) > 0:
band_stats.append(
{
"Band": band_idx,
"Mean": float(np.mean(band_data)),
"Std": float(np.std(band_data)),
"Min": float(np.min(band_data)),
"25%": float(np.percentile(band_data, 25)),
"Median": float(np.median(band_data)),
"75%": float(np.percentile(band_data, 75)),
"Max": float(np.max(band_data)),
}
)
band_df = pd.DataFrame(band_stats)
# Create plot showing distribution across bands
fig = go.Figure()
# Add min/max range first (background)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Min"],
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Max"],
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add std band
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"] - band_df["Std"],
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"] + band_df["Std"],
mode="lines",
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.2)",
line={"width": 0},
name="±1 Std",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"],
mode="lines+markers",
name="Mean",
line={"color": "#1f77b4", "width": 2},
marker={"size": 4},
)
)
fig.update_layout(
title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
xaxis_title="Band",
yaxis_title="Embedding Value",
height=450,
hovermode="x unified",
)
st.plotly_chart(fig, width="stretch")
# Band statistics
with st.expander("📈 Statistics by Embedding Band", expanded=False):
st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
# Calculate statistics for each band
band_stats = []
for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
band_data = embeddings_data.isel(band=band_idx)
band_stats.append(
{
"Band": int(band_idx),
"Mean": float(band_data.mean().values),
"Std": float(band_data.std().values),
"Min": float(band_data.min().values),
"Max": float(band_data.max().values),
}
)
band_df = pd.DataFrame(band_stats)
st.dataframe(band_df, width="stretch", hide_index=True)
if len(ds["band"]) > 10:
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
def render_arcticdem_overview(ds: xr.Dataset):
"""Render overview statistics for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
"""
st.subheader("📊 Data Overview")
# Key metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
# Show details in expander
with st.expander(" Data Details", expanded=False):
st.markdown("**Spatial Aggregations:**")
aggs = ds["aggregations"].to_numpy()
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
st.markdown("**📈 Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
var_data = ds[var_name]
var_stats.append(
{
"Variable": var_name,
"Mean": float(var_data.mean().values),
"Std": float(var_data.std().values),
"Min": float(var_data.min().values),
"Max": float(var_data.max().values),
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
}
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, width="stretch", hide_index=True)
@st.fragment
def render_arcticdem_plots(ds: xr.Dataset):
"""Render interactive plots for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
"""
st.markdown("---")
st.markdown("**Variable Distributions**")
variables = list(ds.data_vars)
# Select a variable to visualize
selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
if selected_var:
var_data = ds[selected_var]
# Create histogram
fig = go.Figure()
for agg in ds["aggregations"].to_numpy():
agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
agg_data = agg_data[~np.isnan(agg_data)]
fig.add_trace(
go.Histogram(
x=agg_data,
name=str(agg),
opacity=0.7,
nbinsx=50,
)
)
fig.update_layout(
title=f"Distribution of {selected_var} by Aggregation",
xaxis_title=selected_var,
yaxis_title="Count",
barmode="overlay",
height=400,
)
st.plotly_chart(fig, width="stretch")
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
"""Render overview statistics for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.subheader("📊 Data Overview")
# Key metrics
has_agg = "aggregations" in ds.dims
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
time_values = pd.to_datetime(ds["time"].values)
st.metric(
"Time Steps",
f"{time_values.min().strftime('%Y')}{time_values.max().strftime('%Y')}",
)
with col4:
if has_agg:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
else:
st.metric("Temporal Type", temporal_type.capitalize())
# Show details in expander
with st.expander(" Data Details", expanded=False):
st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
st.markdown(
f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
)
if has_agg:
st.markdown("**Spatial Aggregations:**")
aggs = ds["aggregations"].to_numpy()
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
st.markdown("**📈 Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
var_data = ds[var_name]
var_stats.append(
{
"Variable": var_name,
"Mean": float(var_data.mean().values),
"Std": float(var_data.std().values),
"Min": float(var_data.min().values),
"Max": float(var_data.max().values),
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
}
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, width="stretch", hide_index=True)
@st.fragment
def render_era5_plots(ds: xr.Dataset, temporal_type: str):
"""Render interactive plots for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.markdown("---")
st.markdown("**Temporal Trends**")
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
if has_agg:
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_var = st.selectbox(
"Select variable to visualize",
options=variables,
key=f"era5_{temporal_type}_var_select",
)
with col2:
selected_agg = st.selectbox(
"Aggregation",
options=ds["aggregations"].values,
key=f"era5_{temporal_type}_agg_select",
)
with col3:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
else:
col1, col2 = st.columns([3, 1])
with col1:
selected_var = st.selectbox(
"Select variable to visualize",
options=variables,
key=f"era5_{temporal_type}_var_select",
)
with col2:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
if selected_var:
var_data = ds[selected_var]
# Calculate statistics over space for each time step
time_values = pd.to_datetime(ds["time"].to_numpy())
if has_agg:
# Select specific aggregation, then calculate stats over cells
var_data_agg = var_data.sel(aggregations=selected_agg)
time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
time_std = var_data_agg.std(dim="cell_ids").to_numpy()
time_min = var_data_agg.min(dim="cell_ids").to_numpy()
time_max = var_data_agg.max(dim="cell_ids").to_numpy()
else:
time_mean = var_data.mean(dim="cell_ids").to_numpy()
time_std = var_data.std(dim="cell_ids").to_numpy()
time_min = var_data.min(dim="cell_ids").to_numpy()
time_max = var_data.max(dim="cell_ids").to_numpy()
fig = go.Figure()
# Add min/max range first (background) - optional
if show_minmax:
fig.add_trace(
go.Scatter(
x=time_values,
y=time_min,
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_max,
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add std band - optional
if show_std:
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean - time_std,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean + time_std,
mode="lines",
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.2)",
line={"width": 0},
name="±1 Std",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean,
mode="lines+markers",
name="Mean",
line={"color": "#1f77b4", "width": 2},
marker={"size": 4},
)
)
title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
fig.update_layout(
title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
xaxis_title="Time",
yaxis_title=selected_var,
height=400,
hovermode="x unified",
)
st.plotly_chart(fig, width="stretch")
@st.fragment
def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for AlphaEarth embeddings.
Args:
ds: xarray Dataset containing AlphaEarth data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ AlphaEarth Spatial Distribution")
# Year slider (full width)
years = sorted(ds["year"].values)
selected_year = st.slider(
"Year",
min_value=int(years[0]),
max_value=int(years[-1]),
value=int(years[-1]),
step=1,
key="alphaearth_year",
)
# Other controls
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
with col2:
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
# Extract data for selected parameters
data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].to_numpy()
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap
cmap = get_cmap("embeddings")
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = normalized
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"value": float(row["value"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Value:</b> {value}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
@st.fragment
def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ ArcticDEM Spatial Distribution")
# Controls
variables = list(ds.data_vars)
col1, col2, col3 = st.columns([3, 3, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
with col2:
selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
# Extract data for selected parameters
data_values = ds[selected_var].sel(aggregations=selected_agg)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Add all aggregation values for tooltip
if len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop the aggregations column to avoid conflicts
if "aggregations" in agg_data.columns:
agg_data = agg_data.drop(columns=["aggregations"])
gdf = gdf.join(agg_data, how="left")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].values
values_clean = values[~np.isnan(values)]
if len(values_clean) > 0:
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap
cmap = get_cmap(f"arcticdem_{selected_var}")
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_col = f"agg_{agg}"
if agg_col in row.index:
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": properties,
}
geojson_data.append(feature)
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ArcticDEM
if len(ds["aggregations"]) > 1:
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
tooltip_lines.append("<b>All aggregations:</b><br/>")
for agg in ds["aggregations"].values:
tooltip_lines.append(f"&nbsp;&nbsp;{agg}: {{agg_{agg}}}<br/>")
tooltip_html = "".join(tooltip_lines)
else:
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(
f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
f"Std: {np.nanstd(values_clean):.2f}"
)
else:
st.warning("No valid data available for selected parameters")
@st.fragment
def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for grid cell areas.
Args:
grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ Grid Cell Areas Distribution")
# Controls
col1, col2 = st.columns([3, 1])
with col1:
area_metric = st.selectbox(
"Area Metric",
options=["cell_area", "land_area", "water_area", "land_ratio"],
format_func=lambda x: x.replace("_", " ").title(),
key="areas_metric",
)
with col2:
opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="areas_map_opacity",
)
# Create GeoDataFrame
gdf = grid_gdf.copy()
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Get values for the selected metric
values = gdf_wgs84[area_metric].to_numpy()
# Normalize values for color mapping
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap based on metric type
if area_metric == "land_ratio":
cmap = get_cmap("terrain") # Different colormap for ratio
else:
cmap = get_cmap("terrain")
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values for 3D visualization
gdf_wgs84["elevation"] = normalized
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"cell_area": f"{float(row['cell_area']):.2f}",
"land_area": f"{float(row['land_area']):.2f}",
"water_area": f"{float(row['water_area']):.2f}",
"land_ratio": f"{float(row['land_ratio']):.2%}",
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Cell Area:</b> {cell_area} km²<br/>"
"<b>Land Area:</b> {land_area} km²<br/>"
"<b>Water Area:</b> {water_area} km²<br/>"
"<b>Land Ratio:</b> {land_ratio}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
# Show additional info
st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
@st.fragment
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
"""Render interactive pydeck map for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.subheader("🗺️ ERA5 Spatial Distribution")
# Controls
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
# Top row: Variable, Aggregation (if applicable), and Opacity
if has_agg:
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
selected_agg = st.selectbox(
"Aggregation",
options=ds["aggregations"].values,
key=f"era5_{temporal_type}_agg",
)
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
col1, col2 = st.columns([4, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
# Bottom row: Time slider (full width)
time_values = pd.to_datetime(ds["time"].values)
time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
selected_time_idx = st.slider(
"Time",
min_value=0,
max_value=len(time_values) - 1,
value=len(time_values) - 1,
format="",
key=f"era5_{temporal_type}_time_slider",
)
st.caption(f"Selected: {time_labels[selected_time_idx]}")
selected_time = time_values[selected_time_idx]
# Extract data for selected parameters
if has_agg:
data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
else:
data_values = ds[selected_var].sel(time=selected_time)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Add all aggregation values for tooltip if has_agg
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop dimension columns to avoid conflicts
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
if cols_to_drop:
agg_data = agg_data.drop(columns=cols_to_drop)
gdf = gdf.join(agg_data, how="left")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].values
values_clean = values[~np.isnan(values)]
if len(values_clean) > 0:
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap - use variable-specific colors
cmap = get_cmap(f"era5_{selected_var}")
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_col = f"agg_{agg}"
if agg_col in row.index:
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": properties,
}
geojson_data.append(feature)
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ERA5
if has_agg and len(ds["aggregations"]) > 1:
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
tooltip_lines.append("<b>All aggregations:</b><br/>")
for agg in ds["aggregations"].values:
tooltip_lines.append(f"&nbsp;&nbsp;{agg}: {{agg_{agg}}}<br/>")
tooltip_html = "".join(tooltip_lines)
else:
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(
f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
f"Std: {np.nanstd(values_clean):.4f}"
)
else:
st.warning("No valid data available for selected parameters")