Add data sources viz to dashboard
This commit is contained in:
parent
f5ea72e05e
commit
3d6417ef6b
3 changed files with 952 additions and 19 deletions
798
src/entropice/dashboard/plots/source_data.py
Normal file
798
src/entropice/dashboard/plots/source_data.py
Normal file
|
|
@ -0,0 +1,798 @@
|
|||
"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
|
||||
|
||||
import antimeridian
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_cmap
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian."""
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
st.error(f"Error fixing geometry: {e}")
|
||||
return geom
|
||||
|
||||
|
||||
def render_alphaearth_overview(ds: xr.Dataset):
|
||||
"""Render overview statistics for AlphaEarth embeddings data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing AlphaEarth embeddings.
|
||||
|
||||
"""
|
||||
st.subheader("📊 AlphaEarth Embeddings Statistics")
|
||||
|
||||
# Overall statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Embedding Dimensions", f"{len(ds['band'])}")
|
||||
|
||||
with col3:
|
||||
st.metric("Years Available", f"{len(ds['year'])}")
|
||||
|
||||
with col4:
|
||||
st.metric("Aggregations", f"{len(ds['agg'])}")
|
||||
|
||||
# Show temporal coverage
|
||||
st.markdown("**Temporal Coverage:**")
|
||||
years = sorted(ds["year"].values)
|
||||
st.write(f"Years: {min(years)} - {max(years)}")
|
||||
|
||||
# Show aggregations
|
||||
st.markdown("**Available Aggregations:**")
|
||||
aggs = ds["agg"].to_numpy()
|
||||
st.write(", ".join(str(a) for a in aggs))
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_alphaearth_plots(ds: xr.Dataset):
|
||||
"""Render interactive plots for AlphaEarth embeddings data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing AlphaEarth embeddings.
|
||||
|
||||
"""
|
||||
st.markdown("---")
|
||||
st.markdown("**Embedding Distribution by Band**")
|
||||
|
||||
embeddings_data = ds["embeddings"]
|
||||
|
||||
# Select year and aggregation for visualization
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
|
||||
with col2:
|
||||
selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
|
||||
|
||||
# Get data for selected year and aggregation
|
||||
year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
|
||||
|
||||
# Calculate statistics for each band
|
||||
band_stats = []
|
||||
for band_idx in range(len(ds["band"])):
|
||||
band_data = year_agg_data.isel(band=band_idx).values.flatten()
|
||||
band_data = band_data[~np.isnan(band_data)] # Remove NaN values
|
||||
|
||||
if len(band_data) > 0:
|
||||
band_stats.append(
|
||||
{
|
||||
"Band": band_idx,
|
||||
"Mean": float(np.mean(band_data)),
|
||||
"Std": float(np.std(band_data)),
|
||||
"Min": float(np.min(band_data)),
|
||||
"25%": float(np.percentile(band_data, 25)),
|
||||
"Median": float(np.median(band_data)),
|
||||
"75%": float(np.percentile(band_data, 75)),
|
||||
"Max": float(np.max(band_data)),
|
||||
}
|
||||
)
|
||||
|
||||
band_df = pd.DataFrame(band_stats)
|
||||
|
||||
# Create plot showing distribution across bands
|
||||
fig = go.Figure()
|
||||
|
||||
# Add min/max range first (background)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=band_df["Band"],
|
||||
y=band_df["Min"],
|
||||
mode="lines",
|
||||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||
name="Min/Max Range",
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=band_df["Band"],
|
||||
y=band_df["Max"],
|
||||
mode="lines",
|
||||
fill="tonexty",
|
||||
fillcolor="rgba(200, 200, 200, 0.1)",
|
||||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Add std band
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=band_df["Band"],
|
||||
y=band_df["Mean"] - band_df["Std"],
|
||||
mode="lines",
|
||||
line={"width": 0},
|
||||
showlegend=False,
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=band_df["Band"],
|
||||
y=band_df["Mean"] + band_df["Std"],
|
||||
mode="lines",
|
||||
fill="tonexty",
|
||||
fillcolor="rgba(31, 119, 180, 0.2)",
|
||||
line={"width": 0},
|
||||
name="±1 Std",
|
||||
)
|
||||
)
|
||||
|
||||
# Add mean line on top
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=band_df["Band"],
|
||||
y=band_df["Mean"],
|
||||
mode="lines+markers",
|
||||
name="Mean",
|
||||
line={"color": "#1f77b4", "width": 2},
|
||||
marker={"size": 4},
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
|
||||
xaxis_title="Band",
|
||||
yaxis_title="Embedding Value",
|
||||
height=450,
|
||||
hovermode="x unified",
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Band statistics
|
||||
with st.expander("📈 Statistics by Embedding Band", expanded=False):
|
||||
st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
|
||||
|
||||
# Calculate statistics for each band
|
||||
band_stats = []
|
||||
for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
|
||||
band_data = embeddings_data.isel(band=band_idx)
|
||||
band_stats.append(
|
||||
{
|
||||
"Band": int(band_idx),
|
||||
"Mean": float(band_data.mean().values),
|
||||
"Std": float(band_data.std().values),
|
||||
"Min": float(band_data.min().values),
|
||||
"Max": float(band_data.max().values),
|
||||
}
|
||||
)
|
||||
|
||||
band_df = pd.DataFrame(band_stats)
|
||||
st.dataframe(band_df, use_container_width=True, hide_index=True)
|
||||
|
||||
if len(ds["band"]) > 10:
|
||||
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
|
||||
|
||||
|
||||
def render_arcticdem_overview(ds: xr.Dataset):
|
||||
"""Render overview statistics for ArcticDEM terrain data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ArcticDEM data.
|
||||
|
||||
"""
|
||||
st.subheader("🏔️ ArcticDEM Terrain Statistics")
|
||||
|
||||
# Overall statistics
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Variables", f"{len(ds.data_vars)}")
|
||||
|
||||
with col3:
|
||||
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||||
|
||||
# Show available variables
|
||||
st.markdown("**Available Variables:**")
|
||||
variables = list(ds.data_vars)
|
||||
st.write(", ".join(variables))
|
||||
|
||||
# Show aggregations
|
||||
st.markdown("**Available Aggregations:**")
|
||||
aggs = ds["aggregations"].to_numpy()
|
||||
st.write(", ".join(str(a) for a in aggs))
|
||||
|
||||
# Statistics by variable
|
||||
st.markdown("---")
|
||||
st.markdown("**Variable Statistics (across all aggregations)**")
|
||||
|
||||
var_stats = []
|
||||
for var_name in ds.data_vars:
|
||||
var_data = ds[var_name]
|
||||
var_stats.append(
|
||||
{
|
||||
"Variable": var_name,
|
||||
"Mean": float(var_data.mean().values),
|
||||
"Std": float(var_data.std().values),
|
||||
"Min": float(var_data.min().values),
|
||||
"Max": float(var_data.max().values),
|
||||
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||||
}
|
||||
)
|
||||
|
||||
stats_df = pd.DataFrame(var_stats)
|
||||
st.dataframe(stats_df, use_container_width=True, hide_index=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_arcticdem_plots(ds: xr.Dataset):
|
||||
"""Render interactive plots for ArcticDEM terrain data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ArcticDEM data.
|
||||
|
||||
"""
|
||||
st.markdown("---")
|
||||
st.markdown("**Variable Distributions**")
|
||||
|
||||
variables = list(ds.data_vars)
|
||||
|
||||
# Select a variable to visualize
|
||||
selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
|
||||
|
||||
if selected_var:
|
||||
var_data = ds[selected_var]
|
||||
|
||||
# Create histogram
|
||||
fig = go.Figure()
|
||||
|
||||
for agg in ds["aggregations"].to_numpy():
|
||||
agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
|
||||
agg_data = agg_data[~np.isnan(agg_data)]
|
||||
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
x=agg_data,
|
||||
name=str(agg),
|
||||
opacity=0.7,
|
||||
nbinsx=50,
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Distribution of {selected_var} by Aggregation",
|
||||
xaxis_title=selected_var,
|
||||
yaxis_title="Count",
|
||||
barmode="overlay",
|
||||
height=400,
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
|
||||
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
|
||||
"""Render overview statistics for ERA5 climate data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ERA5 data.
|
||||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||
|
||||
"""
|
||||
st.subheader(f"🌡️ ERA5 Climate Statistics ({temporal_type.capitalize()})")
|
||||
|
||||
# Overall statistics
|
||||
has_agg = "aggregations" in ds.dims
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||
|
||||
with col2:
|
||||
st.metric("Variables", f"{len(ds.data_vars)}")
|
||||
|
||||
with col3:
|
||||
st.metric("Time Steps", f"{len(ds['time'])}")
|
||||
|
||||
with col4:
|
||||
if has_agg:
|
||||
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||||
else:
|
||||
st.metric("Aggregations", "1")
|
||||
|
||||
# Show available variables
|
||||
st.markdown("**Available Variables:**")
|
||||
variables = list(ds.data_vars)
|
||||
st.write(", ".join(variables))
|
||||
|
||||
# Show temporal range
|
||||
st.markdown("**Temporal Range:**")
|
||||
time_values = pd.to_datetime(ds["time"].values)
|
||||
st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
|
||||
|
||||
if has_agg:
|
||||
st.markdown("**Available Aggregations:**")
|
||||
aggs = ds["aggregations"].to_numpy()
|
||||
st.write(", ".join(str(a) for a in aggs))
|
||||
|
||||
# Statistics by variable
|
||||
st.markdown("---")
|
||||
st.markdown("**Variable Statistics (across all time steps and aggregations)**")
|
||||
|
||||
var_stats = []
|
||||
for var_name in ds.data_vars:
|
||||
var_data = ds[var_name]
|
||||
var_stats.append(
|
||||
{
|
||||
"Variable": var_name,
|
||||
"Mean": float(var_data.mean().values),
|
||||
"Std": float(var_data.std().values),
|
||||
"Min": float(var_data.min().values),
|
||||
"Max": float(var_data.max().values),
|
||||
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||||
}
|
||||
)
|
||||
|
||||
stats_df = pd.DataFrame(var_stats)
|
||||
st.dataframe(stats_df, use_container_width=True, hide_index=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_era5_plots(ds: xr.Dataset, temporal_type: str):
|
||||
"""Render interactive plots for ERA5 climate data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ERA5 data.
|
||||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||
|
||||
"""
|
||||
st.markdown("---")
|
||||
st.markdown("**Temporal Trends**")
|
||||
|
||||
variables = list(ds.data_vars)
|
||||
has_agg = "aggregations" in ds.dims
|
||||
|
||||
selected_var = st.selectbox(
|
||||
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
|
||||
)
|
||||
|
||||
if selected_var:
|
||||
var_data = ds[selected_var]
|
||||
|
||||
# Calculate mean over space for each time step
|
||||
if has_agg:
|
||||
# Average over aggregations first, then over cells
|
||||
time_series = var_data.mean(dim=["cell_ids", "aggregations"])
|
||||
else:
|
||||
time_series = var_data.mean(dim="cell_ids")
|
||||
|
||||
time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=time_df["Time"],
|
||||
y=time_df["Value"],
|
||||
mode="lines+markers",
|
||||
name=selected_var,
|
||||
line={"width": 2},
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Temporal Trend of {selected_var} (Spatial Mean)",
|
||||
xaxis_title="Time",
|
||||
yaxis_title=selected_var,
|
||||
height=400,
|
||||
hovermode="x unified",
|
||||
)
|
||||
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||||
"""Render interactive pydeck map for AlphaEarth embeddings.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing AlphaEarth data.
|
||||
targets: GeoDataFrame with geometry for each cell.
|
||||
grid: Grid type ('hex' or 'healpix').
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ AlphaEarth Spatial Distribution")
|
||||
|
||||
# Controls
|
||||
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
|
||||
|
||||
with col1:
|
||||
selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
|
||||
|
||||
with col2:
|
||||
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
|
||||
|
||||
with col3:
|
||||
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
|
||||
|
||||
with col4:
|
||||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
|
||||
|
||||
# Extract data for selected parameters
|
||||
data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
|
||||
|
||||
# Create GeoDataFrame
|
||||
gdf = targets.copy()
|
||||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||
gdf = gdf.set_index("cell_id")
|
||||
|
||||
# Add values
|
||||
values_df = data_values.to_dataframe(name="value")
|
||||
gdf = gdf.join(values_df, how="inner")
|
||||
|
||||
# Convert to WGS84 first
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix geometries after CRS conversion
|
||||
if grid == "hex":
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Normalize values for color mapping
|
||||
values = gdf_wgs84["value"].to_numpy()
|
||||
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
|
||||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
|
||||
# Apply colormap
|
||||
cmap = get_cmap("embeddings")
|
||||
colors = [cmap(val) for val in normalized]
|
||||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||
|
||||
# Create GeoJSON
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"value": float(row["value"]),
|
||||
"fill_color": row["fill_color"],
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": "<b>Value:</b> {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||||
"""Render interactive pydeck map for ArcticDEM terrain data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ArcticDEM data.
|
||||
targets: GeoDataFrame with geometry for each cell.
|
||||
grid: Grid type ('hex' or 'healpix').
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ ArcticDEM Spatial Distribution")
|
||||
|
||||
# Controls
|
||||
variables = list(ds.data_vars)
|
||||
col1, col2, col3 = st.columns([3, 3, 1])
|
||||
|
||||
with col1:
|
||||
selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
|
||||
|
||||
with col2:
|
||||
selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
|
||||
|
||||
with col3:
|
||||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
|
||||
|
||||
# Extract data for selected parameters
|
||||
data_values = ds[selected_var].sel(aggregations=selected_agg)
|
||||
|
||||
# Create GeoDataFrame
|
||||
gdf = targets.copy()
|
||||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||
gdf = gdf.set_index("cell_id")
|
||||
|
||||
# Add values
|
||||
values_df = data_values.to_dataframe(name="value")
|
||||
gdf = gdf.join(values_df, how="inner")
|
||||
|
||||
# Add all aggregation values for tooltip
|
||||
if len(ds["aggregations"]) > 1:
|
||||
for agg in ds["aggregations"].values:
|
||||
agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||||
# Drop the aggregations column to avoid conflicts
|
||||
if "aggregations" in agg_data.columns:
|
||||
agg_data = agg_data.drop(columns=["aggregations"])
|
||||
gdf = gdf.join(agg_data, how="left")
|
||||
|
||||
# Convert to WGS84 first
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix geometries after CRS conversion
|
||||
if grid == "hex":
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Normalize values for color mapping
|
||||
values = gdf_wgs84["value"].values
|
||||
values_clean = values[~np.isnan(values)]
|
||||
|
||||
if len(values_clean) > 0:
|
||||
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
|
||||
# Apply colormap
|
||||
cmap = get_cmap(f"arcticdem_{selected_var}")
|
||||
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||
|
||||
# Create GeoJSON
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
properties = {
|
||||
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||||
"fill_color": row["fill_color"],
|
||||
}
|
||||
# Add all aggregation values if available
|
||||
if len(ds["aggregations"]) > 1:
|
||||
for agg in ds["aggregations"].values:
|
||||
agg_col = f"agg_{agg}"
|
||||
if agg_col in row.index:
|
||||
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||||
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": properties,
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||
|
||||
# Build tooltip HTML for ArcticDEM
|
||||
if len(ds["aggregations"]) > 1:
|
||||
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||||
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||||
for agg in ds["aggregations"].values:
|
||||
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||||
tooltip_html = "".join(tooltip_lines)
|
||||
else:
|
||||
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||||
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
st.caption(
|
||||
f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
|
||||
f"Std: {np.nanstd(values_clean):.2f}"
|
||||
)
|
||||
else:
|
||||
st.warning("No valid data available for selected parameters")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
|
||||
"""Render interactive pydeck map for ERA5 climate data.
|
||||
|
||||
Args:
|
||||
ds: xarray Dataset containing ERA5 data.
|
||||
targets: GeoDataFrame with geometry for each cell.
|
||||
grid: Grid type ('hex' or 'healpix').
|
||||
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ ERA5 Spatial Distribution")
|
||||
|
||||
# Controls
|
||||
variables = list(ds.data_vars)
|
||||
has_agg = "aggregations" in ds.dims
|
||||
|
||||
if has_agg:
|
||||
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
|
||||
else:
|
||||
col1, col2, col3 = st.columns([3, 3, 1])
|
||||
|
||||
with col1:
|
||||
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
|
||||
|
||||
with col2:
|
||||
# Convert time to readable format
|
||||
time_values = pd.to_datetime(ds["time"].values)
|
||||
time_options = {str(t): t for t in time_values}
|
||||
selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
|
||||
|
||||
if has_agg:
|
||||
with col3:
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
|
||||
)
|
||||
with col4:
|
||||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||||
else:
|
||||
with col3:
|
||||
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||||
|
||||
# Extract data for selected parameters
|
||||
time_val = time_options[selected_time]
|
||||
if has_agg:
|
||||
data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
|
||||
else:
|
||||
data_values = ds[selected_var].sel(time=time_val)
|
||||
|
||||
# Create GeoDataFrame
|
||||
gdf = targets.copy()
|
||||
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||
gdf = gdf.set_index("cell_id")
|
||||
|
||||
# Add values
|
||||
values_df = data_values.to_dataframe(name="value")
|
||||
gdf = gdf.join(values_df, how="inner")
|
||||
|
||||
# Add all aggregation values for tooltip if has_agg
|
||||
if has_agg and len(ds["aggregations"]) > 1:
|
||||
for agg in ds["aggregations"].values:
|
||||
agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||||
# Drop dimension columns to avoid conflicts
|
||||
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
|
||||
if cols_to_drop:
|
||||
agg_data = agg_data.drop(columns=cols_to_drop)
|
||||
gdf = gdf.join(agg_data, how="left")
|
||||
|
||||
# Convert to WGS84 first
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix geometries after CRS conversion
|
||||
if grid == "hex":
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Normalize values for color mapping
|
||||
values = gdf_wgs84["value"].values
|
||||
values_clean = values[~np.isnan(values)]
|
||||
|
||||
if len(values_clean) > 0:
|
||||
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
|
||||
# Apply colormap - use variable-specific colors
|
||||
cmap = get_cmap(f"era5_{selected_var}")
|
||||
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||
|
||||
# Create GeoJSON
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
properties = {
|
||||
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||||
"fill_color": row["fill_color"],
|
||||
}
|
||||
# Add all aggregation values if available
|
||||
if has_agg and len(ds["aggregations"]) > 1:
|
||||
for agg in ds["aggregations"].values:
|
||||
agg_col = f"agg_{agg}"
|
||||
if agg_col in row.index:
|
||||
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||||
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": properties,
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||
|
||||
# Build tooltip HTML for ERA5
|
||||
if has_agg and len(ds["aggregations"]) > 1:
|
||||
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||||
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||||
for agg in ds["aggregations"].values:
|
||||
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||||
tooltip_html = "".join(tooltip_lines)
|
||||
else:
|
||||
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||||
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
st.caption(
|
||||
f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
|
||||
f"Std: {np.nanstd(values_clean):.4f}"
|
||||
)
|
||||
else:
|
||||
st.warning("No valid data available for selected parameters")
|
||||
|
|
@ -2,8 +2,19 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.source_data import (
|
||||
render_alphaearth_map,
|
||||
render_alphaearth_overview,
|
||||
render_alphaearth_plots,
|
||||
render_arcticdem_map,
|
||||
render_arcticdem_overview,
|
||||
render_arcticdem_plots,
|
||||
render_era5_map,
|
||||
render_era5_overview,
|
||||
render_era5_plots,
|
||||
)
|
||||
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||
from entropice.dashboard.utils.data import load_all_training_data
|
||||
from entropice.dashboard.utils.data import load_all_training_data, load_source_data
|
||||
from entropice.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
|
|
@ -107,7 +118,28 @@ def render_training_data_page():
|
|||
# Display dataset ID in a styled container
|
||||
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
||||
|
||||
# Create tabs for different data views
|
||||
tab_names = ["📊 Labels"]
|
||||
|
||||
# Add tabs for each member
|
||||
for member in ensemble.members:
|
||||
if member == "AlphaEarth":
|
||||
tab_names.append("🌍 AlphaEarth")
|
||||
elif member == "ArcticDEM":
|
||||
tab_names.append("🏔️ ArcticDEM")
|
||||
elif member.startswith("ERA5"):
|
||||
# Group ERA5 temporal variants
|
||||
if "🌡️ ERA5" not in tab_names:
|
||||
tab_names.append("🌡️ ERA5")
|
||||
|
||||
tabs = st.tabs(tab_names)
|
||||
|
||||
# Labels tab
|
||||
with tabs[0]:
|
||||
st.markdown("### Target Labels Distribution and Spatial Visualization")
|
||||
|
||||
# Load training data for all three tasks
|
||||
with st.spinner("Loading training data for all tasks..."):
|
||||
train_data_dict = load_all_training_data(ensemble)
|
||||
|
||||
# Calculate total samples (use binary as reference)
|
||||
|
|
@ -115,7 +147,9 @@ def render_training_data_page():
|
|||
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
||||
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
||||
|
||||
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
|
||||
st.success(
|
||||
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
|
||||
)
|
||||
|
||||
# Render distribution histograms
|
||||
st.markdown("---")
|
||||
|
|
@ -123,14 +157,95 @@ def render_training_data_page():
|
|||
|
||||
st.markdown("---")
|
||||
|
||||
# Render spatial map (as a fragment for efficient re-rendering)
|
||||
# Extract geometries from the X.data dataframe (which has geometry as a column)
|
||||
# The index should be cell_id
|
||||
# Render spatial map
|
||||
binary_dataset = train_data_dict["binary"]
|
||||
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
|
||||
|
||||
render_spatial_map(train_data_dict)
|
||||
|
||||
# Add more components and visualizations as needed for training data.
|
||||
st.balloons()
|
||||
|
||||
# AlphaEarth tab
|
||||
tab_idx = 1
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
st.markdown("### AlphaEarth Embeddings Analysis")
|
||||
|
||||
with st.spinner("Loading AlphaEarth data..."):
|
||||
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
|
||||
|
||||
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
|
||||
|
||||
render_alphaearth_overview(alphaearth_ds)
|
||||
render_alphaearth_plots(alphaearth_ds)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
|
||||
|
||||
st.balloons()
|
||||
|
||||
tab_idx += 1
|
||||
|
||||
# ArcticDEM tab
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
st.markdown("### ArcticDEM Terrain Analysis")
|
||||
|
||||
with st.spinner("Loading ArcticDEM data..."):
|
||||
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
|
||||
|
||||
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
|
||||
|
||||
render_arcticdem_overview(arcticdem_ds)
|
||||
render_arcticdem_plots(arcticdem_ds)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
|
||||
|
||||
st.balloons()
|
||||
|
||||
tab_idx += 1
|
||||
|
||||
# ERA5 tab (combining all temporal variants)
|
||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||
if era5_members:
|
||||
with tabs[tab_idx]:
|
||||
st.markdown("### ERA5 Climate Data Analysis")
|
||||
|
||||
# Let user select which ERA5 temporal aggregation to view
|
||||
era5_options = {
|
||||
"ERA5-yearly": "Yearly",
|
||||
"ERA5-seasonal": "Seasonal (Winter/Summer)",
|
||||
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
||||
}
|
||||
|
||||
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
|
||||
|
||||
selected_era5 = st.selectbox(
|
||||
"Select ERA5 temporal aggregation",
|
||||
options=list(available_era5.keys()),
|
||||
format_func=lambda x: available_era5[x],
|
||||
key="era5_temporal_select",
|
||||
)
|
||||
|
||||
if selected_era5:
|
||||
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
|
||||
|
||||
with st.spinner(f"Loading {selected_era5} data..."):
|
||||
era5_ds, targets = load_source_data(ensemble, selected_era5)
|
||||
|
||||
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
|
||||
|
||||
render_era5_overview(era5_ds, temporal_type)
|
||||
render_era5_plots(era5_ds, temporal_type)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
|
||||
|
||||
st.balloons()
|
||||
|
||||
else:
|
||||
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
|
||||
|
|
|
|||
|
|
@ -99,3 +99,23 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
|
|||
"count": e.create_cat_training_dataset("count"),
|
||||
"density": e.create_cat_training_dataset("density"),
|
||||
}
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_source_data(e: DatasetEnsemble, source: str):
|
||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||
|
||||
Returns:
|
||||
xarray.Dataset with the raw data for the specified source.
|
||||
|
||||
"""
|
||||
targets = e._read_target()
|
||||
|
||||
# Load the member data lazily to get metadata
|
||||
ds = e._read_member(source, targets, lazy=False)
|
||||
|
||||
return ds, targets
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue