diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py
new file mode 100644
index 0000000..6801050
--- /dev/null
+++ b/src/entropice/dashboard/plots/source_data.py
@@ -0,0 +1,798 @@
+"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
+
+import antimeridian
+import geopandas as gpd
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
+import pydeck as pdk
+import streamlit as st
+import xarray as xr
+from shapely.geometry import shape
+
+from entropice.dashboard.plots.colors import get_cmap
+
+
+def _fix_hex_geometry(geom):
+ """Fix hexagon geometry crossing the antimeridian."""
+ try:
+ return shape(antimeridian.fix_shape(geom))
+ except ValueError as e:
+ st.error(f"Error fixing geometry: {e}")
+ return geom
+
+
+def render_alphaearth_overview(ds: xr.Dataset):
+ """Render overview statistics for AlphaEarth embeddings data.
+
+ Args:
+ ds: xarray Dataset containing AlphaEarth embeddings.
+
+ """
+ st.subheader("π AlphaEarth Embeddings Statistics")
+
+ # Overall statistics
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+
+ with col2:
+ st.metric("Embedding Dimensions", f"{len(ds['band'])}")
+
+ with col3:
+ st.metric("Years Available", f"{len(ds['year'])}")
+
+ with col4:
+ st.metric("Aggregations", f"{len(ds['agg'])}")
+
+ # Show temporal coverage
+ st.markdown("**Temporal Coverage:**")
+ years = sorted(ds["year"].values)
+ st.write(f"Years: {min(years)} - {max(years)}")
+
+ # Show aggregations
+ st.markdown("**Available Aggregations:**")
+ aggs = ds["agg"].to_numpy()
+ st.write(", ".join(str(a) for a in aggs))
+
+
+@st.fragment
+def render_alphaearth_plots(ds: xr.Dataset):
+ """Render interactive plots for AlphaEarth embeddings data.
+
+ Args:
+ ds: xarray Dataset containing AlphaEarth embeddings.
+
+ """
+ st.markdown("---")
+ st.markdown("**Embedding Distribution by Band**")
+
+ embeddings_data = ds["embeddings"]
+
+ # Select year and aggregation for visualization
+ col1, col2 = st.columns(2)
+ with col1:
+ selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
+ with col2:
+ selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
+
+ # Get data for selected year and aggregation
+ year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
+
+ # Calculate statistics for each band
+ band_stats = []
+ for band_idx in range(len(ds["band"])):
+ band_data = year_agg_data.isel(band=band_idx).values.flatten()
+ band_data = band_data[~np.isnan(band_data)] # Remove NaN values
+
+ if len(band_data) > 0:
+ band_stats.append(
+ {
+ "Band": band_idx,
+ "Mean": float(np.mean(band_data)),
+ "Std": float(np.std(band_data)),
+ "Min": float(np.min(band_data)),
+ "25%": float(np.percentile(band_data, 25)),
+ "Median": float(np.median(band_data)),
+ "75%": float(np.percentile(band_data, 75)),
+ "Max": float(np.max(band_data)),
+ }
+ )
+
+ band_df = pd.DataFrame(band_stats)
+
+ # Create plot showing distribution across bands
+ fig = go.Figure()
+
+ # Add min/max range first (background)
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Min"],
+ mode="lines",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ name="Min/Max Range",
+ showlegend=True,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Max"],
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(200, 200, 200, 0.1)",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ showlegend=False,
+ )
+ )
+
+ # Add std band
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Mean"] - band_df["Std"],
+ mode="lines",
+ line={"width": 0},
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Mean"] + band_df["Std"],
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(31, 119, 180, 0.2)",
+ line={"width": 0},
+ name="Β±1 Std",
+ )
+ )
+
+ # Add mean line on top
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Mean"],
+ mode="lines+markers",
+ name="Mean",
+ line={"color": "#1f77b4", "width": 2},
+ marker={"size": 4},
+ )
+ )
+
+ fig.update_layout(
+ title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
+ xaxis_title="Band",
+ yaxis_title="Embedding Value",
+ height=450,
+ hovermode="x unified",
+ )
+
+ st.plotly_chart(fig, use_container_width=True)
+
+ # Band statistics
+ with st.expander("π Statistics by Embedding Band", expanded=False):
+ st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
+
+ # Calculate statistics for each band
+ band_stats = []
+ for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
+ band_data = embeddings_data.isel(band=band_idx)
+ band_stats.append(
+ {
+ "Band": int(band_idx),
+ "Mean": float(band_data.mean().values),
+ "Std": float(band_data.std().values),
+ "Min": float(band_data.min().values),
+ "Max": float(band_data.max().values),
+ }
+ )
+
+ band_df = pd.DataFrame(band_stats)
+ st.dataframe(band_df, use_container_width=True, hide_index=True)
+
+ if len(ds["band"]) > 10:
+ st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
+
+
+def render_arcticdem_overview(ds: xr.Dataset):
+ """Render overview statistics for ArcticDEM terrain data.
+
+ Args:
+ ds: xarray Dataset containing ArcticDEM data.
+
+ """
+ st.subheader("ποΈ ArcticDEM Terrain Statistics")
+
+ # Overall statistics
+ col1, col2, col3 = st.columns(3)
+
+ with col1:
+ st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+
+ with col2:
+ st.metric("Variables", f"{len(ds.data_vars)}")
+
+ with col3:
+ st.metric("Aggregations", f"{len(ds['aggregations'])}")
+
+ # Show available variables
+ st.markdown("**Available Variables:**")
+ variables = list(ds.data_vars)
+ st.write(", ".join(variables))
+
+ # Show aggregations
+ st.markdown("**Available Aggregations:**")
+ aggs = ds["aggregations"].to_numpy()
+ st.write(", ".join(str(a) for a in aggs))
+
+ # Statistics by variable
+ st.markdown("---")
+ st.markdown("**Variable Statistics (across all aggregations)**")
+
+ var_stats = []
+ for var_name in ds.data_vars:
+ var_data = ds[var_name]
+ var_stats.append(
+ {
+ "Variable": var_name,
+ "Mean": float(var_data.mean().values),
+ "Std": float(var_data.std().values),
+ "Min": float(var_data.min().values),
+ "Max": float(var_data.max().values),
+ "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
+ }
+ )
+
+ stats_df = pd.DataFrame(var_stats)
+ st.dataframe(stats_df, use_container_width=True, hide_index=True)
+
+
+@st.fragment
+def render_arcticdem_plots(ds: xr.Dataset):
+ """Render interactive plots for ArcticDEM terrain data.
+
+ Args:
+ ds: xarray Dataset containing ArcticDEM data.
+
+ """
+ st.markdown("---")
+ st.markdown("**Variable Distributions**")
+
+ variables = list(ds.data_vars)
+
+ # Select a variable to visualize
+ selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
+
+ if selected_var:
+ var_data = ds[selected_var]
+
+ # Create histogram
+ fig = go.Figure()
+
+ for agg in ds["aggregations"].to_numpy():
+ agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
+ agg_data = agg_data[~np.isnan(agg_data)]
+
+ fig.add_trace(
+ go.Histogram(
+ x=agg_data,
+ name=str(agg),
+ opacity=0.7,
+ nbinsx=50,
+ )
+ )
+
+ fig.update_layout(
+ title=f"Distribution of {selected_var} by Aggregation",
+ xaxis_title=selected_var,
+ yaxis_title="Count",
+ barmode="overlay",
+ height=400,
+ )
+
+ st.plotly_chart(fig, use_container_width=True)
+
+
+def render_era5_overview(ds: xr.Dataset, temporal_type: str):
+ """Render overview statistics for ERA5 climate data.
+
+ Args:
+ ds: xarray Dataset containing ERA5 data.
+ temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
+
+ """
+ st.subheader(f"π‘οΈ ERA5 Climate Statistics ({temporal_type.capitalize()})")
+
+ # Overall statistics
+ has_agg = "aggregations" in ds.dims
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+
+ with col2:
+ st.metric("Variables", f"{len(ds.data_vars)}")
+
+ with col3:
+ st.metric("Time Steps", f"{len(ds['time'])}")
+
+ with col4:
+ if has_agg:
+ st.metric("Aggregations", f"{len(ds['aggregations'])}")
+ else:
+ st.metric("Aggregations", "1")
+
+ # Show available variables
+ st.markdown("**Available Variables:**")
+ variables = list(ds.data_vars)
+ st.write(", ".join(variables))
+
+ # Show temporal range
+ st.markdown("**Temporal Range:**")
+ time_values = pd.to_datetime(ds["time"].values)
+ st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
+
+ if has_agg:
+ st.markdown("**Available Aggregations:**")
+ aggs = ds["aggregations"].to_numpy()
+ st.write(", ".join(str(a) for a in aggs))
+
+ # Statistics by variable
+ st.markdown("---")
+ st.markdown("**Variable Statistics (across all time steps and aggregations)**")
+
+ var_stats = []
+ for var_name in ds.data_vars:
+ var_data = ds[var_name]
+ var_stats.append(
+ {
+ "Variable": var_name,
+ "Mean": float(var_data.mean().values),
+ "Std": float(var_data.std().values),
+ "Min": float(var_data.min().values),
+ "Max": float(var_data.max().values),
+ "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
+ }
+ )
+
+ stats_df = pd.DataFrame(var_stats)
+ st.dataframe(stats_df, use_container_width=True, hide_index=True)
+
+
+@st.fragment
+def render_era5_plots(ds: xr.Dataset, temporal_type: str):
+ """Render interactive plots for ERA5 climate data.
+
+ Args:
+ ds: xarray Dataset containing ERA5 data.
+ temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
+
+ """
+ st.markdown("---")
+ st.markdown("**Temporal Trends**")
+
+ variables = list(ds.data_vars)
+ has_agg = "aggregations" in ds.dims
+
+ selected_var = st.selectbox(
+ "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
+ )
+
+ if selected_var:
+ var_data = ds[selected_var]
+
+ # Calculate mean over space for each time step
+ if has_agg:
+ # Average over aggregations first, then over cells
+ time_series = var_data.mean(dim=["cell_ids", "aggregations"])
+ else:
+ time_series = var_data.mean(dim="cell_ids")
+
+ time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
+
+ fig = go.Figure()
+
+ fig.add_trace(
+ go.Scatter(
+ x=time_df["Time"],
+ y=time_df["Value"],
+ mode="lines+markers",
+ name=selected_var,
+ line={"width": 2},
+ )
+ )
+
+ fig.update_layout(
+ title=f"Temporal Trend of {selected_var} (Spatial Mean)",
+ xaxis_title="Time",
+ yaxis_title=selected_var,
+ height=400,
+ hovermode="x unified",
+ )
+
+ st.plotly_chart(fig, use_container_width=True)
+
+
+@st.fragment
+def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
+ """Render interactive pydeck map for AlphaEarth embeddings.
+
+ Args:
+ ds: xarray Dataset containing AlphaEarth data.
+ targets: GeoDataFrame with geometry for each cell.
+ grid: Grid type ('hex' or 'healpix').
+
+ """
+ st.subheader("πΊοΈ AlphaEarth Spatial Distribution")
+
+ # Controls
+ col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
+
+ with col1:
+ selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
+
+ with col2:
+ selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
+
+ with col3:
+ selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
+
+ with col4:
+ opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
+
+ # Extract data for selected parameters
+ data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
+
+ # Create GeoDataFrame
+ gdf = targets.copy()
+ gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
+ gdf = gdf.set_index("cell_id")
+
+ # Add values
+ values_df = data_values.to_dataframe(name="value")
+ gdf = gdf.join(values_df, how="inner")
+
+ # Convert to WGS84 first
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix geometries after CRS conversion
+ if grid == "hex":
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
+
+ # Normalize values for color mapping
+ values = gdf_wgs84["value"].to_numpy()
+ vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
+ normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
+
+ # Apply colormap
+ cmap = get_cmap("embeddings")
+ colors = [cmap(val) for val in normalized]
+ gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+
+ # Create GeoJSON
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": {
+ "value": float(row["value"]),
+ "fill_color": row["fill_color"],
+ },
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=opacity,
+ stroked=True,
+ filled=True,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ pickable=True,
+ )
+
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={"html": "Value: {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ st.pydeck_chart(deck)
+
+ # Show statistics
+ st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
+
+
+@st.fragment
+def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
+ """Render interactive pydeck map for ArcticDEM terrain data.
+
+ Args:
+ ds: xarray Dataset containing ArcticDEM data.
+ targets: GeoDataFrame with geometry for each cell.
+ grid: Grid type ('hex' or 'healpix').
+
+ """
+ st.subheader("πΊοΈ ArcticDEM Spatial Distribution")
+
+ # Controls
+ variables = list(ds.data_vars)
+ col1, col2, col3 = st.columns([3, 3, 1])
+
+ with col1:
+ selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
+
+ with col2:
+ selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
+
+ with col3:
+ opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
+
+ # Extract data for selected parameters
+ data_values = ds[selected_var].sel(aggregations=selected_agg)
+
+ # Create GeoDataFrame
+ gdf = targets.copy()
+ gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
+ gdf = gdf.set_index("cell_id")
+
+ # Add values
+ values_df = data_values.to_dataframe(name="value")
+ gdf = gdf.join(values_df, how="inner")
+
+ # Add all aggregation values for tooltip
+ if len(ds["aggregations"]) > 1:
+ for agg in ds["aggregations"].values:
+ agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
+ # Drop the aggregations column to avoid conflicts
+ if "aggregations" in agg_data.columns:
+ agg_data = agg_data.drop(columns=["aggregations"])
+ gdf = gdf.join(agg_data, how="left")
+
+ # Convert to WGS84 first
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix geometries after CRS conversion
+ if grid == "hex":
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
+
+ # Normalize values for color mapping
+ values = gdf_wgs84["value"].values
+ values_clean = values[~np.isnan(values)]
+
+ if len(values_clean) > 0:
+ vmin, vmax = np.nanpercentile(values_clean, [2, 98])
+ normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
+
+ # Apply colormap
+ cmap = get_cmap(f"arcticdem_{selected_var}")
+ colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
+ gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+
+ # Create GeoJSON
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ properties = {
+ "value": float(row["value"]) if not np.isnan(row["value"]) else None,
+ "fill_color": row["fill_color"],
+ }
+ # Add all aggregation values if available
+ if len(ds["aggregations"]) > 1:
+ for agg in ds["aggregations"].values:
+ agg_col = f"agg_{agg}"
+ if agg_col in row.index:
+ properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
+
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": properties,
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=opacity,
+ stroked=True,
+ filled=True,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ pickable=True,
+ )
+
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+
+ # Build tooltip HTML for ArcticDEM
+ if len(ds["aggregations"]) > 1:
+ tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"]
+ tooltip_lines.append("All aggregations:
")
+ for agg in ds["aggregations"].values:
+ tooltip_lines.append(f" {agg}: {{agg_{agg}}}
")
+ tooltip_html = "".join(tooltip_lines)
+ else:
+ tooltip_html = f"{selected_var}: {{value}}"
+
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ st.pydeck_chart(deck)
+
+ # Show statistics
+ st.caption(
+ f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
+ f"Std: {np.nanstd(values_clean):.2f}"
+ )
+ else:
+ st.warning("No valid data available for selected parameters")
+
+
+@st.fragment
+def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
+ """Render interactive pydeck map for ERA5 climate data.
+
+ Args:
+ ds: xarray Dataset containing ERA5 data.
+ targets: GeoDataFrame with geometry for each cell.
+ grid: Grid type ('hex' or 'healpix').
+ temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
+
+ """
+ st.subheader("πΊοΈ ERA5 Spatial Distribution")
+
+ # Controls
+ variables = list(ds.data_vars)
+ has_agg = "aggregations" in ds.dims
+
+ if has_agg:
+ col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
+ else:
+ col1, col2, col3 = st.columns([3, 3, 1])
+
+ with col1:
+ selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
+
+ with col2:
+ # Convert time to readable format
+ time_values = pd.to_datetime(ds["time"].values)
+ time_options = {str(t): t for t in time_values}
+ selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
+
+ if has_agg:
+ with col3:
+ selected_agg = st.selectbox(
+ "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
+ )
+ with col4:
+ opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
+ else:
+ with col3:
+ opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
+
+ # Extract data for selected parameters
+ time_val = time_options[selected_time]
+ if has_agg:
+ data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
+ else:
+ data_values = ds[selected_var].sel(time=time_val)
+
+ # Create GeoDataFrame
+ gdf = targets.copy()
+ gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
+ gdf = gdf.set_index("cell_id")
+
+ # Add values
+ values_df = data_values.to_dataframe(name="value")
+ gdf = gdf.join(values_df, how="inner")
+
+ # Add all aggregation values for tooltip if has_agg
+ if has_agg and len(ds["aggregations"]) > 1:
+ for agg in ds["aggregations"].values:
+ agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
+ # Drop dimension columns to avoid conflicts
+ cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
+ if cols_to_drop:
+ agg_data = agg_data.drop(columns=cols_to_drop)
+ gdf = gdf.join(agg_data, how="left")
+
+ # Convert to WGS84 first
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix geometries after CRS conversion
+ if grid == "hex":
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
+
+ # Normalize values for color mapping
+ values = gdf_wgs84["value"].values
+ values_clean = values[~np.isnan(values)]
+
+ if len(values_clean) > 0:
+ vmin, vmax = np.nanpercentile(values_clean, [2, 98])
+ normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
+
+ # Apply colormap - use variable-specific colors
+ cmap = get_cmap(f"era5_{selected_var}")
+ colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
+ gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+
+ # Create GeoJSON
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ properties = {
+ "value": float(row["value"]) if not np.isnan(row["value"]) else None,
+ "fill_color": row["fill_color"],
+ }
+ # Add all aggregation values if available
+ if has_agg and len(ds["aggregations"]) > 1:
+ for agg in ds["aggregations"].values:
+ agg_col = f"agg_{agg}"
+ if agg_col in row.index:
+ properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
+
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": properties,
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=opacity,
+ stroked=True,
+ filled=True,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ pickable=True,
+ )
+
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+
+ # Build tooltip HTML for ERA5
+ if has_agg and len(ds["aggregations"]) > 1:
+ tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"]
+ tooltip_lines.append("All aggregations:
")
+ for agg in ds["aggregations"].values:
+ tooltip_lines.append(f" {agg}: {{agg_{agg}}}
")
+ tooltip_html = "".join(tooltip_lines)
+ else:
+ tooltip_html = f"{selected_var}: {{value}}"
+
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ st.pydeck_chart(deck)
+
+ # Show statistics
+ st.caption(
+ f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
+ f"Std: {np.nanstd(values_clean):.4f}"
+ )
+ else:
+ st.warning("No valid data available for selected parameters")
diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py
index d4b13e8..c1ee2d4 100644
--- a/src/entropice/dashboard/training_data_page.py
+++ b/src/entropice/dashboard/training_data_page.py
@@ -2,8 +2,19 @@
import streamlit as st
+from entropice.dashboard.plots.source_data import (
+ render_alphaearth_map,
+ render_alphaearth_overview,
+ render_alphaearth_plots,
+ render_arcticdem_map,
+ render_arcticdem_overview,
+ render_arcticdem_plots,
+ render_era5_map,
+ render_era5_overview,
+ render_era5_plots,
+)
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
-from entropice.dashboard.utils.data import load_all_training_data
+from entropice.dashboard.utils.data import load_all_training_data, load_source_data
from entropice.dataset import DatasetEnsemble
@@ -107,30 +118,134 @@ def render_training_data_page():
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
- # Load training data for all three tasks
- train_data_dict = load_all_training_data(ensemble)
+ # Create tabs for different data views
+ tab_names = ["π Labels"]
- # Calculate total samples (use binary as reference)
- total_samples = len(train_data_dict["binary"])
- train_samples = (train_data_dict["binary"].split == "train").sum().item()
- test_samples = (train_data_dict["binary"].split == "test").sum().item()
+ # Add tabs for each member
+ for member in ensemble.members:
+ if member == "AlphaEarth":
+ tab_names.append("π AlphaEarth")
+ elif member == "ArcticDEM":
+ tab_names.append("ποΈ ArcticDEM")
+ elif member.startswith("ERA5"):
+ # Group ERA5 temporal variants
+ if "π‘οΈ ERA5" not in tab_names:
+ tab_names.append("π‘οΈ ERA5")
- st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
+ tabs = st.tabs(tab_names)
- # Render distribution histograms
- st.markdown("---")
- render_all_distribution_histograms(train_data_dict)
+ # Labels tab
+ with tabs[0]:
+ st.markdown("### Target Labels Distribution and Spatial Visualization")
- st.markdown("---")
+ # Load training data for all three tasks
+ with st.spinner("Loading training data for all tasks..."):
+ train_data_dict = load_all_training_data(ensemble)
- # Render spatial map (as a fragment for efficient re-rendering)
- # Extract geometries from the X.data dataframe (which has geometry as a column)
- # The index should be cell_id
- binary_dataset = train_data_dict["binary"]
- assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
+ # Calculate total samples (use binary as reference)
+ total_samples = len(train_data_dict["binary"])
+ train_samples = (train_data_dict["binary"].split == "train").sum().item()
+ test_samples = (train_data_dict["binary"].split == "test").sum().item()
- render_spatial_map(train_data_dict)
+ st.success(
+ f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
+ )
+
+ # Render distribution histograms
+ st.markdown("---")
+ render_all_distribution_histograms(train_data_dict)
+
+ st.markdown("---")
+
+ # Render spatial map
+ binary_dataset = train_data_dict["binary"]
+ assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
+
+ render_spatial_map(train_data_dict)
+
+ st.balloons()
+
+ # AlphaEarth tab
+ tab_idx = 1
+ if "AlphaEarth" in ensemble.members:
+ with tabs[tab_idx]:
+ st.markdown("### AlphaEarth Embeddings Analysis")
+
+ with st.spinner("Loading AlphaEarth data..."):
+ alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
+
+ st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
+
+ render_alphaearth_overview(alphaearth_ds)
+ render_alphaearth_plots(alphaearth_ds)
+
+ st.markdown("---")
+
+ render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
+
+ st.balloons()
+
+ tab_idx += 1
+
+ # ArcticDEM tab
+ if "ArcticDEM" in ensemble.members:
+ with tabs[tab_idx]:
+ st.markdown("### ArcticDEM Terrain Analysis")
+
+ with st.spinner("Loading ArcticDEM data..."):
+ arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
+
+ st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
+
+ render_arcticdem_overview(arcticdem_ds)
+ render_arcticdem_plots(arcticdem_ds)
+
+ st.markdown("---")
+
+ render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
+
+ st.balloons()
+
+ tab_idx += 1
+
+ # ERA5 tab (combining all temporal variants)
+ era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
+ if era5_members:
+ with tabs[tab_idx]:
+ st.markdown("### ERA5 Climate Data Analysis")
+
+ # Let user select which ERA5 temporal aggregation to view
+ era5_options = {
+ "ERA5-yearly": "Yearly",
+ "ERA5-seasonal": "Seasonal (Winter/Summer)",
+ "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
+ }
+
+ available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
+
+ selected_era5 = st.selectbox(
+ "Select ERA5 temporal aggregation",
+ options=list(available_era5.keys()),
+ format_func=lambda x: available_era5[x],
+ key="era5_temporal_select",
+ )
+
+ if selected_era5:
+ temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
+
+ with st.spinner(f"Loading {selected_era5} data..."):
+ era5_ds, targets = load_source_data(ensemble, selected_era5)
+
+ st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
+
+ render_era5_overview(era5_ds, temporal_type)
+ render_era5_plots(era5_ds, temporal_type)
+
+ st.markdown("---")
+
+ render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
+
+ st.balloons()
- # Add more components and visualizations as needed for training data.
else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py
index ac0824e..70c0a21 100644
--- a/src/entropice/dashboard/utils/data.py
+++ b/src/entropice/dashboard/utils/data.py
@@ -99,3 +99,23 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
"count": e.create_cat_training_dataset("count"),
"density": e.create_cat_training_dataset("density"),
}
+
+
+@st.cache_data
+def load_source_data(e: DatasetEnsemble, source: str):
+ """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
+
+ Args:
+ e: DatasetEnsemble object.
+ source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
+
+ Returns:
+ xarray.Dataset with the raw data for the specified source.
+
+ """
+ targets = e._read_target()
+
+ # Load the member data lazily to get metadata
+ ds = e._read_member(source, targets, lazy=False)
+
+ return ds, targets