1064 lines
35 KiB
Python
1064 lines
35 KiB
Python
"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
|
||
|
||
import antimeridian
|
||
import geopandas as gpd
|
||
import numpy as np
|
||
import pandas as pd
|
||
import plotly.graph_objects as go
|
||
import pydeck as pdk
|
||
import streamlit as st
|
||
import xarray as xr
|
||
from shapely.geometry import shape
|
||
|
||
from entropice.dashboard.utils.colors import get_cmap
|
||
|
||
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
|
||
|
||
|
||
def _fix_hex_geometry(geom):
|
||
"""Fix hexagon geometry crossing the antimeridian."""
|
||
try:
|
||
return shape(antimeridian.fix_shape(geom))
|
||
except ValueError as e:
|
||
st.error(f"Error fixing geometry: {e}")
|
||
return geom
|
||
|
||
|
||
def render_alphaearth_overview(ds: xr.Dataset):
|
||
"""Render overview statistics for AlphaEarth embeddings data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing AlphaEarth embeddings.
|
||
|
||
"""
|
||
st.subheader("📊 Data Overview")
|
||
|
||
# Key metrics
|
||
col1, col2, col3, col4 = st.columns(4)
|
||
|
||
with col1:
|
||
st.metric("Cells", f"{len(ds['cell_ids']):,}")
|
||
|
||
with col2:
|
||
st.metric("Embedding Dims", f"{len(ds['band'])}")
|
||
|
||
with col3:
|
||
years = sorted(ds["year"].values)
|
||
st.metric("Years", f"{min(years)}–{max(years)}")
|
||
|
||
with col4:
|
||
st.metric("Aggregations", f"{len(ds['agg'])}")
|
||
|
||
# Show aggregations as badges in an expander
|
||
with st.expander("ℹ️ Data Details", expanded=False):
|
||
st.markdown("**Spatial Aggregations:**")
|
||
aggs = ds["agg"].to_numpy()
|
||
aggs_html = " ".join(
|
||
[
|
||
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
|
||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
|
||
for a in aggs
|
||
]
|
||
)
|
||
st.markdown(aggs_html, unsafe_allow_html=True)
|
||
|
||
|
||
@st.fragment
|
||
def render_alphaearth_plots(ds: xr.Dataset):
|
||
"""Render interactive plots for AlphaEarth embeddings data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing AlphaEarth embeddings.
|
||
|
||
"""
|
||
st.markdown("---")
|
||
st.markdown("**Embedding Distribution by Band**")
|
||
|
||
embeddings_data = ds["embeddings"]
|
||
|
||
# Select year and aggregation for visualization
|
||
col1, col2 = st.columns(2)
|
||
with col1:
|
||
selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
|
||
with col2:
|
||
selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
|
||
|
||
# Get data for selected year and aggregation
|
||
year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
|
||
|
||
# Calculate statistics for each band
|
||
band_stats = []
|
||
for band_idx in range(len(ds["band"])):
|
||
band_data = year_agg_data.isel(band=band_idx).values.flatten()
|
||
band_data = band_data[~np.isnan(band_data)] # Remove NaN values
|
||
|
||
if len(band_data) > 0:
|
||
band_stats.append(
|
||
{
|
||
"Band": band_idx,
|
||
"Mean": float(np.mean(band_data)),
|
||
"Std": float(np.std(band_data)),
|
||
"Min": float(np.min(band_data)),
|
||
"25%": float(np.percentile(band_data, 25)),
|
||
"Median": float(np.median(band_data)),
|
||
"75%": float(np.percentile(band_data, 75)),
|
||
"Max": float(np.max(band_data)),
|
||
}
|
||
)
|
||
|
||
band_df = pd.DataFrame(band_stats)
|
||
|
||
# Create plot showing distribution across bands
|
||
fig = go.Figure()
|
||
|
||
# Add min/max range first (background)
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=band_df["Band"],
|
||
y=band_df["Min"],
|
||
mode="lines",
|
||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||
name="Min/Max Range",
|
||
showlegend=True,
|
||
)
|
||
)
|
||
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=band_df["Band"],
|
||
y=band_df["Max"],
|
||
mode="lines",
|
||
fill="tonexty",
|
||
fillcolor="rgba(200, 200, 200, 0.1)",
|
||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||
showlegend=False,
|
||
)
|
||
)
|
||
|
||
# Add std band
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=band_df["Band"],
|
||
y=band_df["Mean"] - band_df["Std"],
|
||
mode="lines",
|
||
line={"width": 0},
|
||
showlegend=False,
|
||
hoverinfo="skip",
|
||
)
|
||
)
|
||
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=band_df["Band"],
|
||
y=band_df["Mean"] + band_df["Std"],
|
||
mode="lines",
|
||
fill="tonexty",
|
||
fillcolor="rgba(31, 119, 180, 0.2)",
|
||
line={"width": 0},
|
||
name="±1 Std",
|
||
)
|
||
)
|
||
|
||
# Add mean line on top
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=band_df["Band"],
|
||
y=band_df["Mean"],
|
||
mode="lines+markers",
|
||
name="Mean",
|
||
line={"color": "#1f77b4", "width": 2},
|
||
marker={"size": 4},
|
||
)
|
||
)
|
||
|
||
fig.update_layout(
|
||
title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
|
||
xaxis_title="Band",
|
||
yaxis_title="Embedding Value",
|
||
height=450,
|
||
hovermode="x unified",
|
||
)
|
||
|
||
st.plotly_chart(fig, width="stretch")
|
||
|
||
# Band statistics
|
||
with st.expander("📈 Statistics by Embedding Band", expanded=False):
|
||
st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
|
||
|
||
# Calculate statistics for each band
|
||
band_stats = []
|
||
for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
|
||
band_data = embeddings_data.isel(band=band_idx)
|
||
band_stats.append(
|
||
{
|
||
"Band": int(band_idx),
|
||
"Mean": float(band_data.mean().values),
|
||
"Std": float(band_data.std().values),
|
||
"Min": float(band_data.min().values),
|
||
"Max": float(band_data.max().values),
|
||
}
|
||
)
|
||
|
||
band_df = pd.DataFrame(band_stats)
|
||
st.dataframe(band_df, width="stretch", hide_index=True)
|
||
|
||
if len(ds["band"]) > 10:
|
||
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
|
||
|
||
|
||
def render_arcticdem_overview(ds: xr.Dataset):
|
||
"""Render overview statistics for ArcticDEM terrain data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ArcticDEM data.
|
||
|
||
"""
|
||
st.subheader("📊 Data Overview")
|
||
|
||
# Key metrics
|
||
col1, col2, col3 = st.columns(3)
|
||
|
||
with col1:
|
||
st.metric("Cells", f"{len(ds['cell_ids']):,}")
|
||
|
||
with col2:
|
||
st.metric("Variables", f"{len(ds.data_vars)}")
|
||
|
||
with col3:
|
||
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||
|
||
# Show details in expander
|
||
with st.expander("ℹ️ Data Details", expanded=False):
|
||
st.markdown("**Spatial Aggregations:**")
|
||
aggs = ds["aggregations"].to_numpy()
|
||
aggs_html = " ".join(
|
||
[
|
||
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
|
||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
|
||
for a in aggs
|
||
]
|
||
)
|
||
st.markdown(aggs_html, unsafe_allow_html=True)
|
||
|
||
# Statistics by variable
|
||
st.markdown("---")
|
||
st.markdown("**📈 Variable Statistics**")
|
||
|
||
var_stats = []
|
||
for var_name in ds.data_vars:
|
||
var_data = ds[var_name]
|
||
var_stats.append(
|
||
{
|
||
"Variable": var_name,
|
||
"Mean": float(var_data.mean().values),
|
||
"Std": float(var_data.std().values),
|
||
"Min": float(var_data.min().values),
|
||
"Max": float(var_data.max().values),
|
||
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||
}
|
||
)
|
||
|
||
stats_df = pd.DataFrame(var_stats)
|
||
st.dataframe(stats_df, width="stretch", hide_index=True)
|
||
|
||
|
||
@st.fragment
|
||
def render_arcticdem_plots(ds: xr.Dataset):
|
||
"""Render interactive plots for ArcticDEM terrain data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ArcticDEM data.
|
||
|
||
"""
|
||
st.markdown("---")
|
||
st.markdown("**Variable Distributions**")
|
||
|
||
variables = list(ds.data_vars)
|
||
|
||
# Select a variable to visualize
|
||
selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
|
||
|
||
if selected_var:
|
||
var_data = ds[selected_var]
|
||
|
||
# Create histogram
|
||
fig = go.Figure()
|
||
|
||
for agg in ds["aggregations"].to_numpy():
|
||
agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
|
||
agg_data = agg_data[~np.isnan(agg_data)]
|
||
|
||
fig.add_trace(
|
||
go.Histogram(
|
||
x=agg_data,
|
||
name=str(agg),
|
||
opacity=0.7,
|
||
nbinsx=50,
|
||
)
|
||
)
|
||
|
||
fig.update_layout(
|
||
title=f"Distribution of {selected_var} by Aggregation",
|
||
xaxis_title=selected_var,
|
||
yaxis_title="Count",
|
||
barmode="overlay",
|
||
height=400,
|
||
)
|
||
|
||
st.plotly_chart(fig, width="stretch")
|
||
|
||
|
||
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
|
||
"""Render overview statistics for ERA5 climate data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ERA5 data.
|
||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||
|
||
"""
|
||
st.subheader("📊 Data Overview")
|
||
|
||
# Key metrics
|
||
has_agg = "aggregations" in ds.dims
|
||
col1, col2, col3, col4 = st.columns(4)
|
||
|
||
with col1:
|
||
st.metric("Cells", f"{len(ds['cell_ids']):,}")
|
||
|
||
with col2:
|
||
st.metric("Variables", f"{len(ds.data_vars)}")
|
||
|
||
with col3:
|
||
time_values = pd.to_datetime(ds["time"].values)
|
||
st.metric(
|
||
"Time Steps",
|
||
f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}",
|
||
)
|
||
|
||
with col4:
|
||
if has_agg:
|
||
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||
else:
|
||
st.metric("Temporal Type", temporal_type.capitalize())
|
||
|
||
# Show details in expander
|
||
with st.expander("ℹ️ Data Details", expanded=False):
|
||
st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
|
||
st.markdown(
|
||
f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
|
||
)
|
||
|
||
if has_agg:
|
||
st.markdown("**Spatial Aggregations:**")
|
||
aggs = ds["aggregations"].to_numpy()
|
||
aggs_html = " ".join(
|
||
[
|
||
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
|
||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
|
||
for a in aggs
|
||
]
|
||
)
|
||
st.markdown(aggs_html, unsafe_allow_html=True)
|
||
|
||
# Statistics by variable
|
||
st.markdown("---")
|
||
st.markdown("**📈 Variable Statistics**")
|
||
|
||
var_stats = []
|
||
for var_name in ds.data_vars:
|
||
var_data = ds[var_name]
|
||
var_stats.append(
|
||
{
|
||
"Variable": var_name,
|
||
"Mean": float(var_data.mean().values),
|
||
"Std": float(var_data.std().values),
|
||
"Min": float(var_data.min().values),
|
||
"Max": float(var_data.max().values),
|
||
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||
}
|
||
)
|
||
|
||
stats_df = pd.DataFrame(var_stats)
|
||
st.dataframe(stats_df, width="stretch", hide_index=True)
|
||
|
||
|
||
@st.fragment
|
||
def render_era5_plots(ds: xr.Dataset, temporal_type: str):
|
||
"""Render interactive plots for ERA5 climate data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ERA5 data.
|
||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||
|
||
"""
|
||
st.markdown("---")
|
||
st.markdown("**Temporal Trends**")
|
||
|
||
variables = list(ds.data_vars)
|
||
has_agg = "aggregations" in ds.dims
|
||
|
||
if has_agg:
|
||
col1, col2, col3 = st.columns([2, 2, 1])
|
||
with col1:
|
||
selected_var = st.selectbox(
|
||
"Select variable to visualize",
|
||
options=variables,
|
||
key=f"era5_{temporal_type}_var_select",
|
||
)
|
||
with col2:
|
||
selected_agg = st.selectbox(
|
||
"Aggregation",
|
||
options=ds["aggregations"].values,
|
||
key=f"era5_{temporal_type}_agg_select",
|
||
)
|
||
with col3:
|
||
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
|
||
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
|
||
else:
|
||
col1, col2 = st.columns([3, 1])
|
||
with col1:
|
||
selected_var = st.selectbox(
|
||
"Select variable to visualize",
|
||
options=variables,
|
||
key=f"era5_{temporal_type}_var_select",
|
||
)
|
||
with col2:
|
||
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
|
||
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
|
||
|
||
if selected_var:
|
||
var_data = ds[selected_var]
|
||
|
||
# Calculate statistics over space for each time step
|
||
time_values = pd.to_datetime(ds["time"].to_numpy())
|
||
|
||
if has_agg:
|
||
# Select specific aggregation, then calculate stats over cells
|
||
var_data_agg = var_data.sel(aggregations=selected_agg)
|
||
time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
|
||
time_std = var_data_agg.std(dim="cell_ids").to_numpy()
|
||
time_min = var_data_agg.min(dim="cell_ids").to_numpy()
|
||
time_max = var_data_agg.max(dim="cell_ids").to_numpy()
|
||
else:
|
||
time_mean = var_data.mean(dim="cell_ids").to_numpy()
|
||
time_std = var_data.std(dim="cell_ids").to_numpy()
|
||
time_min = var_data.min(dim="cell_ids").to_numpy()
|
||
time_max = var_data.max(dim="cell_ids").to_numpy()
|
||
|
||
fig = go.Figure()
|
||
|
||
# Add min/max range first (background) - optional
|
||
if show_minmax:
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=time_values,
|
||
y=time_min,
|
||
mode="lines",
|
||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||
name="Min/Max Range",
|
||
showlegend=True,
|
||
)
|
||
)
|
||
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=time_values,
|
||
y=time_max,
|
||
mode="lines",
|
||
fill="tonexty",
|
||
fillcolor="rgba(200, 200, 200, 0.1)",
|
||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||
showlegend=False,
|
||
)
|
||
)
|
||
|
||
# Add std band - optional
|
||
if show_std:
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=time_values,
|
||
y=time_mean - time_std,
|
||
mode="lines",
|
||
line={"width": 0},
|
||
showlegend=False,
|
||
hoverinfo="skip",
|
||
)
|
||
)
|
||
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=time_values,
|
||
y=time_mean + time_std,
|
||
mode="lines",
|
||
fill="tonexty",
|
||
fillcolor="rgba(31, 119, 180, 0.2)",
|
||
line={"width": 0},
|
||
name="±1 Std",
|
||
)
|
||
)
|
||
|
||
# Add mean line on top
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=time_values,
|
||
y=time_mean,
|
||
mode="lines+markers",
|
||
name="Mean",
|
||
line={"color": "#1f77b4", "width": 2},
|
||
marker={"size": 4},
|
||
)
|
||
)
|
||
|
||
title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
|
||
fig.update_layout(
|
||
title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
|
||
xaxis_title="Time",
|
||
yaxis_title=selected_var,
|
||
height=400,
|
||
hovermode="x unified",
|
||
)
|
||
|
||
st.plotly_chart(fig, width="stretch")
|
||
|
||
|
||
@st.fragment
|
||
def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||
"""Render interactive pydeck map for AlphaEarth embeddings.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing AlphaEarth data.
|
||
targets: GeoDataFrame with geometry for each cell.
|
||
grid: Grid type ('hex' or 'healpix').
|
||
|
||
"""
|
||
st.subheader("🗺️ AlphaEarth Spatial Distribution")
|
||
|
||
# Year slider (full width)
|
||
years = sorted(ds["year"].values)
|
||
selected_year = st.slider(
|
||
"Year",
|
||
min_value=int(years[0]),
|
||
max_value=int(years[-1]),
|
||
value=int(years[-1]),
|
||
step=1,
|
||
key="alphaearth_year",
|
||
)
|
||
|
||
# Other controls
|
||
col1, col2, col3 = st.columns([2, 2, 1])
|
||
|
||
with col1:
|
||
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
|
||
|
||
with col2:
|
||
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
|
||
|
||
with col3:
|
||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
|
||
|
||
# Extract data for selected parameters
|
||
data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
|
||
|
||
# Create GeoDataFrame
|
||
gdf = targets.copy()
|
||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||
gdf = gdf.set_index("cell_id")
|
||
|
||
# Add values
|
||
values_df = data_values.to_dataframe(name="value")
|
||
gdf = gdf.join(values_df, how="inner")
|
||
|
||
# Convert to WGS84 first
|
||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||
|
||
# Fix geometries after CRS conversion
|
||
if grid == "hex":
|
||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||
|
||
# Normalize values for color mapping
|
||
values = gdf_wgs84["value"].to_numpy()
|
||
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
|
||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||
|
||
# Apply colormap
|
||
cmap = get_cmap("embeddings")
|
||
colors = [cmap(val) for val in normalized]
|
||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||
|
||
# Set elevation based on normalized values
|
||
gdf_wgs84["elevation"] = normalized
|
||
|
||
# Create GeoJSON
|
||
geojson_data = []
|
||
for _, row in gdf_wgs84.iterrows():
|
||
feature = {
|
||
"type": "Feature",
|
||
"geometry": row["geometry"].__geo_interface__,
|
||
"properties": {
|
||
"value": float(row["value"]),
|
||
"fill_color": row["fill_color"],
|
||
"elevation": float(row["elevation"]),
|
||
},
|
||
}
|
||
geojson_data.append(feature)
|
||
|
||
# Create pydeck layer with 3D elevation
|
||
layer = pdk.Layer(
|
||
"GeoJsonLayer",
|
||
geojson_data,
|
||
opacity=opacity,
|
||
stroked=True,
|
||
filled=True,
|
||
extruded=True,
|
||
get_fill_color="properties.fill_color",
|
||
get_line_color=[80, 80, 80],
|
||
get_elevation="properties.elevation",
|
||
elevation_scale=500000,
|
||
line_width_min_pixels=0.5,
|
||
pickable=True,
|
||
)
|
||
|
||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||
|
||
deck = pdk.Deck(
|
||
layers=[layer],
|
||
initial_view_state=view_state,
|
||
tooltip={
|
||
"html": "<b>Value:</b> {value}",
|
||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||
},
|
||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||
)
|
||
|
||
st.pydeck_chart(deck)
|
||
|
||
# Show statistics
|
||
st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
|
||
|
||
|
||
@st.fragment
|
||
def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||
"""Render interactive pydeck map for ArcticDEM terrain data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ArcticDEM data.
|
||
targets: GeoDataFrame with geometry for each cell.
|
||
grid: Grid type ('hex' or 'healpix').
|
||
|
||
"""
|
||
st.subheader("🗺️ ArcticDEM Spatial Distribution")
|
||
|
||
# Controls
|
||
variables = list(ds.data_vars)
|
||
col1, col2, col3 = st.columns([3, 3, 1])
|
||
|
||
with col1:
|
||
selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
|
||
|
||
with col2:
|
||
selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
|
||
|
||
with col3:
|
||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
|
||
|
||
# Extract data for selected parameters
|
||
data_values = ds[selected_var].sel(aggregations=selected_agg)
|
||
|
||
# Create GeoDataFrame
|
||
gdf = targets.copy()
|
||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||
gdf = gdf.set_index("cell_id")
|
||
|
||
# Add values
|
||
values_df = data_values.to_dataframe(name="value")
|
||
gdf = gdf.join(values_df, how="inner")
|
||
|
||
# Add all aggregation values for tooltip
|
||
if len(ds["aggregations"]) > 1:
|
||
for agg in ds["aggregations"].values:
|
||
agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||
# Drop the aggregations column to avoid conflicts
|
||
if "aggregations" in agg_data.columns:
|
||
agg_data = agg_data.drop(columns=["aggregations"])
|
||
gdf = gdf.join(agg_data, how="left")
|
||
|
||
# Convert to WGS84 first
|
||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||
|
||
# Fix geometries after CRS conversion
|
||
if grid == "hex":
|
||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||
|
||
# Normalize values for color mapping
|
||
values = gdf_wgs84["value"].values
|
||
values_clean = values[~np.isnan(values)]
|
||
|
||
if len(values_clean) > 0:
|
||
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||
|
||
# Apply colormap
|
||
cmap = get_cmap(f"arcticdem_{selected_var}")
|
||
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||
|
||
# Set elevation based on normalized values
|
||
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
|
||
|
||
# Create GeoJSON
|
||
geojson_data = []
|
||
for _, row in gdf_wgs84.iterrows():
|
||
properties = {
|
||
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||
"fill_color": row["fill_color"],
|
||
"elevation": float(row["elevation"]),
|
||
}
|
||
# Add all aggregation values if available
|
||
if len(ds["aggregations"]) > 1:
|
||
for agg in ds["aggregations"].values:
|
||
agg_col = f"agg_{agg}"
|
||
if agg_col in row.index:
|
||
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||
|
||
feature = {
|
||
"type": "Feature",
|
||
"geometry": row["geometry"].__geo_interface__,
|
||
"properties": properties,
|
||
}
|
||
geojson_data.append(feature)
|
||
|
||
# Create pydeck layer with 3D elevation
|
||
layer = pdk.Layer(
|
||
"GeoJsonLayer",
|
||
geojson_data,
|
||
opacity=opacity,
|
||
stroked=True,
|
||
filled=True,
|
||
extruded=True,
|
||
get_fill_color="properties.fill_color",
|
||
get_line_color=[80, 80, 80],
|
||
get_elevation="properties.elevation",
|
||
elevation_scale=500000,
|
||
line_width_min_pixels=0.5,
|
||
pickable=True,
|
||
)
|
||
|
||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||
|
||
# Build tooltip HTML for ArcticDEM
|
||
if len(ds["aggregations"]) > 1:
|
||
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||
for agg in ds["aggregations"].values:
|
||
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||
tooltip_html = "".join(tooltip_lines)
|
||
else:
|
||
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||
|
||
deck = pdk.Deck(
|
||
layers=[layer],
|
||
initial_view_state=view_state,
|
||
tooltip={
|
||
"html": tooltip_html,
|
||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||
},
|
||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||
)
|
||
|
||
st.pydeck_chart(deck)
|
||
|
||
# Show statistics
|
||
st.caption(
|
||
f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
|
||
f"Std: {np.nanstd(values_clean):.2f}"
|
||
)
|
||
else:
|
||
st.warning("No valid data available for selected parameters")
|
||
|
||
|
||
@st.fragment
|
||
def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
|
||
"""Render interactive pydeck map for grid cell areas.
|
||
|
||
Args:
|
||
grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
|
||
grid: Grid type ('hex' or 'healpix').
|
||
|
||
"""
|
||
st.subheader("🗺️ Grid Cell Areas Distribution")
|
||
|
||
# Controls
|
||
col1, col2 = st.columns([3, 1])
|
||
|
||
with col1:
|
||
area_metric = st.selectbox(
|
||
"Area Metric",
|
||
options=["cell_area", "land_area", "water_area", "land_ratio"],
|
||
format_func=lambda x: x.replace("_", " ").title(),
|
||
key="areas_metric",
|
||
)
|
||
|
||
with col2:
|
||
opacity = st.slider(
|
||
"Opacity",
|
||
min_value=0.1,
|
||
max_value=1.0,
|
||
value=0.7,
|
||
step=0.1,
|
||
key="areas_map_opacity",
|
||
)
|
||
|
||
# Create GeoDataFrame
|
||
gdf = grid_gdf.copy()
|
||
|
||
# Convert to WGS84 first
|
||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||
|
||
# Fix geometries after CRS conversion
|
||
if grid == "hex":
|
||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||
|
||
# Get values for the selected metric
|
||
values = gdf_wgs84[area_metric].to_numpy()
|
||
|
||
# Normalize values for color mapping
|
||
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
|
||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||
|
||
# Apply colormap based on metric type
|
||
if area_metric == "land_ratio":
|
||
cmap = get_cmap("terrain") # Different colormap for ratio
|
||
else:
|
||
cmap = get_cmap("terrain")
|
||
|
||
colors = [cmap(val) for val in normalized]
|
||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||
|
||
# Set elevation based on normalized values for 3D visualization
|
||
gdf_wgs84["elevation"] = normalized
|
||
|
||
# Create GeoJSON
|
||
geojson_data = []
|
||
for _, row in gdf_wgs84.iterrows():
|
||
feature = {
|
||
"type": "Feature",
|
||
"geometry": row["geometry"].__geo_interface__,
|
||
"properties": {
|
||
"cell_area": f"{float(row['cell_area']):.2f}",
|
||
"land_area": f"{float(row['land_area']):.2f}",
|
||
"water_area": f"{float(row['water_area']):.2f}",
|
||
"land_ratio": f"{float(row['land_ratio']):.2%}",
|
||
"fill_color": row["fill_color"],
|
||
"elevation": float(row["elevation"]),
|
||
},
|
||
}
|
||
geojson_data.append(feature)
|
||
|
||
# Create pydeck layer with 3D elevation
|
||
layer = pdk.Layer(
|
||
"GeoJsonLayer",
|
||
geojson_data,
|
||
opacity=opacity,
|
||
stroked=True,
|
||
filled=True,
|
||
extruded=True,
|
||
get_fill_color="properties.fill_color",
|
||
get_line_color=[80, 80, 80],
|
||
get_elevation="properties.elevation",
|
||
elevation_scale=500000,
|
||
line_width_min_pixels=0.5,
|
||
pickable=True,
|
||
)
|
||
|
||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||
|
||
deck = pdk.Deck(
|
||
layers=[layer],
|
||
initial_view_state=view_state,
|
||
tooltip={
|
||
"html": "<b>Cell Area:</b> {cell_area} km²<br/>"
|
||
"<b>Land Area:</b> {land_area} km²<br/>"
|
||
"<b>Water Area:</b> {water_area} km²<br/>"
|
||
"<b>Land Ratio:</b> {land_ratio}",
|
||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||
},
|
||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||
)
|
||
|
||
st.pydeck_chart(deck)
|
||
|
||
# Show statistics
|
||
st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
|
||
|
||
# Show additional info
|
||
st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||
|
||
|
||
@st.fragment
|
||
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
|
||
"""Render interactive pydeck map for ERA5 climate data.
|
||
|
||
Args:
|
||
ds: xarray Dataset containing ERA5 data.
|
||
targets: GeoDataFrame with geometry for each cell.
|
||
grid: Grid type ('hex' or 'healpix').
|
||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||
|
||
"""
|
||
st.subheader("🗺️ ERA5 Spatial Distribution")
|
||
|
||
# Controls
|
||
variables = list(ds.data_vars)
|
||
has_agg = "aggregations" in ds.dims
|
||
|
||
# Top row: Variable, Aggregation (if applicable), and Opacity
|
||
if has_agg:
|
||
col1, col2, col3 = st.columns([2, 2, 1])
|
||
with col1:
|
||
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
|
||
with col2:
|
||
selected_agg = st.selectbox(
|
||
"Aggregation",
|
||
options=ds["aggregations"].values,
|
||
key=f"era5_{temporal_type}_agg",
|
||
)
|
||
with col3:
|
||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||
else:
|
||
col1, col2 = st.columns([4, 1])
|
||
with col1:
|
||
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
|
||
with col2:
|
||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||
|
||
# Bottom row: Time slider (full width)
|
||
time_values = pd.to_datetime(ds["time"].values)
|
||
time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
|
||
selected_time_idx = st.slider(
|
||
"Time",
|
||
min_value=0,
|
||
max_value=len(time_values) - 1,
|
||
value=len(time_values) - 1,
|
||
format="",
|
||
key=f"era5_{temporal_type}_time_slider",
|
||
)
|
||
st.caption(f"Selected: {time_labels[selected_time_idx]}")
|
||
selected_time = time_values[selected_time_idx]
|
||
|
||
# Extract data for selected parameters
|
||
if has_agg:
|
||
data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
|
||
else:
|
||
data_values = ds[selected_var].sel(time=selected_time)
|
||
|
||
# Create GeoDataFrame
|
||
gdf = targets.copy()
|
||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||
gdf = gdf.set_index("cell_id")
|
||
|
||
# Add values
|
||
values_df = data_values.to_dataframe(name="value")
|
||
gdf = gdf.join(values_df, how="inner")
|
||
|
||
# Add all aggregation values for tooltip if has_agg
|
||
if has_agg and len(ds["aggregations"]) > 1:
|
||
for agg in ds["aggregations"].values:
|
||
agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||
# Drop dimension columns to avoid conflicts
|
||
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
|
||
if cols_to_drop:
|
||
agg_data = agg_data.drop(columns=cols_to_drop)
|
||
gdf = gdf.join(agg_data, how="left")
|
||
|
||
# Convert to WGS84 first
|
||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||
|
||
# Fix geometries after CRS conversion
|
||
if grid == "hex":
|
||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||
|
||
# Normalize values for color mapping
|
||
values = gdf_wgs84["value"].values
|
||
values_clean = values[~np.isnan(values)]
|
||
|
||
if len(values_clean) > 0:
|
||
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||
|
||
# Apply colormap - use variable-specific colors
|
||
cmap = get_cmap(f"era5_{selected_var}")
|
||
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||
|
||
# Set elevation based on normalized values
|
||
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
|
||
|
||
# Create GeoJSON
|
||
geojson_data = []
|
||
for _, row in gdf_wgs84.iterrows():
|
||
properties = {
|
||
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||
"fill_color": row["fill_color"],
|
||
"elevation": float(row["elevation"]),
|
||
}
|
||
# Add all aggregation values if available
|
||
if has_agg and len(ds["aggregations"]) > 1:
|
||
for agg in ds["aggregations"].values:
|
||
agg_col = f"agg_{agg}"
|
||
if agg_col in row.index:
|
||
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||
|
||
feature = {
|
||
"type": "Feature",
|
||
"geometry": row["geometry"].__geo_interface__,
|
||
"properties": properties,
|
||
}
|
||
geojson_data.append(feature)
|
||
|
||
# Create pydeck layer with 3D elevation
|
||
layer = pdk.Layer(
|
||
"GeoJsonLayer",
|
||
geojson_data,
|
||
opacity=opacity,
|
||
stroked=True,
|
||
filled=True,
|
||
extruded=True,
|
||
get_fill_color="properties.fill_color",
|
||
get_line_color=[80, 80, 80],
|
||
get_elevation="properties.elevation",
|
||
elevation_scale=500000,
|
||
line_width_min_pixels=0.5,
|
||
pickable=True,
|
||
)
|
||
|
||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||
|
||
# Build tooltip HTML for ERA5
|
||
if has_agg and len(ds["aggregations"]) > 1:
|
||
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||
for agg in ds["aggregations"].values:
|
||
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||
tooltip_html = "".join(tooltip_lines)
|
||
else:
|
||
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||
|
||
deck = pdk.Deck(
|
||
layers=[layer],
|
||
initial_view_state=view_state,
|
||
tooltip={
|
||
"html": tooltip_html,
|
||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||
},
|
||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||
)
|
||
|
||
st.pydeck_chart(deck)
|
||
|
||
# Show statistics
|
||
st.caption(
|
||
f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
|
||
f"Std: {np.nanstd(values_clean):.4f}"
|
||
)
|
||
else:
|
||
st.warning("No valid data available for selected parameters")
|