Add ArcticDEM and ERA5 tab
This commit is contained in:
parent
26de80ee89
commit
c358bb63bc
9 changed files with 1416 additions and 36 deletions
421
src/entropice/dashboard/plots/climate.py
Normal file
421
src/entropice/dashboard/plots/climate.py
Normal file
|
|
@ -0,0 +1,421 @@
|
|||
"""Plots for visualizing ERA5 climate data."""
|
||||
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
import xarray as xr
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
|
||||
|
||||
def create_climate_map(
|
||||
climate_values: xr.DataArray,
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
variable_name: str,
|
||||
make_3d_map: bool,
|
||||
) -> pdk.Deck:
|
||||
"""Create a spatial distribution map for ERA5 climate variables.
|
||||
|
||||
Args:
|
||||
climate_values: Series with cell_ids as index and climate values
|
||||
grid_gdf: GeoDataFrame containing grid cell geometries
|
||||
variable_name: Name of the climate variable being visualized
|
||||
make_3d_map: Whether to render the map in 3D (extruded) or 2D
|
||||
|
||||
Returns:
|
||||
pdk.Deck: A PyDeck map visualization of the climate variable
|
||||
|
||||
"""
|
||||
# Subsample if too many cells for performance
|
||||
n_cells = len(climate_values["cell_ids"])
|
||||
if n_cells > 100000:
|
||||
rng = np.random.default_rng(42)
|
||||
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
||||
climate_values = climate_values.isel(cell_ids=cell_indices)
|
||||
# Create a copy to avoid modifying the original
|
||||
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
||||
|
||||
# Convert to DataFrame for easier merging
|
||||
climate_df = climate_values.to_dataframe(name="climate_value")
|
||||
|
||||
# Reset index if cell_id is already the index
|
||||
if gdf.index.name == "cell_id":
|
||||
gdf = gdf.reset_index()
|
||||
|
||||
# Filter grid to only cells that have climate data
|
||||
gdf = gdf[gdf["cell_id"].isin(climate_df.index)]
|
||||
gdf = gdf.set_index("cell_id")
|
||||
|
||||
# Merge climate values with grid geometries
|
||||
gdf = gdf.join(climate_df, how="inner")
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues for hex cells
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Get colormap
|
||||
cmap = get_cmap(variable_name)
|
||||
|
||||
# Normalize the climate values to [0, 1] for color mapping
|
||||
values = gdf_wgs84["climate_value"].to_numpy()
|
||||
|
||||
# Use percentiles to avoid outliers
|
||||
vmin, vmax = np.nanpercentile(values, [2, 98])
|
||||
if vmax > vmin:
|
||||
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
else:
|
||||
normalized_values = np.zeros_like(values)
|
||||
|
||||
# Map normalized values to colors
|
||||
colors = [cmap(val) for val in normalized_values]
|
||||
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
|
||||
gdf_wgs84["fill_color"] = rgb_colors
|
||||
|
||||
# Store climate value for tooltip
|
||||
gdf_wgs84["climate_value_display"] = values
|
||||
|
||||
# Store normalized values for elevation (if 3D)
|
||||
gdf_wgs84["elevation"] = normalized_values
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"fill_color": row["fill_color"],
|
||||
"climate_value": float(row["climate_value_display"]),
|
||||
"elevation": float(row["elevation"]) if make_3d_map else 0,
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=0.7,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=make_3d_map,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation" if make_3d_map else 0,
|
||||
elevation_scale=500000,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not make_3d_map else 1.5,
|
||||
pitch=0 if not make_3d_map else 45,
|
||||
)
|
||||
|
||||
# Build tooltip HTML
|
||||
tooltip_html = f"<b>{variable_name}:</b> {{climate_value}}"
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": tooltip_html,
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
return deck
|
||||
|
||||
|
||||
def create_climate_trend_plot(climate_data: xr.DataArray, variable_name: str) -> go.Figure:
|
||||
"""Create a trend plot for climate variables over time.
|
||||
|
||||
Args:
|
||||
climate_data: DataArray containing climate variable with 'year' dimension
|
||||
variable_name: Name of the variable being plotted
|
||||
|
||||
Returns:
|
||||
Plotly Figure with trend plot
|
||||
|
||||
"""
|
||||
# Subsample if too many cells for performance
|
||||
n_cells = len(climate_data["cell_ids"])
|
||||
if n_cells > 10000:
|
||||
rng = np.random.default_rng(42)
|
||||
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
||||
climate_data = climate_data.isel({"cell_ids": cell_indices})
|
||||
|
||||
# Get years
|
||||
years = climate_data["year"].to_numpy()
|
||||
|
||||
# Calculate statistics over space for each year
|
||||
mean_values = climate_data.mean(dim="cell_ids").to_numpy()
|
||||
min_values = climate_data.min(dim="cell_ids").to_numpy()
|
||||
max_values = climate_data.max(dim="cell_ids").to_numpy()
|
||||
p10_values = climate_data.quantile(0.10, dim="cell_ids").to_numpy()
|
||||
p90_values = climate_data.quantile(0.90, dim="cell_ids").to_numpy()
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Add min/max range (background)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=min_values,
|
||||
mode="lines",
|
||||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||
name="Min/Max Range",
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=max_values,
|
||||
mode="lines",
|
||||
fill="tonexty",
|
||||
fillcolor="rgba(200, 200, 200, 0.1)",
|
||||
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Add 10th-90th percentile band
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=p10_values,
|
||||
mode="lines",
|
||||
line={"width": 0},
|
||||
showlegend=False,
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=p90_values,
|
||||
mode="lines",
|
||||
fill="tonexty",
|
||||
fillcolor="rgba(33, 150, 243, 0.2)",
|
||||
line={"width": 0},
|
||||
name="10th-90th Percentile",
|
||||
)
|
||||
)
|
||||
|
||||
# Add mean line
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=mean_values,
|
||||
mode="lines+markers",
|
||||
name="Mean",
|
||||
line={"color": "#1976D2", "width": 2},
|
||||
marker={"size": 6},
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"{variable_name} Over Time (Spatial Statistics)",
|
||||
xaxis_title="Year",
|
||||
yaxis_title=variable_name,
|
||||
height=450,
|
||||
hovermode="x unified",
|
||||
legend={
|
||||
"orientation": "h",
|
||||
"yanchor": "bottom",
|
||||
"y": 1.02,
|
||||
"xanchor": "right",
|
||||
"x": 1,
|
||||
},
|
||||
)
|
||||
|
||||
fig.update_xaxes(dtick=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_climate_distribution_plot(climate_ds: xr.Dataset, variables: list[str]) -> go.Figure:
|
||||
"""Create distribution plots for climate variables.
|
||||
|
||||
Args:
|
||||
climate_ds: Xarray Dataset containing ERA5 climate features
|
||||
variables: List of variable names to plot
|
||||
|
||||
Returns:
|
||||
Plotly Figure with distribution plots
|
||||
|
||||
"""
|
||||
# Subsample if too many cells for performance
|
||||
n_cells = len(climate_ds.cell_ids)
|
||||
if n_cells > 10000:
|
||||
rng = np.random.default_rng(42)
|
||||
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
||||
climate_ds = climate_ds.isel(cell_ids=cell_indices)
|
||||
|
||||
# Create subplots - one row per variable
|
||||
n_rows = len(variables)
|
||||
fig = make_subplots(
|
||||
rows=n_rows,
|
||||
cols=1,
|
||||
subplot_titles=[v.replace("_", " ").title() for v in variables],
|
||||
vertical_spacing=0.15 / max(n_rows, 1),
|
||||
)
|
||||
|
||||
# Color palette
|
||||
colors = [
|
||||
"#1f77b4",
|
||||
"#ff7f0e",
|
||||
"#2ca02c",
|
||||
"#d62728",
|
||||
"#9467bd",
|
||||
"#8c564b",
|
||||
"#e377c2",
|
||||
"#7f7f7f",
|
||||
]
|
||||
|
||||
for row_idx, variable in enumerate(variables, start=1):
|
||||
if variable not in climate_ds.data_vars:
|
||||
continue
|
||||
|
||||
# Get data
|
||||
data = climate_ds[variable]
|
||||
values = data.to_numpy().flatten()
|
||||
values = values[~np.isnan(values)]
|
||||
|
||||
if len(values) == 0:
|
||||
continue
|
||||
|
||||
# Add violin plot
|
||||
fig.add_trace(
|
||||
go.Violin(
|
||||
y=values,
|
||||
name=variable.replace("_", " ").title(),
|
||||
box_visible=True,
|
||||
meanline_visible=True,
|
||||
line_color=colors[row_idx % len(colors)],
|
||||
showlegend=False,
|
||||
),
|
||||
row=row_idx,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
height=300 * n_rows,
|
||||
title_text="Climate Variable Distributions",
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
# Update y-axes labels
|
||||
for i in range(n_rows):
|
||||
fig.update_yaxes(title_text="Value", row=i + 1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_temperature_comparison_plot(climate_ds: xr.Dataset) -> go.Figure:
|
||||
"""Create a comparison plot for temperature-related variables.
|
||||
|
||||
Args:
|
||||
climate_ds: Xarray Dataset containing temperature variables
|
||||
|
||||
Returns:
|
||||
Plotly Figure with temperature comparison
|
||||
|
||||
"""
|
||||
# Temperature variables to compare
|
||||
temp_vars = ["t2m_max", "t2m_mean", "t2m_min"]
|
||||
available_vars = [v for v in temp_vars if v in climate_ds.data_vars]
|
||||
|
||||
if not available_vars:
|
||||
return go.Figure()
|
||||
|
||||
# Get years
|
||||
years = climate_ds["year"].to_numpy()
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
colors = {"t2m_max": "#d32f2f", "t2m_mean": "#1976d2", "t2m_min": "#0288d1"}
|
||||
|
||||
for var in available_vars:
|
||||
# Calculate spatial mean for each year
|
||||
values = climate_ds[var].mean(dim="cell_ids").to_numpy()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=years,
|
||||
y=values - 273.15, # Convert to Celsius
|
||||
mode="lines+markers",
|
||||
name=var.replace("_", " ").title(),
|
||||
line={"color": colors.get(var, "#666666"), "width": 2},
|
||||
marker={"size": 4},
|
||||
)
|
||||
)
|
||||
|
||||
# Add freezing point reference line
|
||||
fig.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Freezing Point (0°C)")
|
||||
|
||||
fig.update_layout(
|
||||
title="Temperature Extremes Over Time",
|
||||
xaxis_title="Year",
|
||||
yaxis_title="Temperature (°C)",
|
||||
height=450,
|
||||
hovermode="x unified",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_seasonal_pattern_plot(climate_ds: xr.Dataset, variable: str) -> go.Figure:
|
||||
"""Create a plot showing seasonal patterns.
|
||||
|
||||
Args:
|
||||
climate_ds: Xarray Dataset with month dimension
|
||||
variable: Variable name to plot
|
||||
|
||||
Returns:
|
||||
Plotly Figure with seasonal patterns
|
||||
|
||||
"""
|
||||
if variable not in climate_ds.data_vars or "month" not in climate_ds.dims:
|
||||
return go.Figure()
|
||||
|
||||
# Get unique months/seasons
|
||||
months = climate_ds["month"].to_numpy()
|
||||
|
||||
# Calculate mean across space and years for each month
|
||||
values = climate_ds[variable].mean(dim=["cell_ids", "year"]).to_numpy()
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=months,
|
||||
y=values,
|
||||
marker_color="#1976d2",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"{variable.replace('_', ' ').title()} by Season",
|
||||
xaxis_title="Season",
|
||||
yaxis_title=variable.replace("_", " ").title(),
|
||||
height=400,
|
||||
)
|
||||
|
||||
return fig
|
||||
418
src/entropice/dashboard/plots/terrain.py
Normal file
418
src/entropice/dashboard/plots/terrain.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
"""Plots for visualizing ArcticDEM terrain features."""
|
||||
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
import xarray as xr
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
|
||||
|
||||
def create_terrain_map(
|
||||
terrain_values: pd.Series,
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
variable_name: str,
|
||||
make_3d_map: bool,
|
||||
) -> pdk.Deck:
|
||||
"""Create a spatial distribution map for ArcticDEM terrain features.
|
||||
|
||||
Args:
|
||||
terrain_values: Series with cell_ids as index and terrain values
|
||||
grid_gdf: GeoDataFrame containing grid cell geometries
|
||||
variable_name: Name of the terrain variable being visualized
|
||||
make_3d_map: Whether to render the map in 3D (extruded) or 2D
|
||||
|
||||
Returns:
|
||||
pdk.Deck: A PyDeck map visualization of the terrain feature
|
||||
|
||||
"""
|
||||
# Subsample if too many cells for performance
|
||||
n_cells = len(terrain_values)
|
||||
if n_cells > 100000:
|
||||
rng = np.random.default_rng(42)
|
||||
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
||||
terrain_values = terrain_values.iloc[cell_indices]
|
||||
|
||||
# Create a copy to avoid modifying the original
|
||||
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
||||
|
||||
# Reset index if cell_id is already the index
|
||||
if gdf.index.name == "cell_id":
|
||||
gdf = gdf.reset_index()
|
||||
|
||||
# Filter grid to only cells that have terrain data
|
||||
gdf = gdf[gdf["cell_id"].isin(terrain_values.index)]
|
||||
gdf = gdf.set_index("cell_id")
|
||||
|
||||
# Merge terrain values with grid geometries
|
||||
gdf = gdf.join(terrain_values.to_frame("terrain_value"), how="inner")
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues for hex cells
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Get colormap - use special colormap for aspect (circular)
|
||||
if "aspect" in variable_name.lower():
|
||||
cmap = get_cmap("aspect")
|
||||
else:
|
||||
cmap = get_cmap("terrain")
|
||||
|
||||
# Normalize the terrain values to [0, 1] for color mapping
|
||||
values = gdf_wgs84["terrain_value"].to_numpy()
|
||||
|
||||
# Handle aspect specially (0-360 degrees, circular)
|
||||
if "aspect" in variable_name.lower():
|
||||
vmin, vmax = 0, 360
|
||||
normalized_values = values / 360
|
||||
else:
|
||||
# Use percentiles to avoid outliers
|
||||
vmin, vmax = np.nanpercentile(values, [2, 98])
|
||||
if vmax > vmin:
|
||||
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
else:
|
||||
normalized_values = np.zeros_like(values)
|
||||
|
||||
# Map normalized values to colors
|
||||
colors = [cmap(val) for val in normalized_values]
|
||||
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
|
||||
gdf_wgs84["fill_color"] = rgb_colors
|
||||
|
||||
# Store terrain value for tooltip
|
||||
gdf_wgs84["terrain_value_display"] = values
|
||||
|
||||
# Store normalized values for elevation (if 3D)
|
||||
gdf_wgs84["elevation"] = normalized_values
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"fill_color": row["fill_color"],
|
||||
"terrain_value": float(row["terrain_value_display"]),
|
||||
"elevation": float(row["elevation"]) if make_3d_map else 0,
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=0.7,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=make_3d_map,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation" if make_3d_map else 0,
|
||||
elevation_scale=500000,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not make_3d_map else 1.5,
|
||||
pitch=0 if not make_3d_map else 45,
|
||||
)
|
||||
|
||||
# Build tooltip HTML
|
||||
tooltip_html = f"<b>{variable_name}:</b> {{terrain_value}}"
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": tooltip_html,
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
return deck
|
||||
|
||||
|
||||
def create_terrain_distribution_plot(arcticdem_ds: xr.Dataset, features: list[str]) -> go.Figure:
|
||||
"""Create distribution plots for terrain features.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
features: List of feature names to plot
|
||||
|
||||
Returns:
|
||||
Plotly Figure with distribution plots
|
||||
|
||||
"""
|
||||
# Subsample if too many cells for performance
|
||||
n_cells = len(arcticdem_ds.cell_ids)
|
||||
if n_cells > 10000:
|
||||
rng = np.random.default_rng(42)
|
||||
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
||||
arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices)
|
||||
|
||||
# Determine aggregation types available
|
||||
aggs = list(arcticdem_ds.coords["aggregations"].values)
|
||||
|
||||
# Create subplots - one row per aggregation type
|
||||
n_rows = len(aggs)
|
||||
fig = make_subplots(
|
||||
rows=n_rows,
|
||||
cols=1,
|
||||
subplot_titles=[f"{agg.title()} Values" for agg in aggs],
|
||||
vertical_spacing=0.15 / max(n_rows, 1),
|
||||
)
|
||||
|
||||
# Color palette for features
|
||||
colors = [
|
||||
"#1f77b4",
|
||||
"#ff7f0e",
|
||||
"#2ca02c",
|
||||
"#d62728",
|
||||
"#9467bd",
|
||||
"#8c564b",
|
||||
"#e377c2",
|
||||
"#7f7f7f",
|
||||
"#bcbd22",
|
||||
"#17becf",
|
||||
]
|
||||
|
||||
for row_idx, agg in enumerate(aggs, start=1):
|
||||
for feat_idx, feature in enumerate(features):
|
||||
# Get the data for this feature and aggregation
|
||||
var_name = f"{feature}"
|
||||
if var_name not in arcticdem_ds.data_vars:
|
||||
continue
|
||||
|
||||
# Extract values for this aggregation
|
||||
data = arcticdem_ds[var_name].sel(aggregations=agg)
|
||||
values = data.to_numpy().flatten()
|
||||
values = values[~np.isnan(values)]
|
||||
|
||||
if len(values) == 0:
|
||||
continue
|
||||
|
||||
# Add violin plot
|
||||
fig.add_trace(
|
||||
go.Violin(
|
||||
y=values,
|
||||
name=feature.replace("_", " ").title(),
|
||||
box_visible=True,
|
||||
meanline_visible=True,
|
||||
line_color=colors[feat_idx % len(colors)],
|
||||
showlegend=(row_idx == 1),
|
||||
),
|
||||
row=row_idx,
|
||||
col=1,
|
||||
)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
height=300 * n_rows,
|
||||
title_text="Terrain Feature Distributions by Aggregation Type",
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
# Update y-axes labels
|
||||
for i in range(n_rows):
|
||||
fig.update_yaxes(title_text="Value", row=i + 1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_aspect_rose_diagram(aspect_values: np.ndarray) -> go.Figure:
|
||||
"""Create a rose diagram (circular histogram) for aspect values.
|
||||
|
||||
Args:
|
||||
aspect_values: Array of aspect values in degrees (0-360)
|
||||
|
||||
Returns:
|
||||
Plotly Figure with rose diagram
|
||||
|
||||
"""
|
||||
# Remove NaN values
|
||||
aspect_values = aspect_values[~np.isnan(aspect_values)]
|
||||
|
||||
# Subsample if too many values for performance
|
||||
if len(aspect_values) > 50000:
|
||||
rng = np.random.default_rng(42)
|
||||
indices = rng.choice(len(aspect_values), size=50000, replace=False)
|
||||
aspect_values = aspect_values[indices]
|
||||
|
||||
if len(aspect_values) == 0:
|
||||
# Return empty figure
|
||||
return go.Figure()
|
||||
|
||||
# Create bins for aspect (every 10 degrees)
|
||||
bins = np.arange(0, 361, 10)
|
||||
bin_counts, _ = np.histogram(aspect_values, bins=bins)
|
||||
|
||||
# Calculate bin centers in degrees and radians
|
||||
bin_centers_deg = (bins[:-1] + bins[1:]) / 2
|
||||
bin_centers_rad = np.deg2rad(bin_centers_deg)
|
||||
|
||||
# Close the circle
|
||||
bin_centers_rad = np.append(bin_centers_rad, bin_centers_rad[0])
|
||||
bin_counts = np.append(bin_counts, bin_counts[0])
|
||||
|
||||
# Create polar bar chart
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Barpolar(
|
||||
r=bin_counts,
|
||||
theta=np.rad2deg(bin_centers_rad),
|
||||
width=10,
|
||||
marker={
|
||||
"color": bin_centers_rad,
|
||||
"colorscale": "HSV",
|
||||
"cmin": 0,
|
||||
"cmax": 2 * np.pi,
|
||||
"showscale": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Aspect Rose Diagram",
|
||||
polar={
|
||||
"radialaxis": {"title": "Frequency", "showticklabels": True},
|
||||
"angularaxis": {
|
||||
"direction": "clockwise",
|
||||
"rotation": 90,
|
||||
"tickmode": "array",
|
||||
"tickvals": [0, 45, 90, 135, 180, 225, 270, 315],
|
||||
"ticktext": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"],
|
||||
},
|
||||
},
|
||||
showlegend=False,
|
||||
height=500,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_slope_aspect_scatter(slope_values: np.ndarray, aspect_values: np.ndarray) -> go.Figure:
|
||||
"""Create a scatter plot showing the relationship between slope and aspect.
|
||||
|
||||
Args:
|
||||
slope_values: Array of slope values
|
||||
aspect_values: Array of aspect values in degrees (0-360)
|
||||
|
||||
Returns:
|
||||
Plotly Figure with scatter plot
|
||||
|
||||
"""
|
||||
# Create DataFrame and remove NaN
|
||||
df = pd.DataFrame({"slope": slope_values.flatten(), "aspect": aspect_values.flatten()})
|
||||
df = df.dropna()
|
||||
|
||||
if len(df) == 0:
|
||||
return go.Figure()
|
||||
|
||||
# Sample if too many points for performance
|
||||
if len(df) > 50000:
|
||||
df = df.sample(n=50000, random_state=42)
|
||||
|
||||
# Create 2D histogram (density plot)
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Histogram2d(
|
||||
x=df["aspect"],
|
||||
y=df["slope"],
|
||||
colorscale="Viridis",
|
||||
nbinsx=36, # 10-degree bins for aspect
|
||||
nbinsy=50,
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Slope vs Aspect Distribution",
|
||||
xaxis_title="Aspect (degrees)",
|
||||
yaxis_title="Slope (degrees)",
|
||||
height=500,
|
||||
)
|
||||
|
||||
# Add directional labels on x-axis
|
||||
fig.update_xaxes(
|
||||
tickmode="array",
|
||||
tickvals=[0, 45, 90, 135, 180, 225, 270, 315, 360],
|
||||
ticktext=["N", "NE", "E", "SE", "S", "SW", "W", "NW", "N"],
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_correlation_heatmap(arcticdem_ds: xr.Dataset, features: list[str], agg: str) -> go.Figure:
|
||||
"""Create a correlation heatmap for terrain features.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
features: List of feature names to include
|
||||
agg: Aggregation type to use
|
||||
|
||||
Returns:
|
||||
Plotly Figure with correlation heatmap
|
||||
|
||||
"""
|
||||
# Extract data for each feature
|
||||
data_dict = {}
|
||||
for feature in features:
|
||||
var_name = f"{feature}"
|
||||
if var_name in arcticdem_ds.data_vars:
|
||||
data = arcticdem_ds[var_name].sel(aggregations=agg)
|
||||
values = data.to_numpy().flatten()
|
||||
data_dict[feature] = values
|
||||
|
||||
if not data_dict:
|
||||
return go.Figure()
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(data_dict)
|
||||
df = df.dropna()
|
||||
|
||||
# Sample if too many rows
|
||||
if len(df) > 50000:
|
||||
df = df.sample(n=50000, random_state=42)
|
||||
|
||||
# Calculate correlation matrix
|
||||
corr = df.corr()
|
||||
|
||||
# Create heatmap
|
||||
fig = go.Figure(
|
||||
data=go.Heatmap(
|
||||
z=corr.values,
|
||||
x=[f.replace("_", " ").title() for f in corr.columns],
|
||||
y=[f.replace("_", " ").title() for f in corr.index],
|
||||
colorscale="RdBu",
|
||||
zmid=0,
|
||||
zmin=-1,
|
||||
zmax=1,
|
||||
text=np.round(corr.values, 2),
|
||||
texttemplate="%{text}",
|
||||
textfont={"size": 10},
|
||||
colorbar={"title": "Correlation"},
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Feature Correlation Matrix ({agg.title()})",
|
||||
height=600,
|
||||
xaxis={"side": "bottom"},
|
||||
)
|
||||
|
||||
return fig
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
"""AlphaEarth embeddings dashboard section."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
|
|
@ -61,8 +63,6 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
|
|||
"""
|
||||
)
|
||||
|
||||
cols = st.columns([4, 1])
|
||||
with cols[0]:
|
||||
if "year" in embedding_values.dims or "year" in embedding_values.coords:
|
||||
year_values = embedding_values["year"].values.tolist()
|
||||
year = st.slider(
|
||||
|
|
@ -74,8 +74,7 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF
|
|||
help="Select the year for which to visualize the embeddings.",
|
||||
)
|
||||
embedding_values = embedding_values.sel(year=year)
|
||||
with cols[1]:
|
||||
make_3d_map = st.checkbox("3D Map", value=True)
|
||||
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="embedding_map_3d"))
|
||||
|
||||
# Check if subsampling will occur
|
||||
n_cells = len(embedding_values["cell_ids"])
|
||||
|
|
|
|||
317
src/entropice/dashboard/sections/arcticdem.py
Normal file
317
src/entropice/dashboard/sections/arcticdem.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
"""ArcticDEM terrain features dashboard section."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.plots.terrain import (
|
||||
create_aspect_rose_diagram,
|
||||
create_correlation_heatmap,
|
||||
create_slope_aspect_scatter,
|
||||
create_terrain_distribution_plot,
|
||||
create_terrain_map,
|
||||
)
|
||||
from entropice.dashboard.sections.dataset_statistics import render_member_details
|
||||
from entropice.dashboard.utils.colors import get_cmap
|
||||
from entropice.dashboard.utils.stats import MemberStatistics
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_terrain_map(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
|
||||
"""Visualize spatial distribution of terrain features.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
grid_gdf: GeoDataFrame with grid cell geometries
|
||||
|
||||
"""
|
||||
st.subheader("Spatial Distribution of Terrain Features")
|
||||
|
||||
# Get available features and aggregations
|
||||
available_vars = list(arcticdem_ds.data_vars)
|
||||
aggs = list(arcticdem_ds.coords["aggregations"].values)
|
||||
|
||||
# Feature selection with grouping
|
||||
feature_groups = {
|
||||
"TPI (Topographic Position Index)": ["tpi_small", "tpi_medium", "tpi_large"],
|
||||
"TRI (Terrain Ruggedness Index)": ["tri_small", "tri_medium", "tri_large"],
|
||||
"VRM (Vector Ruggedness Measure)": ["vrm_small", "vrm_medium", "vrm_large"],
|
||||
"Slope & Curvature": ["slope", "curvature"],
|
||||
"Aspect": ["aspect"],
|
||||
}
|
||||
|
||||
# Flatten to get all features
|
||||
all_features = [f for features in feature_groups.values() for f in features if f in available_vars]
|
||||
|
||||
cols = st.columns([2, 2, 1])
|
||||
with cols[0]:
|
||||
# Feature selection with nice formatting
|
||||
selected_feature = st.selectbox(
|
||||
"Terrain Feature",
|
||||
options=all_features,
|
||||
format_func=lambda x: x.replace("_", " ").title(),
|
||||
key="terrain_feature",
|
||||
)
|
||||
|
||||
with cols[1]:
|
||||
# Aggregation selection
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation",
|
||||
options=aggs,
|
||||
key="terrain_agg",
|
||||
)
|
||||
|
||||
with cols[2]:
|
||||
st.write("\n")
|
||||
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="terrain_map_3d"))
|
||||
|
||||
# Extract the data
|
||||
if selected_feature not in arcticdem_ds.data_vars:
|
||||
st.error(f"Feature {selected_feature} not found in dataset")
|
||||
return
|
||||
|
||||
terrain_data = arcticdem_ds[selected_feature].sel(aggregations=selected_agg)
|
||||
terrain_series = terrain_data.to_series()
|
||||
|
||||
# Check if subsampling will occur
|
||||
n_cells = len(terrain_series)
|
||||
if n_cells > 100000:
|
||||
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
|
||||
|
||||
# Create map
|
||||
map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map)
|
||||
st.pydeck_chart(map_deck)
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown(f"**{selected_feature.replace('_', ' ').title()} ({selected_agg})**")
|
||||
|
||||
values = terrain_series.dropna()
|
||||
if len(values) > 0:
|
||||
vmin, vmax = values.min(), values.max()
|
||||
vmean = values.mean()
|
||||
vstd = values.std()
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Min", f"{vmin:.2f}")
|
||||
with col2:
|
||||
st.metric("Mean", f"{vmean:.2f}")
|
||||
with col3:
|
||||
st.metric("Max", f"{vmax:.2f}")
|
||||
with col4:
|
||||
st.metric("Std Dev", f"{vstd:.2f}")
|
||||
|
||||
# Color scale visualization
|
||||
if "aspect" in selected_feature.lower():
|
||||
cmap = get_cmap("aspect")
|
||||
else:
|
||||
cmap = get_cmap("terrain")
|
||||
|
||||
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
|
||||
gradient_css = ", ".join(gradient_colors)
|
||||
|
||||
gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;"
|
||||
st.markdown(
|
||||
f"""
|
||||
<div style="margin-top: 10px;">
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span style="font-size: 0.9em;">{vmin:.2f}</span>
|
||||
<span style="font-size: 0.9em;">{vmax:.2f}</span>
|
||||
</div>
|
||||
<div style="{gradient_style}"></div>
|
||||
</div>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
else:
|
||||
st.warning("No data available for the selected feature and aggregation")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_terrain_distributions(arcticdem_ds: xr.Dataset):
|
||||
"""Display distribution plots for terrain features.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
|
||||
"""
|
||||
st.subheader("Terrain Feature Distributions")
|
||||
|
||||
# Get available features
|
||||
available_vars = list(arcticdem_ds.data_vars)
|
||||
|
||||
# Feature group selection
|
||||
feature_groups = {
|
||||
"Topographic Position (TPI)": ["tpi_small", "tpi_medium", "tpi_large"],
|
||||
"Terrain Ruggedness (TRI)": ["tri_small", "tri_medium", "tri_large"],
|
||||
"Vector Ruggedness (VRM)": ["vrm_small", "vrm_medium", "vrm_large"],
|
||||
"All Scale-Invariant": ["slope", "aspect", "curvature"],
|
||||
}
|
||||
|
||||
selected_group = st.selectbox(
|
||||
"Feature Group",
|
||||
options=list(feature_groups.keys()),
|
||||
key="terrain_dist_group",
|
||||
)
|
||||
|
||||
features_to_plot = [f for f in feature_groups[selected_group] if f in available_vars]
|
||||
|
||||
if features_to_plot:
|
||||
# Check if subsampling will occur
|
||||
n_cells = len(arcticdem_ds.cell_ids)
|
||||
if n_cells > 10000:
|
||||
st.info(
|
||||
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
|
||||
"for performance. Statistics remain representative."
|
||||
)
|
||||
|
||||
fig = create_terrain_distribution_plot(arcticdem_ds, features_to_plot)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
Distribution of **{selected_group}** features across different aggregation types.
|
||||
Violin plots show the full distribution with embedded box plots for quartiles.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
st.warning(f"No features available for {selected_group}")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_terrain_correlations(arcticdem_ds: xr.Dataset):
|
||||
"""Display correlation analysis for terrain features.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
|
||||
"""
|
||||
st.subheader("Feature Correlation Analysis")
|
||||
|
||||
# Get available features
|
||||
available_vars = list(arcticdem_ds.data_vars)
|
||||
aggs = list(arcticdem_ds.coords["aggregations"].values)
|
||||
|
||||
# Select aggregation for correlation
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation Type",
|
||||
options=aggs,
|
||||
key="terrain_corr_agg",
|
||||
)
|
||||
|
||||
# Select which features to include
|
||||
all_features = [str(f) for f in available_vars if f in available_vars]
|
||||
|
||||
if len(all_features) >= 2:
|
||||
fig = create_correlation_heatmap(arcticdem_ds, all_features, selected_agg)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
Correlation heatmap shows relationships between different terrain features.
|
||||
Strong positive correlations (red) or negative correlations (blue) can indicate
|
||||
related terrain characteristics.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
st.warning("Need at least 2 features for correlation analysis")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_slope_aspect_analysis(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame):
|
||||
"""Specialized visualization for slope and aspect relationships.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features
|
||||
grid_gdf: GeoDataFrame with grid cell geometries
|
||||
|
||||
"""
|
||||
st.subheader("Slope & Aspect Analysis")
|
||||
|
||||
# Check if slope and aspect are available
|
||||
if "slope" not in arcticdem_ds.data_vars or "aspect" not in arcticdem_ds.data_vars:
|
||||
st.warning("Slope and/or aspect data not available in this dataset")
|
||||
return
|
||||
|
||||
# Get aggregations
|
||||
aggs = list(arcticdem_ds.coords["aggregations"].values)
|
||||
|
||||
# Select aggregation
|
||||
selected_agg = st.selectbox(
|
||||
"Aggregation Type",
|
||||
options=aggs,
|
||||
key="slope_aspect_agg",
|
||||
)
|
||||
|
||||
# Extract slope and aspect data
|
||||
slope_data = arcticdem_ds["slope"].sel(aggregations=selected_agg).values
|
||||
aspect_data = arcticdem_ds["aspect"].sel(aggregations=selected_agg).values
|
||||
|
||||
# Check if subsampling will occur
|
||||
n_values = len(slope_data.flatten())
|
||||
if n_values > 50000:
|
||||
st.info(
|
||||
f"📊 **Dataset subsampled:** Using 50,000 randomly selected values out of {n_values:,} "
|
||||
"for performance. Distributions remain representative."
|
||||
)
|
||||
|
||||
# Create two columns for visualizations
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("**Aspect Rose Diagram**")
|
||||
fig_rose = create_aspect_rose_diagram(aspect_data)
|
||||
st.plotly_chart(fig_rose, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
The rose diagram shows the directional distribution of terrain aspect.
|
||||
Each bar represents the frequency of slopes facing that direction.
|
||||
"""
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.markdown("**Slope vs Aspect Distribution**")
|
||||
fig_scatter = create_slope_aspect_scatter(slope_data, aspect_data)
|
||||
st.plotly_chart(fig_scatter, width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This 2D histogram shows how slope steepness relates to aspect direction.
|
||||
Patterns can reveal preferential slope orientations (e.g., due to prevailing winds or sun exposure).
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def render_arcticdem_tab(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame, arcticdem_stats: MemberStatistics):
|
||||
"""Render the ArcticDEM visualization tab.
|
||||
|
||||
Args:
|
||||
arcticdem_ds: The ArcticDEM dataset member, lazily loaded.
|
||||
grid_gdf: GeoDataFrame with grid cell geometries
|
||||
arcticdem_stats: Statistics for the ArcticDEM member.
|
||||
|
||||
"""
|
||||
# Render different visualizations
|
||||
with st.expander("ArcticDEM Statistics", expanded=True):
|
||||
render_member_details("ArcticDEM", arcticdem_stats)
|
||||
|
||||
st.divider()
|
||||
|
||||
_render_terrain_map(arcticdem_ds, grid_gdf)
|
||||
|
||||
# st.divider()
|
||||
|
||||
# _render_terrain_distributions(arcticdem_ds)
|
||||
|
||||
st.divider()
|
||||
|
||||
_render_terrain_correlations(arcticdem_ds)
|
||||
|
||||
st.divider()
|
||||
|
||||
_render_slope_aspect_analysis(arcticdem_ds, grid_gdf)
|
||||
|
|
@ -14,16 +14,14 @@ from entropice.dashboard.utils.colors import get_cmap
|
|||
def _render_area_map(grid_gdf: gpd.GeoDataFrame):
|
||||
st.subheader("Spatial Distribution of Grid Cell Areas")
|
||||
|
||||
cols = st.columns([4, 1])
|
||||
with cols[0]:
|
||||
metric = st.selectbox(
|
||||
"Metric",
|
||||
options=["cell_area", "land_area", "water_area", "land_ratio"],
|
||||
format_func=lambda x: x.replace("_", " ").title(),
|
||||
key="metric",
|
||||
)
|
||||
with cols[1]:
|
||||
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
|
||||
|
||||
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d"))
|
||||
|
||||
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
|
||||
st.pydeck_chart(map_deck)
|
||||
|
|
|
|||
|
|
@ -446,8 +446,6 @@ def render_member_details(member: str, member_stats: MemberStatistics):
|
|||
member_stats: Statistics for the member
|
||||
|
||||
"""
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
|
|
@ -539,6 +537,7 @@ def render_ensemble_details(
|
|||
|
||||
# Individual member details
|
||||
for member, stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
render_member_details(member, stats)
|
||||
st.divider()
|
||||
|
||||
|
|
|
|||
219
src/entropice/dashboard/sections/era5.py
Normal file
219
src/entropice/dashboard/sections/era5.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""ERA5 climate data dashboard section."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import geopandas as gpd
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.plots.climate import (
|
||||
create_climate_map,
|
||||
create_climate_trend_plot,
|
||||
)
|
||||
from entropice.dashboard.sections.dataset_statistics import render_member_details
|
||||
from entropice.dashboard.utils.colors import get_cmap
|
||||
from entropice.dashboard.utils.stats import MemberStatistics
|
||||
|
||||
|
||||
def _get_climate_variable_aggregation_season_option(era5_ds: xr.Dataset):
|
||||
available_vars = list(era5_ds.data_vars)
|
||||
selected_var = st.selectbox("Select Climate Variable", available_vars, key="climate_var_map")
|
||||
if "month" in era5_ds.dims:
|
||||
months = era5_ds.month.to_numpy()
|
||||
selected_month = cast(str, st.selectbox("Select Season", options=months.tolist(), key="climate_month"))
|
||||
else:
|
||||
selected_month = None
|
||||
if "aggregations" in era5_ds.dims:
|
||||
aggs = era5_ds.aggregations.to_numpy()
|
||||
selected_agg = cast(str, st.selectbox("Select Aggregation Method", options=aggs.tolist(), key="climate_agg"))
|
||||
else:
|
||||
selected_agg = None
|
||||
return selected_var, selected_agg, selected_month
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_climate_variable_map(climate_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame, selected_var: str):
|
||||
"""Visualize spatial distribution of climate variables.
|
||||
|
||||
Args:
|
||||
climate_values: Xarray DataArray containing ERA5 climate features
|
||||
grid_gdf: GeoDataFrame with grid cell geometries
|
||||
selected_var: Name of the selected climate variable
|
||||
|
||||
"""
|
||||
st.subheader("Spatial Distribution of Climate Variables")
|
||||
|
||||
if "year" in climate_values.dims:
|
||||
years = climate_values.year.to_numpy()
|
||||
selected_year = cast(
|
||||
int, st.select_slider("Select Year", options=years.tolist(), value=int(years.max()), key="climate_year")
|
||||
)
|
||||
climate_values = climate_values.sel(year=selected_year)
|
||||
|
||||
# 3D toggle
|
||||
make_3d = cast(bool, st.toggle("3D Map", value=True, key="climate_map_3d"))
|
||||
|
||||
# Create map
|
||||
n_cells = len(climate_values)
|
||||
if n_cells > 100000:
|
||||
st.info(f"Showing 100,000 / {n_cells:,} cells for performance")
|
||||
|
||||
deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d)
|
||||
st.pydeck_chart(deck, use_container_width=True)
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
st.markdown(f"**{selected_var.replace('_', ' ').title()}**")
|
||||
|
||||
# Get the actual values to show accurate min/max (same as in the map function)
|
||||
values = climate_values.values.flatten()
|
||||
values = values[~np.isnan(values)]
|
||||
vmin, vmax = np.nanpercentile(values, [2, 98])
|
||||
|
||||
vmin_str = f"{vmin:.4f}"
|
||||
vmax_str = f"{vmax:.4f}"
|
||||
|
||||
# Color scale visualization - use appropriate colormap
|
||||
cmap = get_cmap(selected_var)
|
||||
|
||||
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
|
||||
gradient_css = ", ".join(gradient_colors)
|
||||
|
||||
gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;"
|
||||
st.markdown(
|
||||
f"""
|
||||
<div style="margin-top: 10px;">
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span style="font-size: 0.9em;">{vmin_str}</span>
|
||||
<span style="font-size: 0.9em;">{vmax_str}</span>
|
||||
</div>
|
||||
<div style="{gradient_style}"></div>
|
||||
</div>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.caption(
|
||||
"Color intensity represents climate values from low to high. "
|
||||
"Values are normalized using the 2nd-98th percentile range to avoid outliers."
|
||||
)
|
||||
|
||||
if make_3d:
|
||||
st.markdown("---")
|
||||
st.markdown("**3D Elevation:**")
|
||||
st.caption("Height represents normalized climate values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
||||
|
||||
def _render_climate_temporal_trends(
|
||||
climate_values: xr.DataArray, selected_var: str, selected_agg: str | None, selected_month: str | None
|
||||
):
|
||||
"""Display temporal trends of climate variables.
|
||||
|
||||
Args:
|
||||
climate_values: Xarray DataArray containing ERA5 climate features
|
||||
selected_var: Name of the selected climate variable
|
||||
selected_agg: Selected aggregation method
|
||||
selected_month: Selected month/season
|
||||
|
||||
"""
|
||||
st.subheader("Climate Variable Temporal Trends")
|
||||
|
||||
if "year" not in climate_values.dims:
|
||||
st.info("Temporal trends require yearly dimension")
|
||||
return
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows how climate variables have changed over time across the study area.
|
||||
The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal
|
||||
patterns in climate variables that may correlate with environmental changes.
|
||||
|
||||
**Understanding the plot:**
|
||||
- **Mean line** (blue): Average value across all grid cells for each year
|
||||
- **10th-90th percentile band** (light blue): Range containing 80% of the values, showing
|
||||
typical variation
|
||||
- **Min/Max range** (gray): Full extent of values, highlighting outliers
|
||||
"""
|
||||
)
|
||||
# Show dataset filtering info
|
||||
n_years = len(climate_values["year"])
|
||||
n_cells = len(climate_values["cell_ids"])
|
||||
|
||||
agg_info = f", Aggregation: `{selected_agg}`" if selected_agg else ""
|
||||
month_info = f", Season: `{selected_month}`" if selected_month else ""
|
||||
st.caption(
|
||||
f"📊 **Dataset selection:** Variable: `{selected_var}` "
|
||||
f"({n_years} years, {n_cells:,} cells{agg_info}{month_info})"
|
||||
)
|
||||
|
||||
# Check if subsampling will occur
|
||||
if n_cells > 10000:
|
||||
st.info(
|
||||
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
|
||||
"for performance. Statistics remain representative."
|
||||
)
|
||||
|
||||
# Create trend plot
|
||||
fig = create_climate_trend_plot(climate_values, selected_var)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_era5_tab(
|
||||
era5_member_datasets: dict[str, xr.Dataset],
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
era5_member_stats: dict[str, MemberStatistics],
|
||||
):
|
||||
"""Render the ERA5 visualization tab.
|
||||
|
||||
Args:
|
||||
era5_member_datasets: Dictionary of ERA5 member datasets (yearly, seasonal, shoulder)
|
||||
grid_gdf: GeoDataFrame with grid cell geometries
|
||||
era5_member_stats: Dictionary of MemberStatistics for each ERA5 member.
|
||||
|
||||
"""
|
||||
if not era5_member_datasets:
|
||||
st.warning("No ERA5 data available")
|
||||
return
|
||||
|
||||
# Member selection
|
||||
available_members = list(era5_member_datasets.keys())
|
||||
|
||||
if len(available_members) == 1:
|
||||
selected_member = available_members[0]
|
||||
st.info(f"Showing ERA5 data: {selected_member}")
|
||||
else:
|
||||
selected_member = st.selectbox(
|
||||
"Select ERA5 Aggregation Type",
|
||||
available_members,
|
||||
format_func=lambda x: x.replace("_", " ").title(),
|
||||
key="era5_member",
|
||||
)
|
||||
|
||||
era5_stats = era5_member_stats[selected_member]
|
||||
# Load selected dataset
|
||||
era5_ds = era5_member_datasets[selected_member]
|
||||
|
||||
# Render different visualizations
|
||||
with st.expander(f"{selected_member.replace('_', ' ').title()} Statistics", expanded=True):
|
||||
render_member_details(selected_member, era5_stats)
|
||||
|
||||
st.divider()
|
||||
|
||||
selected_var, selected_agg, selected_month = _get_climate_variable_aggregation_season_option(era5_ds)
|
||||
|
||||
climate_values = era5_ds[selected_var]
|
||||
if selected_agg:
|
||||
climate_values = climate_values.sel(aggregations=selected_agg)
|
||||
if selected_month:
|
||||
climate_values = climate_values.sel(month=selected_month)
|
||||
climate_values = climate_values.compute()
|
||||
|
||||
_render_climate_variable_map(climate_values, grid_gdf, selected_var)
|
||||
|
||||
if "year" in climate_values.dims:
|
||||
st.divider()
|
||||
|
||||
_render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month)
|
||||
|
|
@ -185,9 +185,9 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
|
|||
)
|
||||
with cols[2]:
|
||||
# Controls weather a 3D map or a 2D map is shown
|
||||
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
|
||||
make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="target_map_3d"))
|
||||
# Controls what should be shows, either the split or the labels / values
|
||||
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False))
|
||||
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False, key="target_map_show_split"))
|
||||
|
||||
training_set = train_data_dict[selected_target][selected_task]
|
||||
map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
"""Data page: Visualization of the data."""
|
||||
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.sections.alphaearth import render_alphaearth_tab
|
||||
from entropice.dashboard.sections.arcticdem import render_arcticdem_tab
|
||||
from entropice.dashboard.sections.areas import render_area_information_tab
|
||||
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
|
||||
from entropice.dashboard.sections.era5 import render_era5_tab
|
||||
from entropice.dashboard.sections.targets import render_target_information_tab
|
||||
from entropice.dashboard.utils.loaders import load_training_sets
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics
|
||||
|
|
@ -130,19 +132,20 @@ def render_dataset_page():
|
|||
tab_names.append("🌍 AlphaEarth")
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
tab_names.append("🏔️ ArcticDEM")
|
||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||
era5_members = cast(
|
||||
list[Literal["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]],
|
||||
[m for m in ensemble.members if m.startswith("ERA5")],
|
||||
)
|
||||
if era5_members:
|
||||
tab_names.append("🌡️ ERA5")
|
||||
tabs = st.tabs(tab_names)
|
||||
|
||||
with tabs[0]:
|
||||
st.header("🎯 Target Labels Visualization")
|
||||
if False: # ! DEBUG
|
||||
train_data_dict = load_training_sets(ensemble)
|
||||
render_target_information_tab(train_data_dict)
|
||||
with tabs[1]:
|
||||
st.header("📐 Areas Visualization")
|
||||
if False: # ! DEBUG
|
||||
render_area_information_tab(grid_gdf)
|
||||
tab_index = 2
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
|
|
@ -155,10 +158,16 @@ def render_dataset_page():
|
|||
if "ArcticDEM" in ensemble.members:
|
||||
with tabs[tab_index]:
|
||||
st.header("🏔️ ArcticDEM Visualization")
|
||||
arcticdem_ds = member_datasets["ArcticDEM"].compute()
|
||||
arcticdem_stats = stats.members["ArcticDEM"]
|
||||
render_arcticdem_tab(arcticdem_ds, grid_gdf, arcticdem_stats)
|
||||
tab_index += 1
|
||||
if era5_members:
|
||||
with tabs[tab_index]:
|
||||
st.header("🌡️ ERA5 Visualization")
|
||||
era5_member_dataset = {m: member_datasets[m] for m in era5_members}
|
||||
era5_member_stats = {m: stats.members[m] for m in era5_members}
|
||||
render_era5_tab(era5_member_dataset, grid_gdf, era5_member_stats)
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue