Add ArcticDEM and ERA5 tab

This commit is contained in:
Tobias Hölzer 2026-01-17 03:56:42 +01:00
parent 26de80ee89
commit c358bb63bc
9 changed files with 1416 additions and 36 deletions

View file

@ -0,0 +1,421 @@
"""Plots for visualizing ERA5 climate data."""
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import plotly.graph_objects as go
import pydeck as pdk
import xarray as xr
from plotly.subplots import make_subplots
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
from entropice.dashboard.utils.geometry import fix_hex_geometry
def create_climate_map(
climate_values: xr.DataArray,
grid_gdf: gpd.GeoDataFrame,
variable_name: str,
make_3d_map: bool,
) -> pdk.Deck:
"""Create a spatial distribution map for ERA5 climate variables.
Args:
climate_values: Series with cell_ids as index and climate values
grid_gdf: GeoDataFrame containing grid cell geometries
variable_name: Name of the climate variable being visualized
make_3d_map: Whether to render the map in 3D (extruded) or 2D
Returns:
pdk.Deck: A PyDeck map visualization of the climate variable
"""
# Subsample if too many cells for performance
n_cells = len(climate_values["cell_ids"])
if n_cells > 100000:
rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=100000, replace=False)
climate_values = climate_values.isel(cell_ids=cell_indices)
# Create a copy to avoid modifying the original
gdf = grid_gdf.copy().to_crs("EPSG:4326")
# Convert to DataFrame for easier merging
climate_df = climate_values.to_dataframe(name="climate_value")
# Reset index if cell_id is already the index
if gdf.index.name == "cell_id":
gdf = gdf.reset_index()
# Filter grid to only cells that have climate data
gdf = gdf[gdf["cell_id"].isin(climate_df.index)]
gdf = gdf.set_index("cell_id")
# Merge climate values with grid geometries
gdf = gdf.join(climate_df, how="inner")
# Convert to WGS84 for pydeck
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix antimeridian issues for hex cells
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Get colormap
cmap = get_cmap(variable_name)
# Normalize the climate values to [0, 1] for color mapping
values = gdf_wgs84["climate_value"].to_numpy()
# Use percentiles to avoid outliers
vmin, vmax = np.nanpercentile(values, [2, 98])
if vmax > vmin:
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
else:
normalized_values = np.zeros_like(values)
# Map normalized values to colors
colors = [cmap(val) for val in normalized_values]
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
gdf_wgs84["fill_color"] = rgb_colors
# Store climate value for tooltip
gdf_wgs84["climate_value_display"] = values
# Store normalized values for elevation (if 3D)
gdf_wgs84["elevation"] = normalized_values
# Convert to GeoJSON format
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"fill_color": row["fill_color"],
"climate_value": float(row["climate_value_display"]),
"elevation": float(row["elevation"]) if make_3d_map else 0,
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=0.7,
stroked=True,
filled=True,
extruded=make_3d_map,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if make_3d_map else 0,
elevation_scale=500000,
pickable=True,
)
# Set initial view state
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not make_3d_map else 1.5,
pitch=0 if not make_3d_map else 45,
)
# Build tooltip HTML
tooltip_html = f"<b>{variable_name}:</b> {{climate_value}}"
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
return deck
def create_climate_trend_plot(climate_data: xr.DataArray, variable_name: str) -> go.Figure:
"""Create a trend plot for climate variables over time.
Args:
climate_data: DataArray containing climate variable with 'year' dimension
variable_name: Name of the variable being plotted
Returns:
Plotly Figure with trend plot
"""
# Subsample if too many cells for performance
n_cells = len(climate_data["cell_ids"])
if n_cells > 10000:
rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=10000, replace=False)
climate_data = climate_data.isel({"cell_ids": cell_indices})
# Get years
years = climate_data["year"].to_numpy()
# Calculate statistics over space for each year
mean_values = climate_data.mean(dim="cell_ids").to_numpy()
min_values = climate_data.min(dim="cell_ids").to_numpy()
max_values = climate_data.max(dim="cell_ids").to_numpy()
p10_values = climate_data.quantile(0.10, dim="cell_ids").to_numpy()
p90_values = climate_data.quantile(0.90, dim="cell_ids").to_numpy()
fig = go.Figure()
# Add min/max range (background)
fig.add_trace(
go.Scatter(
x=years,
y=min_values,
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=years,
y=max_values,
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add 10th-90th percentile band
fig.add_trace(
go.Scatter(
x=years,
y=p10_values,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=years,
y=p90_values,
mode="lines",
fill="tonexty",
fillcolor="rgba(33, 150, 243, 0.2)",
line={"width": 0},
name="10th-90th Percentile",
)
)
# Add mean line
fig.add_trace(
go.Scatter(
x=years,
y=mean_values,
mode="lines+markers",
name="Mean",
line={"color": "#1976D2", "width": 2},
marker={"size": 6},
)
)
fig.update_layout(
title=f"{variable_name} Over Time (Spatial Statistics)",
xaxis_title="Year",
yaxis_title=variable_name,
height=450,
hovermode="x unified",
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
)
fig.update_xaxes(dtick=1)
return fig
def create_climate_distribution_plot(climate_ds: xr.Dataset, variables: list[str]) -> go.Figure:
"""Create distribution plots for climate variables.
Args:
climate_ds: Xarray Dataset containing ERA5 climate features
variables: List of variable names to plot
Returns:
Plotly Figure with distribution plots
"""
# Subsample if too many cells for performance
n_cells = len(climate_ds.cell_ids)
if n_cells > 10000:
rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=10000, replace=False)
climate_ds = climate_ds.isel(cell_ids=cell_indices)
# Create subplots - one row per variable
n_rows = len(variables)
fig = make_subplots(
rows=n_rows,
cols=1,
subplot_titles=[v.replace("_", " ").title() for v in variables],
vertical_spacing=0.15 / max(n_rows, 1),
)
# Color palette
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
]
for row_idx, variable in enumerate(variables, start=1):
if variable not in climate_ds.data_vars:
continue
# Get data
data = climate_ds[variable]
values = data.to_numpy().flatten()
values = values[~np.isnan(values)]
if len(values) == 0:
continue
# Add violin plot
fig.add_trace(
go.Violin(
y=values,
name=variable.replace("_", " ").title(),
box_visible=True,
meanline_visible=True,
line_color=colors[row_idx % len(colors)],
showlegend=False,
),
row=row_idx,
col=1,
)
# Update layout
fig.update_layout(
height=300 * n_rows,
title_text="Climate Variable Distributions",
showlegend=False,
)
# Update y-axes labels
for i in range(n_rows):
fig.update_yaxes(title_text="Value", row=i + 1, col=1)
return fig
def create_temperature_comparison_plot(climate_ds: xr.Dataset) -> go.Figure:
"""Create a comparison plot for temperature-related variables.
Args:
climate_ds: Xarray Dataset containing temperature variables
Returns:
Plotly Figure with temperature comparison
"""
# Temperature variables to compare
temp_vars = ["t2m_max", "t2m_mean", "t2m_min"]
available_vars = [v for v in temp_vars if v in climate_ds.data_vars]
if not available_vars:
return go.Figure()
# Get years
years = climate_ds["year"].to_numpy()
fig = go.Figure()
colors = {"t2m_max": "#d32f2f", "t2m_mean": "#1976d2", "t2m_min": "#0288d1"}
for var in available_vars:
# Calculate spatial mean for each year
values = climate_ds[var].mean(dim="cell_ids").to_numpy()
fig.add_trace(
go.Scatter(
x=years,
y=values - 273.15, # Convert to Celsius
mode="lines+markers",
name=var.replace("_", " ").title(),
line={"color": colors.get(var, "#666666"), "width": 2},
marker={"size": 4},
)
)
# Add freezing point reference line
fig.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Freezing Point (0°C)")
fig.update_layout(
title="Temperature Extremes Over Time",
xaxis_title="Year",
yaxis_title="Temperature (°C)",
height=450,
hovermode="x unified",
)
return fig
def create_seasonal_pattern_plot(climate_ds: xr.Dataset, variable: str) -> go.Figure:
"""Create a plot showing seasonal patterns.
Args:
climate_ds: Xarray Dataset with month dimension
variable: Variable name to plot
Returns:
Plotly Figure with seasonal patterns
"""
if variable not in climate_ds.data_vars or "month" not in climate_ds.dims:
return go.Figure()
# Get unique months/seasons
months = climate_ds["month"].to_numpy()
# Calculate mean across space and years for each month
values = climate_ds[variable].mean(dim=["cell_ids", "year"]).to_numpy()
fig = go.Figure()
fig.add_trace(
go.Bar(
x=months,
y=values,
marker_color="#1976d2",
)
)
fig.update_layout(
title=f"{variable.replace('_', ' ').title()} by Season",
xaxis_title="Season",
yaxis_title=variable.replace("_", " ").title(),
height=400,
)
return fig

View file

@ -0,0 +1,418 @@
"""Plots for visualizing ArcticDEM terrain features."""
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import xarray as xr
from plotly.subplots import make_subplots
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
from entropice.dashboard.utils.geometry import fix_hex_geometry
def create_terrain_map(
terrain_values: pd.Series,
grid_gdf: gpd.GeoDataFrame,
variable_name: str,
make_3d_map: bool,
) -> pdk.Deck:
"""Create a spatial distribution map for ArcticDEM terrain features.
Args:
terrain_values: Series with cell_ids as index and terrain values
grid_gdf: GeoDataFrame containing grid cell geometries
variable_name: Name of the terrain variable being visualized
make_3d_map: Whether to render the map in 3D (extruded) or 2D
Returns:
pdk.Deck: A PyDeck map visualization of the terrain feature
"""
# Subsample if too many cells for performance
n_cells = len(terrain_values)
if n_cells > 100000:
rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=100000, replace=False)
terrain_values = terrain_values.iloc[cell_indices]
# Create a copy to avoid modifying the original
gdf = grid_gdf.copy().to_crs("EPSG:4326")
# Reset index if cell_id is already the index
if gdf.index.name == "cell_id":
gdf = gdf.reset_index()
# Filter grid to only cells that have terrain data
gdf = gdf[gdf["cell_id"].isin(terrain_values.index)]
gdf = gdf.set_index("cell_id")
# Merge terrain values with grid geometries
gdf = gdf.join(terrain_values.to_frame("terrain_value"), how="inner")
# Convert to WGS84 for pydeck
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix antimeridian issues for hex cells
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Get colormap - use special colormap for aspect (circular)
if "aspect" in variable_name.lower():
cmap = get_cmap("aspect")
else:
cmap = get_cmap("terrain")
# Normalize the terrain values to [0, 1] for color mapping
values = gdf_wgs84["terrain_value"].to_numpy()
# Handle aspect specially (0-360 degrees, circular)
if "aspect" in variable_name.lower():
vmin, vmax = 0, 360
normalized_values = values / 360
else:
# Use percentiles to avoid outliers
vmin, vmax = np.nanpercentile(values, [2, 98])
if vmax > vmin:
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
else:
normalized_values = np.zeros_like(values)
# Map normalized values to colors
colors = [cmap(val) for val in normalized_values]
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
gdf_wgs84["fill_color"] = rgb_colors
# Store terrain value for tooltip
gdf_wgs84["terrain_value_display"] = values
# Store normalized values for elevation (if 3D)
gdf_wgs84["elevation"] = normalized_values
# Convert to GeoJSON format
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"fill_color": row["fill_color"],
"terrain_value": float(row["terrain_value_display"]),
"elevation": float(row["elevation"]) if make_3d_map else 0,
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=0.7,
stroked=True,
filled=True,
extruded=make_3d_map,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if make_3d_map else 0,
elevation_scale=500000,
pickable=True,
)
# Set initial view state
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not make_3d_map else 1.5,
pitch=0 if not make_3d_map else 45,
)
# Build tooltip HTML
tooltip_html = f"<b>{variable_name}:</b> {{terrain_value}}"
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
return deck
def create_terrain_distribution_plot(arcticdem_ds: xr.Dataset, features: list[str]) -> go.Figure:
"""Create distribution plots for terrain features.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
features: List of feature names to plot
Returns:
Plotly Figure with distribution plots
"""
# Subsample if too many cells for performance
n_cells = len(arcticdem_ds.cell_ids)
if n_cells > 10000:
rng = np.random.default_rng(42)
cell_indices = rng.choice(n_cells, size=10000, replace=False)
arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices)
# Determine aggregation types available
aggs = list(arcticdem_ds.coords["aggregations"].values)
# Create subplots - one row per aggregation type
n_rows = len(aggs)
fig = make_subplots(
rows=n_rows,
cols=1,
subplot_titles=[f"{agg.title()} Values" for agg in aggs],
vertical_spacing=0.15 / max(n_rows, 1),
)
# Color palette for features
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
for row_idx, agg in enumerate(aggs, start=1):
for feat_idx, feature in enumerate(features):
# Get the data for this feature and aggregation
var_name = f"{feature}"
if var_name not in arcticdem_ds.data_vars:
continue
# Extract values for this aggregation
data = arcticdem_ds[var_name].sel(aggregations=agg)
values = data.to_numpy().flatten()
values = values[~np.isnan(values)]
if len(values) == 0:
continue
# Add violin plot
fig.add_trace(
go.Violin(
y=values,
name=feature.replace("_", " ").title(),
box_visible=True,
meanline_visible=True,
line_color=colors[feat_idx % len(colors)],
showlegend=(row_idx == 1),
),
row=row_idx,
col=1,
)
# Update layout
fig.update_layout(
height=300 * n_rows,
title_text="Terrain Feature Distributions by Aggregation Type",
showlegend=True,
)
# Update y-axes labels
for i in range(n_rows):
fig.update_yaxes(title_text="Value", row=i + 1, col=1)
return fig
def create_aspect_rose_diagram(aspect_values: np.ndarray) -> go.Figure:
"""Create a rose diagram (circular histogram) for aspect values.
Args:
aspect_values: Array of aspect values in degrees (0-360)
Returns:
Plotly Figure with rose diagram
"""
# Remove NaN values
aspect_values = aspect_values[~np.isnan(aspect_values)]
# Subsample if too many values for performance
if len(aspect_values) > 50000:
rng = np.random.default_rng(42)
indices = rng.choice(len(aspect_values), size=50000, replace=False)
aspect_values = aspect_values[indices]
if len(aspect_values) == 0:
# Return empty figure
return go.Figure()
# Create bins for aspect (every 10 degrees)
bins = np.arange(0, 361, 10)
bin_counts, _ = np.histogram(aspect_values, bins=bins)
# Calculate bin centers in degrees and radians
bin_centers_deg = (bins[:-1] + bins[1:]) / 2
bin_centers_rad = np.deg2rad(bin_centers_deg)
# Close the circle
bin_centers_rad = np.append(bin_centers_rad, bin_centers_rad[0])
bin_counts = np.append(bin_counts, bin_counts[0])
# Create polar bar chart
fig = go.Figure()
fig.add_trace(
go.Barpolar(
r=bin_counts,
theta=np.rad2deg(bin_centers_rad),
width=10,
marker={
"color": bin_centers_rad,
"colorscale": "HSV",
"cmin": 0,
"cmax": 2 * np.pi,
"showscale": False,
},
)
)
fig.update_layout(
title="Aspect Rose Diagram",
polar={
"radialaxis": {"title": "Frequency", "showticklabels": True},
"angularaxis": {
"direction": "clockwise",
"rotation": 90,
"tickmode": "array",
"tickvals": [0, 45, 90, 135, 180, 225, 270, 315],
"ticktext": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"],
},
},
showlegend=False,
height=500,
)
return fig
def create_slope_aspect_scatter(slope_values: np.ndarray, aspect_values: np.ndarray) -> go.Figure:
"""Create a scatter plot showing the relationship between slope and aspect.
Args:
slope_values: Array of slope values
aspect_values: Array of aspect values in degrees (0-360)
Returns:
Plotly Figure with scatter plot
"""
# Create DataFrame and remove NaN
df = pd.DataFrame({"slope": slope_values.flatten(), "aspect": aspect_values.flatten()})
df = df.dropna()
if len(df) == 0:
return go.Figure()
# Sample if too many points for performance
if len(df) > 50000:
df = df.sample(n=50000, random_state=42)
# Create 2D histogram (density plot)
fig = go.Figure()
fig.add_trace(
go.Histogram2d(
x=df["aspect"],
y=df["slope"],
colorscale="Viridis",
nbinsx=36, # 10-degree bins for aspect
nbinsy=50,
)
)
fig.update_layout(
title="Slope vs Aspect Distribution",
xaxis_title="Aspect (degrees)",
yaxis_title="Slope (degrees)",
height=500,
)
# Add directional labels on x-axis
fig.update_xaxes(
tickmode="array",
tickvals=[0, 45, 90, 135, 180, 225, 270, 315, 360],
ticktext=["N", "NE", "E", "SE", "S", "SW", "W", "NW", "N"],
)
return fig
def create_correlation_heatmap(arcticdem_ds: xr.Dataset, features: list[str], agg: str) -> go.Figure:
"""Create a correlation heatmap for terrain features.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
features: List of feature names to include
agg: Aggregation type to use
Returns:
Plotly Figure with correlation heatmap
"""
# Extract data for each feature
data_dict = {}
for feature in features:
var_name = f"{feature}"
if var_name in arcticdem_ds.data_vars:
data = arcticdem_ds[var_name].sel(aggregations=agg)
values = data.to_numpy().flatten()
data_dict[feature] = values
if not data_dict:
return go.Figure()
# Create DataFrame
df = pd.DataFrame(data_dict)
df = df.dropna()
# Sample if too many rows
if len(df) > 50000:
df = df.sample(n=50000, random_state=42)
# Calculate correlation matrix
corr = df.corr()
# Create heatmap
fig = go.Figure(
data=go.Heatmap(
z=corr.values,
x=[f.replace("_", " ").title() for f in corr.columns],
y=[f.replace("_", " ").title() for f in corr.index],
colorscale="RdBu",
zmid=0,
zmin=-1,
zmax=1,
text=np.round(corr.values, 2),
texttemplate="%{text}",
textfont={"size": 10},
colorbar={"title": "Correlation"},
)
)
fig.update_layout(
title=f"Feature Correlation Matrix ({agg.title()})",
height=600,
xaxis={"side": "bottom"},
)
return fig

View file

@ -1,5 +1,7 @@
"""AlphaEarth embeddings dashboard section.""" """AlphaEarth embeddings dashboard section."""
from typing import cast
import geopandas as gpd import geopandas as gpd
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
import numpy as np import numpy as np
@ -61,8 +63,6 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
""" """
) )
cols = st.columns([4, 1])
with cols[0]:
if "year" in embedding_values.dims or "year" in embedding_values.coords: if "year" in embedding_values.dims or "year" in embedding_values.coords:
year_values = embedding_values["year"].values.tolist() year_values = embedding_values["year"].values.tolist()
year = st.slider( year = st.slider(
@ -74,8 +74,7 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
help="Select the year for which to visualize the embeddings.", help="Select the year for which to visualize the embeddings.",
) )
embedding_values = embedding_values.sel(year=year) embedding_values = embedding_values.sel(year=year)
with cols[1]: make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="embedding_map_3d"))
make_3d_map = st.checkbox("3D Map", value=True)
# Check if subsampling will occur # Check if subsampling will occur
n_cells = len(embedding_values["cell_ids"]) n_cells = len(embedding_values["cell_ids"])

View file

@ -0,0 +1,317 @@
"""ArcticDEM terrain features dashboard section."""
from typing import cast
import geopandas as gpd
import matplotlib.colors as mcolors
import streamlit as st
import xarray as xr
from entropice.dashboard.plots.terrain import (
create_aspect_rose_diagram,
create_correlation_heatmap,
create_slope_aspect_scatter,
create_terrain_distribution_plot,
create_terrain_map,
)
from entropice.dashboard.sections.dataset_statistics import render_member_details
from entropice.dashboard.utils.colors import get_cmap
from entropice.dashboard.utils.stats import MemberStatistics
@st.fragment
def _render_terrain_map(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
"""Visualize spatial distribution of terrain features.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
grid_gdf: GeoDataFrame with grid cell geometries
"""
st.subheader("Spatial Distribution of Terrain Features")
# Get available features and aggregations
available_vars = list(arcticdem_ds.data_vars)
aggs = list(arcticdem_ds.coords["aggregations"].values)
# Feature selection with grouping
feature_groups = {
"TPI (Topographic Position Index)": ["tpi_small", "tpi_medium", "tpi_large"],
"TRI (Terrain Ruggedness Index)": ["tri_small", "tri_medium", "tri_large"],
"VRM (Vector Ruggedness Measure)": ["vrm_small", "vrm_medium", "vrm_large"],
"Slope & Curvature": ["slope", "curvature"],
"Aspect": ["aspect"],
}
# Flatten to get all features
all_features = [f for features in feature_groups.values() for f in features if f in available_vars]
cols = st.columns([2, 2, 1])
with cols[0]:
# Feature selection with nice formatting
selected_feature = st.selectbox(
"Terrain Feature",
options=all_features,
format_func=lambda x: x.replace("_", " ").title(),
key="terrain_feature",
)
with cols[1]:
# Aggregation selection
selected_agg = st.selectbox(
"Aggregation",
options=aggs,
key="terrain_agg",
)
with cols[2]:
st.write("\n")
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="terrain_map_3d"))
# Extract the data
if selected_feature not in arcticdem_ds.data_vars:
st.error(f"Feature {selected_feature} not found in dataset")
return
terrain_data = arcticdem_ds[selected_feature].sel(aggregations=selected_agg)
terrain_series = terrain_data.to_series()
# Check if subsampling will occur
n_cells = len(terrain_series)
if n_cells > 100000:
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
# Create map
map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map)
st.pydeck_chart(map_deck)
# Add legend
with st.expander("Legend", expanded=True):
st.markdown(f"**{selected_feature.replace('_', ' ').title()} ({selected_agg})**")
values = terrain_series.dropna()
if len(values) > 0:
vmin, vmax = values.min(), values.max()
vmean = values.mean()
vstd = values.std()
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Min", f"{vmin:.2f}")
with col2:
st.metric("Mean", f"{vmean:.2f}")
with col3:
st.metric("Max", f"{vmax:.2f}")
with col4:
st.metric("Std Dev", f"{vstd:.2f}")
# Color scale visualization
if "aspect" in selected_feature.lower():
cmap = get_cmap("aspect")
else:
cmap = get_cmap("terrain")
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
gradient_css = ", ".join(gradient_colors)
gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;"
st.markdown(
f"""
<div style="margin-top: 10px;">
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
<span style="font-size: 0.9em;">{vmin:.2f}</span>
<span style="font-size: 0.9em;">{vmax:.2f}</span>
</div>
<div style="{gradient_style}"></div>
</div>
""",
unsafe_allow_html=True,
)
else:
st.warning("No data available for the selected feature and aggregation")
@st.fragment
def _render_terrain_distributions(arcticdem_ds: xr.Dataset):
"""Display distribution plots for terrain features.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
"""
st.subheader("Terrain Feature Distributions")
# Get available features
available_vars = list(arcticdem_ds.data_vars)
# Feature group selection
feature_groups = {
"Topographic Position (TPI)": ["tpi_small", "tpi_medium", "tpi_large"],
"Terrain Ruggedness (TRI)": ["tri_small", "tri_medium", "tri_large"],
"Vector Ruggedness (VRM)": ["vrm_small", "vrm_medium", "vrm_large"],
"All Scale-Invariant": ["slope", "aspect", "curvature"],
}
selected_group = st.selectbox(
"Feature Group",
options=list(feature_groups.keys()),
key="terrain_dist_group",
)
features_to_plot = [f for f in feature_groups[selected_group] if f in available_vars]
if features_to_plot:
# Check if subsampling will occur
n_cells = len(arcticdem_ds.cell_ids)
if n_cells > 10000:
st.info(
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
"for performance. Statistics remain representative."
)
fig = create_terrain_distribution_plot(arcticdem_ds, features_to_plot)
st.plotly_chart(fig, width="stretch")
st.markdown(
f"""
Distribution of **{selected_group}** features across different aggregation types.
Violin plots show the full distribution with embedded box plots for quartiles.
"""
)
else:
st.warning(f"No features available for {selected_group}")
@st.fragment
def _render_terrain_correlations(arcticdem_ds: xr.Dataset):
"""Display correlation analysis for terrain features.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
"""
st.subheader("Feature Correlation Analysis")
# Get available features
available_vars = list(arcticdem_ds.data_vars)
aggs = list(arcticdem_ds.coords["aggregations"].values)
# Select aggregation for correlation
selected_agg = st.selectbox(
"Aggregation Type",
options=aggs,
key="terrain_corr_agg",
)
# Select which features to include
all_features = [str(f) for f in available_vars if f in available_vars]
if len(all_features) >= 2:
fig = create_correlation_heatmap(arcticdem_ds, all_features, selected_agg)
st.plotly_chart(fig, width="stretch")
st.markdown(
"""
Correlation heatmap shows relationships between different terrain features.
Strong positive correlations (red) or negative correlations (blue) can indicate
related terrain characteristics.
"""
)
else:
st.warning("Need at least 2 features for correlation analysis")
@st.fragment
def _render_slope_aspect_analysis(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
"""Specialized visualization for slope and aspect relationships.
Args:
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
grid_gdf: GeoDataFrame with grid cell geometries
"""
st.subheader("Slope & Aspect Analysis")
# Check if slope and aspect are available
if "slope" not in arcticdem_ds.data_vars or "aspect" not in arcticdem_ds.data_vars:
st.warning("Slope and/or aspect data not available in this dataset")
return
# Get aggregations
aggs = list(arcticdem_ds.coords["aggregations"].values)
# Select aggregation
selected_agg = st.selectbox(
"Aggregation Type",
options=aggs,
key="slope_aspect_agg",
)
# Extract slope and aspect data
slope_data = arcticdem_ds["slope"].sel(aggregations=selected_agg).values
aspect_data = arcticdem_ds["aspect"].sel(aggregations=selected_agg).values
# Check if subsampling will occur
n_values = len(slope_data.flatten())
if n_values > 50000:
st.info(
f"📊 **Dataset subsampled:** Using 50,000 randomly selected values out of {n_values:,} "
"for performance. Distributions remain representative."
)
# Create two columns for visualizations
col1, col2 = st.columns(2)
with col1:
st.markdown("**Aspect Rose Diagram**")
fig_rose = create_aspect_rose_diagram(aspect_data)
st.plotly_chart(fig_rose, width="stretch")
st.markdown(
"""
The rose diagram shows the directional distribution of terrain aspect.
Each bar represents the frequency of slopes facing that direction.
"""
)
with col2:
st.markdown("**Slope vs Aspect Distribution**")
fig_scatter = create_slope_aspect_scatter(slope_data, aspect_data)
st.plotly_chart(fig_scatter, width="stretch")
st.markdown(
"""
This 2D histogram shows how slope steepness relates to aspect direction.
Patterns can reveal preferential slope orientations (e.g., due to prevailing winds or sun exposure).
"""
)
def render_arcticdem_tab(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame, arcticdem_stats: MemberStatistics):
"""Render the ArcticDEM visualization tab.
Args:
arcticdem_ds: The ArcticDEM dataset member, lazily loaded.
grid_gdf: GeoDataFrame with grid cell geometries
arcticdem_stats: Statistics for the ArcticDEM member.
"""
# Render different visualizations
with st.expander("ArcticDEM Statistics", expanded=True):
render_member_details("ArcticDEM", arcticdem_stats)
st.divider()
_render_terrain_map(arcticdem_ds, grid_gdf)
# st.divider()
# _render_terrain_distributions(arcticdem_ds)
st.divider()
_render_terrain_correlations(arcticdem_ds)
st.divider()
_render_slope_aspect_analysis(arcticdem_ds, grid_gdf)

View file

@ -14,16 +14,14 @@ from entropice.dashboard.utils.colors import get_cmap
def _render_area_map(grid_gdf: gpd.GeoDataFrame): def _render_area_map(grid_gdf: gpd.GeoDataFrame):
st.subheader("Spatial Distribution of Grid Cell Areas") st.subheader("Spatial Distribution of Grid Cell Areas")
cols = st.columns([4, 1])
with cols[0]:
metric = st.selectbox( metric = st.selectbox(
"Metric", "Metric",
options=["cell_area", "land_area", "water_area", "land_ratio"], options=["cell_area", "land_area", "water_area", "land_ratio"],
format_func=lambda x: x.replace("_", " ").title(), format_func=lambda x: x.replace("_", " ").title(),
key="metric", key="metric",
) )
with cols[1]:
make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d"))
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map) map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
st.pydeck_chart(map_deck) st.pydeck_chart(map_deck)

View file

@ -446,8 +446,6 @@ def render_member_details(member: str, member_stats: MemberStatistics):
member_stats: Statistics for the member member_stats: Statistics for the member
""" """
st.markdown(f"### {member}")
# Variables # Variables
st.markdown("**Variables:**") st.markdown("**Variables:**")
vars_html = " ".join( vars_html = " ".join(
@ -539,6 +537,7 @@ def render_ensemble_details(
# Individual member details # Individual member details
for member, stats in selected_member_stats.items(): for member, stats in selected_member_stats.items():
st.markdown(f"### {member}")
render_member_details(member, stats) render_member_details(member, stats)
st.divider() st.divider()

View file

@ -0,0 +1,219 @@
"""ERA5 climate data dashboard section."""
from typing import cast
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import streamlit as st
import xarray as xr
from entropice.dashboard.plots.climate import (
create_climate_map,
create_climate_trend_plot,
)
from entropice.dashboard.sections.dataset_statistics import render_member_details
from entropice.dashboard.utils.colors import get_cmap
from entropice.dashboard.utils.stats import MemberStatistics
def _get_climate_variable_aggregation_season_option(era5_ds: xr.Dataset):
available_vars = list(era5_ds.data_vars)
selected_var = st.selectbox("Select Climate Variable", available_vars, key="climate_var_map")
if "month" in era5_ds.dims:
months = era5_ds.month.to_numpy()
selected_month = cast(str, st.selectbox("Select Season", options=months.tolist(), key="climate_month"))
else:
selected_month = None
if "aggregations" in era5_ds.dims:
aggs = era5_ds.aggregations.to_numpy()
selected_agg = cast(str, st.selectbox("Select Aggregation Method", options=aggs.tolist(), key="climate_agg"))
else:
selected_agg = None
return selected_var, selected_agg, selected_month
@st.fragment
def _render_climate_variable_map(climate_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame, selected_var: str):
"""Visualize spatial distribution of climate variables.
Args:
climate_values: Xarray DataArray containing ERA5 climate features
grid_gdf: GeoDataFrame with grid cell geometries
selected_var: Name of the selected climate variable
"""
st.subheader("Spatial Distribution of Climate Variables")
if "year" in climate_values.dims:
years = climate_values.year.to_numpy()
selected_year = cast(
int, st.select_slider("Select Year", options=years.tolist(), value=int(years.max()), key="climate_year")
)
climate_values = climate_values.sel(year=selected_year)
# 3D toggle
make_3d = cast(bool, st.toggle("3D Map", value=True, key="climate_map_3d"))
# Create map
n_cells = len(climate_values)
if n_cells > 100000:
st.info(f"Showing 100,000 / {n_cells:,} cells for performance")
deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d)
st.pydeck_chart(deck, use_container_width=True)
# Add legend
with st.expander("Legend", expanded=True):
st.markdown(f"**{selected_var.replace('_', ' ').title()}**")
# Get the actual values to show accurate min/max (same as in the map function)
values = climate_values.values.flatten()
values = values[~np.isnan(values)]
vmin, vmax = np.nanpercentile(values, [2, 98])
vmin_str = f"{vmin:.4f}"
vmax_str = f"{vmax:.4f}"
# Color scale visualization - use appropriate colormap
cmap = get_cmap(selected_var)
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
gradient_css = ", ".join(gradient_colors)
gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;"
st.markdown(
f"""
<div style="margin-top: 10px;">
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
<span style="font-size: 0.9em;">{vmin_str}</span>
<span style="font-size: 0.9em;">{vmax_str}</span>
</div>
<div style="{gradient_style}"></div>
</div>
""",
unsafe_allow_html=True,
)
st.caption(
"Color intensity represents climate values from low to high. "
"Values are normalized using the 2nd-98th percentile range to avoid outliers."
)
if make_3d:
st.markdown("---")
st.markdown("**3D Elevation:**")
st.caption("Height represents normalized climate values. Rotate the map by holding Ctrl/Cmd and dragging.")
def _render_climate_temporal_trends(
climate_values: xr.DataArray, selected_var: str, selected_agg: str | None, selected_month: str | None
):
"""Display temporal trends of climate variables.
Args:
climate_values: Xarray DataArray containing ERA5 climate features
selected_var: Name of the selected climate variable
selected_agg: Selected aggregation method
selected_month: Selected month/season
"""
st.subheader("Climate Variable Temporal Trends")
if "year" not in climate_values.dims:
st.info("Temporal trends require yearly dimension")
return
st.markdown(
"""
This visualization shows how climate variables have changed over time across the study area.
The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal
patterns in climate variables that may correlate with environmental changes.
**Understanding the plot:**
- **Mean line** (blue): Average value across all grid cells for each year
- **10th-90th percentile band** (light blue): Range containing 80% of the values, showing
typical variation
- **Min/Max range** (gray): Full extent of values, highlighting outliers
"""
)
# Show dataset filtering info
n_years = len(climate_values["year"])
n_cells = len(climate_values["cell_ids"])
agg_info = f", Aggregation: `{selected_agg}`" if selected_agg else ""
month_info = f", Season: `{selected_month}`" if selected_month else ""
st.caption(
f"📊 **Dataset selection:** Variable: `{selected_var}` "
f"({n_years} years, {n_cells:,} cells{agg_info}{month_info})"
)
# Check if subsampling will occur
if n_cells > 10000:
st.info(
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
"for performance. Statistics remain representative."
)
# Create trend plot
fig = create_climate_trend_plot(climate_values, selected_var)
st.plotly_chart(fig, use_container_width=True)
@st.fragment
def render_era5_tab(
era5_member_datasets: dict[str, xr.Dataset],
grid_gdf: gpd.GeoDataFrame,
era5_member_stats: dict[str, MemberStatistics],
):
"""Render the ERA5 visualization tab.
Args:
era5_member_datasets: Dictionary of ERA5 member datasets (yearly, seasonal, shoulder)
grid_gdf: GeoDataFrame with grid cell geometries
era5_member_stats: Dictionary of MemberStatistics for each ERA5 member.
"""
if not era5_member_datasets:
st.warning("No ERA5 data available")
return
# Member selection
available_members = list(era5_member_datasets.keys())
if len(available_members) == 1:
selected_member = available_members[0]
st.info(f"Showing ERA5 data: {selected_member}")
else:
selected_member = st.selectbox(
"Select ERA5 Aggregation Type",
available_members,
format_func=lambda x: x.replace("_", " ").title(),
key="era5_member",
)
era5_stats = era5_member_stats[selected_member]
# Load selected dataset
era5_ds = era5_member_datasets[selected_member]
# Render different visualizations
with st.expander(f"{selected_member.replace('_', ' ').title()} Statistics", expanded=True):
render_member_details(selected_member, era5_stats)
st.divider()
selected_var, selected_agg, selected_month = _get_climate_variable_aggregation_season_option(era5_ds)
climate_values = era5_ds[selected_var]
if selected_agg:
climate_values = climate_values.sel(aggregations=selected_agg)
if selected_month:
climate_values = climate_values.sel(month=selected_month)
climate_values = climate_values.compute()
_render_climate_variable_map(climate_values, grid_gdf, selected_var)
if "year" in climate_values.dims:
st.divider()
_render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month)

View file

@ -185,9 +185,9 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
) )
with cols[2]: with cols[2]:
# Controls weather a 3D map or a 2D map is shown # Controls weather a 3D map or a 2D map is shown
make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="target_map_3d"))
# Controls what should be shows, either the split or the labels / values # Controls what should be shows, either the split or the labels / values
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False)) show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False, key="target_map_show_split"))
training_set = train_data_dict[selected_target][selected_task] training_set = train_data_dict[selected_target][selected_task]
map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task) map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task)

View file

@ -1,14 +1,16 @@
"""Data page: Visualization of the data.""" """Data page: Visualization of the data."""
from typing import cast from typing import Literal, cast
import streamlit as st import streamlit as st
import xarray as xr import xarray as xr
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice.dashboard.sections.alphaearth import render_alphaearth_tab from entropice.dashboard.sections.alphaearth import render_alphaearth_tab
from entropice.dashboard.sections.arcticdem import render_arcticdem_tab
from entropice.dashboard.sections.areas import render_area_information_tab from entropice.dashboard.sections.areas import render_area_information_tab
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
from entropice.dashboard.sections.era5 import render_era5_tab
from entropice.dashboard.sections.targets import render_target_information_tab from entropice.dashboard.sections.targets import render_target_information_tab
from entropice.dashboard.utils.loaders import load_training_sets from entropice.dashboard.utils.loaders import load_training_sets
from entropice.dashboard.utils.stats import DatasetStatistics from entropice.dashboard.utils.stats import DatasetStatistics
@ -130,19 +132,20 @@ def render_dataset_page():
tab_names.append("🌍 AlphaEarth") tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members: if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM") tab_names.append("🏔️ ArcticDEM")
era5_members = [m for m in ensemble.members if m.startswith("ERA5")] era5_members = cast(
list[Literal["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]],
[m for m in ensemble.members if m.startswith("ERA5")],
)
if era5_members: if era5_members:
tab_names.append("🌡️ ERA5") tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names) tabs = st.tabs(tab_names)
with tabs[0]: with tabs[0]:
st.header("🎯 Target Labels Visualization") st.header("🎯 Target Labels Visualization")
if False: # ! DEBUG
train_data_dict = load_training_sets(ensemble) train_data_dict = load_training_sets(ensemble)
render_target_information_tab(train_data_dict) render_target_information_tab(train_data_dict)
with tabs[1]: with tabs[1]:
st.header("📐 Areas Visualization") st.header("📐 Areas Visualization")
if False: # ! DEBUG
render_area_information_tab(grid_gdf) render_area_information_tab(grid_gdf)
tab_index = 2 tab_index = 2
if "AlphaEarth" in ensemble.members: if "AlphaEarth" in ensemble.members:
@ -155,10 +158,16 @@ def render_dataset_page():
if "ArcticDEM" in ensemble.members: if "ArcticDEM" in ensemble.members:
with tabs[tab_index]: with tabs[tab_index]:
st.header("🏔️ ArcticDEM Visualization") st.header("🏔️ ArcticDEM Visualization")
arcticdem_ds = member_datasets["ArcticDEM"].compute()
arcticdem_stats = stats.members["ArcticDEM"]
render_arcticdem_tab(arcticdem_ds, grid_gdf, arcticdem_stats)
tab_index += 1 tab_index += 1
if era5_members: if era5_members:
with tabs[tab_index]: with tabs[tab_index]:
st.header("🌡️ ERA5 Visualization") st.header("🌡️ ERA5 Visualization")
era5_member_dataset = {m: member_datasets[m] for m in era5_members}
era5_member_stats = {m: stats.members[m] for m in era5_members}
render_era5_tab(era5_member_dataset, grid_gdf, era5_member_stats)
st.balloons() st.balloons()
stopwatch.summary() stopwatch.summary()