Add data sources viz to dashboard

This commit is contained in:
Tobias Hölzer 2025-12-19 00:24:17 +01:00
parent f5ea72e05e
commit 3d6417ef6b
3 changed files with 952 additions and 19 deletions

View file

@ -0,0 +1,798 @@
"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
import antimeridian
import geopandas as gpd
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
import xarray as xr
from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def render_alphaearth_overview(ds: xr.Dataset):
"""Render overview statistics for AlphaEarth embeddings data.
Args:
ds: xarray Dataset containing AlphaEarth embeddings.
"""
st.subheader("📊 AlphaEarth Embeddings Statistics")
# Overall statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Embedding Dimensions", f"{len(ds['band'])}")
with col3:
st.metric("Years Available", f"{len(ds['year'])}")
with col4:
st.metric("Aggregations", f"{len(ds['agg'])}")
# Show temporal coverage
st.markdown("**Temporal Coverage:**")
years = sorted(ds["year"].values)
st.write(f"Years: {min(years)} - {max(years)}")
# Show aggregations
st.markdown("**Available Aggregations:**")
aggs = ds["agg"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
@st.fragment
def render_alphaearth_plots(ds: xr.Dataset):
"""Render interactive plots for AlphaEarth embeddings data.
Args:
ds: xarray Dataset containing AlphaEarth embeddings.
"""
st.markdown("---")
st.markdown("**Embedding Distribution by Band**")
embeddings_data = ds["embeddings"]
# Select year and aggregation for visualization
col1, col2 = st.columns(2)
with col1:
selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
with col2:
selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
# Get data for selected year and aggregation
year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
# Calculate statistics for each band
band_stats = []
for band_idx in range(len(ds["band"])):
band_data = year_agg_data.isel(band=band_idx).values.flatten()
band_data = band_data[~np.isnan(band_data)] # Remove NaN values
if len(band_data) > 0:
band_stats.append(
{
"Band": band_idx,
"Mean": float(np.mean(band_data)),
"Std": float(np.std(band_data)),
"Min": float(np.min(band_data)),
"25%": float(np.percentile(band_data, 25)),
"Median": float(np.median(band_data)),
"75%": float(np.percentile(band_data, 75)),
"Max": float(np.max(band_data)),
}
)
band_df = pd.DataFrame(band_stats)
# Create plot showing distribution across bands
fig = go.Figure()
# Add min/max range first (background)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Min"],
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Max"],
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add std band
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"] - band_df["Std"],
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"] + band_df["Std"],
mode="lines",
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.2)",
line={"width": 0},
name="±1 Std",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"],
mode="lines+markers",
name="Mean",
line={"color": "#1f77b4", "width": 2},
marker={"size": 4},
)
)
fig.update_layout(
title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
xaxis_title="Band",
yaxis_title="Embedding Value",
height=450,
hovermode="x unified",
)
st.plotly_chart(fig, use_container_width=True)
# Band statistics
with st.expander("📈 Statistics by Embedding Band", expanded=False):
st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
# Calculate statistics for each band
band_stats = []
for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
band_data = embeddings_data.isel(band=band_idx)
band_stats.append(
{
"Band": int(band_idx),
"Mean": float(band_data.mean().values),
"Std": float(band_data.std().values),
"Min": float(band_data.min().values),
"Max": float(band_data.max().values),
}
)
band_df = pd.DataFrame(band_stats)
st.dataframe(band_df, use_container_width=True, hide_index=True)
if len(ds["band"]) > 10:
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
def render_arcticdem_overview(ds: xr.Dataset):
"""Render overview statistics for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
"""
st.subheader("🏔️ ArcticDEM Terrain Statistics")
# Overall statistics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
# Show available variables
st.markdown("**Available Variables:**")
variables = list(ds.data_vars)
st.write(", ".join(variables))
# Show aggregations
st.markdown("**Available Aggregations:**")
aggs = ds["aggregations"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
# Statistics by variable
st.markdown("---")
st.markdown("**Variable Statistics (across all aggregations)**")
var_stats = []
for var_name in ds.data_vars:
var_data = ds[var_name]
var_stats.append(
{
"Variable": var_name,
"Mean": float(var_data.mean().values),
"Std": float(var_data.std().values),
"Min": float(var_data.min().values),
"Max": float(var_data.max().values),
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
}
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, use_container_width=True, hide_index=True)
@st.fragment
def render_arcticdem_plots(ds: xr.Dataset):
"""Render interactive plots for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
"""
st.markdown("---")
st.markdown("**Variable Distributions**")
variables = list(ds.data_vars)
# Select a variable to visualize
selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
if selected_var:
var_data = ds[selected_var]
# Create histogram
fig = go.Figure()
for agg in ds["aggregations"].to_numpy():
agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
agg_data = agg_data[~np.isnan(agg_data)]
fig.add_trace(
go.Histogram(
x=agg_data,
name=str(agg),
opacity=0.7,
nbinsx=50,
)
)
fig.update_layout(
title=f"Distribution of {selected_var} by Aggregation",
xaxis_title=selected_var,
yaxis_title="Count",
barmode="overlay",
height=400,
)
st.plotly_chart(fig, use_container_width=True)
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
"""Render overview statistics for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.subheader(f"🌡️ ERA5 Climate Statistics ({temporal_type.capitalize()})")
# Overall statistics
has_agg = "aggregations" in ds.dims
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
st.metric("Time Steps", f"{len(ds['time'])}")
with col4:
if has_agg:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
else:
st.metric("Aggregations", "1")
# Show available variables
st.markdown("**Available Variables:**")
variables = list(ds.data_vars)
st.write(", ".join(variables))
# Show temporal range
st.markdown("**Temporal Range:**")
time_values = pd.to_datetime(ds["time"].values)
st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
if has_agg:
st.markdown("**Available Aggregations:**")
aggs = ds["aggregations"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
# Statistics by variable
st.markdown("---")
st.markdown("**Variable Statistics (across all time steps and aggregations)**")
var_stats = []
for var_name in ds.data_vars:
var_data = ds[var_name]
var_stats.append(
{
"Variable": var_name,
"Mean": float(var_data.mean().values),
"Std": float(var_data.std().values),
"Min": float(var_data.min().values),
"Max": float(var_data.max().values),
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
}
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, use_container_width=True, hide_index=True)
@st.fragment
def render_era5_plots(ds: xr.Dataset, temporal_type: str):
"""Render interactive plots for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.markdown("---")
st.markdown("**Temporal Trends**")
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
)
if selected_var:
var_data = ds[selected_var]
# Calculate mean over space for each time step
if has_agg:
# Average over aggregations first, then over cells
time_series = var_data.mean(dim=["cell_ids", "aggregations"])
else:
time_series = var_data.mean(dim="cell_ids")
time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=time_df["Time"],
y=time_df["Value"],
mode="lines+markers",
name=selected_var,
line={"width": 2},
)
)
fig.update_layout(
title=f"Temporal Trend of {selected_var} (Spatial Mean)",
xaxis_title="Time",
yaxis_title=selected_var,
height=400,
hovermode="x unified",
)
st.plotly_chart(fig, use_container_width=True)
@st.fragment
def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for AlphaEarth embeddings.
Args:
ds: xarray Dataset containing AlphaEarth data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ AlphaEarth Spatial Distribution")
# Controls
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
with col1:
selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
with col2:
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
with col3:
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
with col4:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
# Extract data for selected parameters
data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].to_numpy()
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap
cmap = get_cmap("embeddings")
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"value": float(row["value"]),
"fill_color": row["fill_color"],
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={"html": "<b>Value:</b> {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
@st.fragment
def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for ArcticDEM terrain data.
Args:
ds: xarray Dataset containing ArcticDEM data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ ArcticDEM Spatial Distribution")
# Controls
variables = list(ds.data_vars)
col1, col2, col3 = st.columns([3, 3, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
with col2:
selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
# Extract data for selected parameters
data_values = ds[selected_var].sel(aggregations=selected_agg)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Add all aggregation values for tooltip
if len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop the aggregations column to avoid conflicts
if "aggregations" in agg_data.columns:
agg_data = agg_data.drop(columns=["aggregations"])
gdf = gdf.join(agg_data, how="left")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].values
values_clean = values[~np.isnan(values)]
if len(values_clean) > 0:
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap
cmap = get_cmap(f"arcticdem_{selected_var}")
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
}
# Add all aggregation values if available
if len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_col = f"agg_{agg}"
if agg_col in row.index:
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": properties,
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
# Build tooltip HTML for ArcticDEM
if len(ds["aggregations"]) > 1:
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
tooltip_lines.append("<b>All aggregations:</b><br/>")
for agg in ds["aggregations"].values:
tooltip_lines.append(f"&nbsp;&nbsp;{agg}: {{agg_{agg}}}<br/>")
tooltip_html = "".join(tooltip_lines)
else:
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(
f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
f"Std: {np.nanstd(values_clean):.2f}"
)
else:
st.warning("No valid data available for selected parameters")
@st.fragment
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
"""Render interactive pydeck map for ERA5 climate data.
Args:
ds: xarray Dataset containing ERA5 data.
targets: GeoDataFrame with geometry for each cell.
grid: Grid type ('hex' or 'healpix').
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.subheader("🗺️ ERA5 Spatial Distribution")
# Controls
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
if has_agg:
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
else:
col1, col2, col3 = st.columns([3, 3, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
# Convert time to readable format
time_values = pd.to_datetime(ds["time"].values)
time_options = {str(t): t for t in time_values}
selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
if has_agg:
with col3:
selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
)
with col4:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
# Extract data for selected parameters
time_val = time_options[selected_time]
if has_agg:
data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
else:
data_values = ds[selected_var].sel(time=time_val)
# Create GeoDataFrame
gdf = targets.copy()
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
gdf = gdf.set_index("cell_id")
# Add values
values_df = data_values.to_dataframe(name="value")
gdf = gdf.join(values_df, how="inner")
# Add all aggregation values for tooltip if has_agg
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop dimension columns to avoid conflicts
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
if cols_to_drop:
agg_data = agg_data.drop(columns=cols_to_drop)
gdf = gdf.join(agg_data, how="left")
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Normalize values for color mapping
values = gdf_wgs84["value"].values
values_clean = values[~np.isnan(values)]
if len(values_clean) > 0:
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap - use variable-specific colors
cmap = get_cmap(f"era5_{selected_var}")
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
}
# Add all aggregation values if available
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_col = f"agg_{agg}"
if agg_col in row.index:
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": properties,
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
# Build tooltip HTML for ERA5
if has_agg and len(ds["aggregations"]) > 1:
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
tooltip_lines.append("<b>All aggregations:</b><br/>")
for agg in ds["aggregations"].values:
tooltip_lines.append(f"&nbsp;&nbsp;{agg}: {{agg_{agg}}}<br/>")
tooltip_html = "".join(tooltip_lines)
else:
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(
f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
f"Std: {np.nanstd(values_clean):.4f}"
)
else:
st.warning("No valid data available for selected parameters")

View file

@ -2,8 +2,19 @@
import streamlit as st import streamlit as st
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
render_alphaearth_overview,
render_alphaearth_plots,
render_arcticdem_map,
render_arcticdem_overview,
render_arcticdem_plots,
render_era5_map,
render_era5_overview,
render_era5_plots,
)
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
from entropice.dashboard.utils.data import load_all_training_data from entropice.dashboard.utils.data import load_all_training_data, load_source_data
from entropice.dataset import DatasetEnsemble from entropice.dataset import DatasetEnsemble
@ -107,7 +118,28 @@ def render_training_data_page():
# Display dataset ID in a styled container # Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`") st.info(f"**Dataset ID:** `{ensemble.id()}`")
# Create tabs for different data views
tab_names = ["📊 Labels"]
# Add tabs for each member
for member in ensemble.members:
if member == "AlphaEarth":
tab_names.append("🌍 AlphaEarth")
elif member == "ArcticDEM":
tab_names.append("🏔️ ArcticDEM")
elif member.startswith("ERA5"):
# Group ERA5 temporal variants
if "🌡️ ERA5" not in tab_names:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Labels tab
with tabs[0]:
st.markdown("### Target Labels Distribution and Spatial Visualization")
# Load training data for all three tasks # Load training data for all three tasks
with st.spinner("Loading training data for all tasks..."):
train_data_dict = load_all_training_data(ensemble) train_data_dict = load_all_training_data(ensemble)
# Calculate total samples (use binary as reference) # Calculate total samples (use binary as reference)
@ -115,7 +147,9 @@ def render_training_data_page():
train_samples = (train_data_dict["binary"].split == "train").sum().item() train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item() test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks") st.success(
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
)
# Render distribution histograms # Render distribution histograms
st.markdown("---") st.markdown("---")
@ -123,14 +157,95 @@ def render_training_data_page():
st.markdown("---") st.markdown("---")
# Render spatial map (as a fragment for efficient re-rendering) # Render spatial map
# Extract geometries from the X.data dataframe (which has geometry as a column)
# The index should be cell_id
binary_dataset = train_data_dict["binary"] binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
render_spatial_map(train_data_dict) render_spatial_map(train_data_dict)
# Add more components and visualizations as needed for training data. st.balloons()
# AlphaEarth tab
tab_idx = 1
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### AlphaEarth Embeddings Analysis")
with st.spinner("Loading AlphaEarth data..."):
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
render_alphaearth_overview(alphaearth_ds)
render_alphaearth_plots(alphaearth_ds)
st.markdown("---")
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
st.balloons()
tab_idx += 1
# ArcticDEM tab
if "ArcticDEM" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### ArcticDEM Terrain Analysis")
with st.spinner("Loading ArcticDEM data..."):
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
st.balloons()
tab_idx += 1
# ERA5 tab (combining all temporal variants)
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
if era5_members:
with tabs[tab_idx]:
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5:
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
with st.spinner(f"Loading {selected_era5} data..."):
era5_ds, targets = load_source_data(ensemble, selected_era5)
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
st.balloons()
else: else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.") st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")

View file

@ -99,3 +99,23 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
"count": e.create_cat_training_dataset("count"), "count": e.create_cat_training_dataset("count"),
"density": e.create_cat_training_dataset("density"), "density": e.create_cat_training_dataset("density"),
} }
@st.cache_data
def load_source_data(e: DatasetEnsemble, source: str):
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
Args:
e: DatasetEnsemble object.
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
Returns:
xarray.Dataset with the raw data for the specified source.
"""
targets = e._read_target()
# Load the member data lazily to get metadata
ds = e._read_member(source, targets, lazy=False)
return ds, targets