diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py new file mode 100644 index 0000000..6801050 --- /dev/null +++ b/src/entropice/dashboard/plots/source_data.py @@ -0,0 +1,798 @@ +"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5).""" + +import antimeridian +import geopandas as gpd +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pydeck as pdk +import streamlit as st +import xarray as xr +from shapely.geometry import shape + +from entropice.dashboard.plots.colors import get_cmap + + +def _fix_hex_geometry(geom): + """Fix hexagon geometry crossing the antimeridian.""" + try: + return shape(antimeridian.fix_shape(geom)) + except ValueError as e: + st.error(f"Error fixing geometry: {e}") + return geom + + +def render_alphaearth_overview(ds: xr.Dataset): + """Render overview statistics for AlphaEarth embeddings data. + + Args: + ds: xarray Dataset containing AlphaEarth embeddings. + + """ + st.subheader("πŸ“Š AlphaEarth Embeddings Statistics") + + # Overall statistics + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + + with col2: + st.metric("Embedding Dimensions", f"{len(ds['band'])}") + + with col3: + st.metric("Years Available", f"{len(ds['year'])}") + + with col4: + st.metric("Aggregations", f"{len(ds['agg'])}") + + # Show temporal coverage + st.markdown("**Temporal Coverage:**") + years = sorted(ds["year"].values) + st.write(f"Years: {min(years)} - {max(years)}") + + # Show aggregations + st.markdown("**Available Aggregations:**") + aggs = ds["agg"].to_numpy() + st.write(", ".join(str(a) for a in aggs)) + + +@st.fragment +def render_alphaearth_plots(ds: xr.Dataset): + """Render interactive plots for AlphaEarth embeddings data. + + Args: + ds: xarray Dataset containing AlphaEarth embeddings. + + """ + st.markdown("---") + st.markdown("**Embedding Distribution by Band**") + + embeddings_data = ds["embeddings"] + + # Select year and aggregation for visualization + col1, col2 = st.columns(2) + with col1: + selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year") + with col2: + selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg") + + # Get data for selected year and aggregation + year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg) + + # Calculate statistics for each band + band_stats = [] + for band_idx in range(len(ds["band"])): + band_data = year_agg_data.isel(band=band_idx).values.flatten() + band_data = band_data[~np.isnan(band_data)] # Remove NaN values + + if len(band_data) > 0: + band_stats.append( + { + "Band": band_idx, + "Mean": float(np.mean(band_data)), + "Std": float(np.std(band_data)), + "Min": float(np.min(band_data)), + "25%": float(np.percentile(band_data, 25)), + "Median": float(np.median(band_data)), + "75%": float(np.percentile(band_data, 75)), + "Max": float(np.max(band_data)), + } + ) + + band_df = pd.DataFrame(band_stats) + + # Create plot showing distribution across bands + fig = go.Figure() + + # Add min/max range first (background) + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Min"], + mode="lines", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + name="Min/Max Range", + showlegend=True, + ) + ) + + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Max"], + mode="lines", + fill="tonexty", + fillcolor="rgba(200, 200, 200, 0.1)", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + showlegend=False, + ) + ) + + # Add std band + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Mean"] - band_df["Std"], + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Mean"] + band_df["Std"], + mode="lines", + fill="tonexty", + fillcolor="rgba(31, 119, 180, 0.2)", + line={"width": 0}, + name="Β±1 Std", + ) + ) + + # Add mean line on top + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Mean"], + mode="lines+markers", + name="Mean", + line={"color": "#1f77b4", "width": 2}, + marker={"size": 4}, + ) + ) + + fig.update_layout( + title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})", + xaxis_title="Band", + yaxis_title="Embedding Value", + height=450, + hovermode="x unified", + ) + + st.plotly_chart(fig, use_container_width=True) + + # Band statistics + with st.expander("πŸ“ˆ Statistics by Embedding Band", expanded=False): + st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:") + + # Calculate statistics for each band + band_stats = [] + for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands + band_data = embeddings_data.isel(band=band_idx) + band_stats.append( + { + "Band": int(band_idx), + "Mean": float(band_data.mean().values), + "Std": float(band_data.std().values), + "Min": float(band_data.min().values), + "Max": float(band_data.max().values), + } + ) + + band_df = pd.DataFrame(band_stats) + st.dataframe(band_df, use_container_width=True, hide_index=True) + + if len(ds["band"]) > 10: + st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions") + + +def render_arcticdem_overview(ds: xr.Dataset): + """Render overview statistics for ArcticDEM terrain data. + + Args: + ds: xarray Dataset containing ArcticDEM data. + + """ + st.subheader("πŸ”οΈ ArcticDEM Terrain Statistics") + + # Overall statistics + col1, col2, col3 = st.columns(3) + + with col1: + st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + + with col2: + st.metric("Variables", f"{len(ds.data_vars)}") + + with col3: + st.metric("Aggregations", f"{len(ds['aggregations'])}") + + # Show available variables + st.markdown("**Available Variables:**") + variables = list(ds.data_vars) + st.write(", ".join(variables)) + + # Show aggregations + st.markdown("**Available Aggregations:**") + aggs = ds["aggregations"].to_numpy() + st.write(", ".join(str(a) for a in aggs)) + + # Statistics by variable + st.markdown("---") + st.markdown("**Variable Statistics (across all aggregations)**") + + var_stats = [] + for var_name in ds.data_vars: + var_data = ds[var_name] + var_stats.append( + { + "Variable": var_name, + "Mean": float(var_data.mean().values), + "Std": float(var_data.std().values), + "Min": float(var_data.min().values), + "Max": float(var_data.max().values), + "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values), + } + ) + + stats_df = pd.DataFrame(var_stats) + st.dataframe(stats_df, use_container_width=True, hide_index=True) + + +@st.fragment +def render_arcticdem_plots(ds: xr.Dataset): + """Render interactive plots for ArcticDEM terrain data. + + Args: + ds: xarray Dataset containing ArcticDEM data. + + """ + st.markdown("---") + st.markdown("**Variable Distributions**") + + variables = list(ds.data_vars) + + # Select a variable to visualize + selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select") + + if selected_var: + var_data = ds[selected_var] + + # Create histogram + fig = go.Figure() + + for agg in ds["aggregations"].to_numpy(): + agg_data = var_data.sel(aggregations=agg).to_numpy().flatten() + agg_data = agg_data[~np.isnan(agg_data)] + + fig.add_trace( + go.Histogram( + x=agg_data, + name=str(agg), + opacity=0.7, + nbinsx=50, + ) + ) + + fig.update_layout( + title=f"Distribution of {selected_var} by Aggregation", + xaxis_title=selected_var, + yaxis_title="Count", + barmode="overlay", + height=400, + ) + + st.plotly_chart(fig, use_container_width=True) + + +def render_era5_overview(ds: xr.Dataset, temporal_type: str): + """Render overview statistics for ERA5 climate data. + + Args: + ds: xarray Dataset containing ERA5 data. + temporal_type: One of 'yearly', 'seasonal', 'shoulder'. + + """ + st.subheader(f"🌑️ ERA5 Climate Statistics ({temporal_type.capitalize()})") + + # Overall statistics + has_agg = "aggregations" in ds.dims + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + + with col2: + st.metric("Variables", f"{len(ds.data_vars)}") + + with col3: + st.metric("Time Steps", f"{len(ds['time'])}") + + with col4: + if has_agg: + st.metric("Aggregations", f"{len(ds['aggregations'])}") + else: + st.metric("Aggregations", "1") + + # Show available variables + st.markdown("**Available Variables:**") + variables = list(ds.data_vars) + st.write(", ".join(variables)) + + # Show temporal range + st.markdown("**Temporal Range:**") + time_values = pd.to_datetime(ds["time"].values) + st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}") + + if has_agg: + st.markdown("**Available Aggregations:**") + aggs = ds["aggregations"].to_numpy() + st.write(", ".join(str(a) for a in aggs)) + + # Statistics by variable + st.markdown("---") + st.markdown("**Variable Statistics (across all time steps and aggregations)**") + + var_stats = [] + for var_name in ds.data_vars: + var_data = ds[var_name] + var_stats.append( + { + "Variable": var_name, + "Mean": float(var_data.mean().values), + "Std": float(var_data.std().values), + "Min": float(var_data.min().values), + "Max": float(var_data.max().values), + "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values), + } + ) + + stats_df = pd.DataFrame(var_stats) + st.dataframe(stats_df, use_container_width=True, hide_index=True) + + +@st.fragment +def render_era5_plots(ds: xr.Dataset, temporal_type: str): + """Render interactive plots for ERA5 climate data. + + Args: + ds: xarray Dataset containing ERA5 data. + temporal_type: One of 'yearly', 'seasonal', 'shoulder'. + + """ + st.markdown("---") + st.markdown("**Temporal Trends**") + + variables = list(ds.data_vars) + has_agg = "aggregations" in ds.dims + + selected_var = st.selectbox( + "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" + ) + + if selected_var: + var_data = ds[selected_var] + + # Calculate mean over space for each time step + if has_agg: + # Average over aggregations first, then over cells + time_series = var_data.mean(dim=["cell_ids", "aggregations"]) + else: + time_series = var_data.mean(dim="cell_ids") + + time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()}) + + fig = go.Figure() + + fig.add_trace( + go.Scatter( + x=time_df["Time"], + y=time_df["Value"], + mode="lines+markers", + name=selected_var, + line={"width": 2}, + ) + ) + + fig.update_layout( + title=f"Temporal Trend of {selected_var} (Spatial Mean)", + xaxis_title="Time", + yaxis_title=selected_var, + height=400, + hovermode="x unified", + ) + + st.plotly_chart(fig, use_container_width=True) + + +@st.fragment +def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): + """Render interactive pydeck map for AlphaEarth embeddings. + + Args: + ds: xarray Dataset containing AlphaEarth data. + targets: GeoDataFrame with geometry for each cell. + grid: Grid type ('hex' or 'healpix'). + + """ + st.subheader("πŸ—ΊοΈ AlphaEarth Spatial Distribution") + + # Controls + col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) + + with col1: + selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year") + + with col2: + selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg") + + with col3: + selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band") + + with col4: + opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity") + + # Extract data for selected parameters + data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band) + + # Create GeoDataFrame + gdf = targets.copy() + gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] + gdf = gdf.set_index("cell_id") + + # Add values + values_df = data_values.to_dataframe(name="value") + gdf = gdf.join(values_df, how="inner") + + # Convert to WGS84 first + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix geometries after CRS conversion + if grid == "hex": + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Normalize values for color mapping + values = gdf_wgs84["value"].to_numpy() + vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers + normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) + + # Apply colormap + cmap = get_cmap("embeddings") + colors = [cmap(val) for val in normalized] + gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + + # Create GeoJSON + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "value": float(row["value"]), + "fill_color": row["fill_color"], + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + pickable=True, + ) + + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={"html": "Value: {value}", "style": {"backgroundColor": "steelblue", "color": "white"}}, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + st.pydeck_chart(deck) + + # Show statistics + st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}") + + +@st.fragment +def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): + """Render interactive pydeck map for ArcticDEM terrain data. + + Args: + ds: xarray Dataset containing ArcticDEM data. + targets: GeoDataFrame with geometry for each cell. + grid: Grid type ('hex' or 'healpix'). + + """ + st.subheader("πŸ—ΊοΈ ArcticDEM Spatial Distribution") + + # Controls + variables = list(ds.data_vars) + col1, col2, col3 = st.columns([3, 3, 1]) + + with col1: + selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var") + + with col2: + selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg") + + with col3: + opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity") + + # Extract data for selected parameters + data_values = ds[selected_var].sel(aggregations=selected_agg) + + # Create GeoDataFrame + gdf = targets.copy() + gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] + gdf = gdf.set_index("cell_id") + + # Add values + values_df = data_values.to_dataframe(name="value") + gdf = gdf.join(values_df, how="inner") + + # Add all aggregation values for tooltip + if len(ds["aggregations"]) > 1: + for agg in ds["aggregations"].values: + agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}") + # Drop the aggregations column to avoid conflicts + if "aggregations" in agg_data.columns: + agg_data = agg_data.drop(columns=["aggregations"]) + gdf = gdf.join(agg_data, how="left") + + # Convert to WGS84 first + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix geometries after CRS conversion + if grid == "hex": + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Normalize values for color mapping + values = gdf_wgs84["value"].values + values_clean = values[~np.isnan(values)] + + if len(values_clean) > 0: + vmin, vmax = np.nanpercentile(values_clean, [2, 98]) + normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) + + # Apply colormap + cmap = get_cmap(f"arcticdem_{selected_var}") + colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] + gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + + # Create GeoJSON + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + properties = { + "value": float(row["value"]) if not np.isnan(row["value"]) else None, + "fill_color": row["fill_color"], + } + # Add all aggregation values if available + if len(ds["aggregations"]) > 1: + for agg in ds["aggregations"].values: + agg_col = f"agg_{agg}" + if agg_col in row.index: + properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None + + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": properties, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + pickable=True, + ) + + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + + # Build tooltip HTML for ArcticDEM + if len(ds["aggregations"]) > 1: + tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"] + tooltip_lines.append("All aggregations:
") + for agg in ds["aggregations"].values: + tooltip_lines.append(f"  {agg}: {{agg_{agg}}}
") + tooltip_html = "".join(tooltip_lines) + else: + tooltip_html = f"{selected_var}: {{value}}" + + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + st.pydeck_chart(deck) + + # Show statistics + st.caption( + f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | " + f"Std: {np.nanstd(values_clean):.2f}" + ) + else: + st.warning("No valid data available for selected parameters") + + +@st.fragment +def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str): + """Render interactive pydeck map for ERA5 climate data. + + Args: + ds: xarray Dataset containing ERA5 data. + targets: GeoDataFrame with geometry for each cell. + grid: Grid type ('hex' or 'healpix'). + temporal_type: One of 'yearly', 'seasonal', 'shoulder'. + + """ + st.subheader("πŸ—ΊοΈ ERA5 Spatial Distribution") + + # Controls + variables = list(ds.data_vars) + has_agg = "aggregations" in ds.dims + + if has_agg: + col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) + else: + col1, col2, col3 = st.columns([3, 3, 1]) + + with col1: + selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") + + with col2: + # Convert time to readable format + time_values = pd.to_datetime(ds["time"].values) + time_options = {str(t): t for t in time_values} + selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time") + + if has_agg: + with col3: + selected_agg = st.selectbox( + "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg" + ) + with col4: + opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") + else: + with col3: + opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") + + # Extract data for selected parameters + time_val = time_options[selected_time] + if has_agg: + data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg) + else: + data_values = ds[selected_var].sel(time=time_val) + + # Create GeoDataFrame + gdf = targets.copy() + gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] + gdf = gdf.set_index("cell_id") + + # Add values + values_df = data_values.to_dataframe(name="value") + gdf = gdf.join(values_df, how="inner") + + # Add all aggregation values for tooltip if has_agg + if has_agg and len(ds["aggregations"]) > 1: + for agg in ds["aggregations"].values: + agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}") + # Drop dimension columns to avoid conflicts + cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns] + if cols_to_drop: + agg_data = agg_data.drop(columns=cols_to_drop) + gdf = gdf.join(agg_data, how="left") + + # Convert to WGS84 first + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix geometries after CRS conversion + if grid == "hex": + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Normalize values for color mapping + values = gdf_wgs84["value"].values + values_clean = values[~np.isnan(values)] + + if len(values_clean) > 0: + vmin, vmax = np.nanpercentile(values_clean, [2, 98]) + normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) + + # Apply colormap - use variable-specific colors + cmap = get_cmap(f"era5_{selected_var}") + colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] + gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + + # Create GeoJSON + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + properties = { + "value": float(row["value"]) if not np.isnan(row["value"]) else None, + "fill_color": row["fill_color"], + } + # Add all aggregation values if available + if has_agg and len(ds["aggregations"]) > 1: + for agg in ds["aggregations"].values: + agg_col = f"agg_{agg}" + if agg_col in row.index: + properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None + + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": properties, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + pickable=True, + ) + + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + + # Build tooltip HTML for ERA5 + if has_agg and len(ds["aggregations"]) > 1: + tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"] + tooltip_lines.append("All aggregations:
") + for agg in ds["aggregations"].values: + tooltip_lines.append(f"  {agg}: {{agg_{agg}}}
") + tooltip_html = "".join(tooltip_lines) + else: + tooltip_html = f"{selected_var}: {{value}}" + + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}}, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + st.pydeck_chart(deck) + + # Show statistics + st.caption( + f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | " + f"Std: {np.nanstd(values_clean):.4f}" + ) + else: + st.warning("No valid data available for selected parameters") diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py index d4b13e8..c1ee2d4 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/training_data_page.py @@ -2,8 +2,19 @@ import streamlit as st +from entropice.dashboard.plots.source_data import ( + render_alphaearth_map, + render_alphaearth_overview, + render_alphaearth_plots, + render_arcticdem_map, + render_arcticdem_overview, + render_arcticdem_plots, + render_era5_map, + render_era5_overview, + render_era5_plots, +) from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map -from entropice.dashboard.utils.data import load_all_training_data +from entropice.dashboard.utils.data import load_all_training_data, load_source_data from entropice.dataset import DatasetEnsemble @@ -107,30 +118,134 @@ def render_training_data_page(): # Display dataset ID in a styled container st.info(f"**Dataset ID:** `{ensemble.id()}`") - # Load training data for all three tasks - train_data_dict = load_all_training_data(ensemble) + # Create tabs for different data views + tab_names = ["πŸ“Š Labels"] - # Calculate total samples (use binary as reference) - total_samples = len(train_data_dict["binary"]) - train_samples = (train_data_dict["binary"].split == "train").sum().item() - test_samples = (train_data_dict["binary"].split == "test").sum().item() + # Add tabs for each member + for member in ensemble.members: + if member == "AlphaEarth": + tab_names.append("🌍 AlphaEarth") + elif member == "ArcticDEM": + tab_names.append("πŸ”οΈ ArcticDEM") + elif member.startswith("ERA5"): + # Group ERA5 temporal variants + if "🌑️ ERA5" not in tab_names: + tab_names.append("🌑️ ERA5") - st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks") + tabs = st.tabs(tab_names) - # Render distribution histograms - st.markdown("---") - render_all_distribution_histograms(train_data_dict) + # Labels tab + with tabs[0]: + st.markdown("### Target Labels Distribution and Spatial Visualization") - st.markdown("---") + # Load training data for all three tasks + with st.spinner("Loading training data for all tasks..."): + train_data_dict = load_all_training_data(ensemble) - # Render spatial map (as a fragment for efficient re-rendering) - # Extract geometries from the X.data dataframe (which has geometry as a column) - # The index should be cell_id - binary_dataset = train_data_dict["binary"] - assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" + # Calculate total samples (use binary as reference) + total_samples = len(train_data_dict["binary"]) + train_samples = (train_data_dict["binary"].split == "train").sum().item() + test_samples = (train_data_dict["binary"].split == "test").sum().item() - render_spatial_map(train_data_dict) + st.success( + f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks" + ) + + # Render distribution histograms + st.markdown("---") + render_all_distribution_histograms(train_data_dict) + + st.markdown("---") + + # Render spatial map + binary_dataset = train_data_dict["binary"] + assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" + + render_spatial_map(train_data_dict) + + st.balloons() + + # AlphaEarth tab + tab_idx = 1 + if "AlphaEarth" in ensemble.members: + with tabs[tab_idx]: + st.markdown("### AlphaEarth Embeddings Analysis") + + with st.spinner("Loading AlphaEarth data..."): + alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth") + + st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells") + + render_alphaearth_overview(alphaearth_ds) + render_alphaearth_plots(alphaearth_ds) + + st.markdown("---") + + render_alphaearth_map(alphaearth_ds, targets, ensemble.grid) + + st.balloons() + + tab_idx += 1 + + # ArcticDEM tab + if "ArcticDEM" in ensemble.members: + with tabs[tab_idx]: + st.markdown("### ArcticDEM Terrain Analysis") + + with st.spinner("Loading ArcticDEM data..."): + arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM") + + st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells") + + render_arcticdem_overview(arcticdem_ds) + render_arcticdem_plots(arcticdem_ds) + + st.markdown("---") + + render_arcticdem_map(arcticdem_ds, targets, ensemble.grid) + + st.balloons() + + tab_idx += 1 + + # ERA5 tab (combining all temporal variants) + era5_members = [m for m in ensemble.members if m.startswith("ERA5")] + if era5_members: + with tabs[tab_idx]: + st.markdown("### ERA5 Climate Data Analysis") + + # Let user select which ERA5 temporal aggregation to view + era5_options = { + "ERA5-yearly": "Yearly", + "ERA5-seasonal": "Seasonal (Winter/Summer)", + "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)", + } + + available_era5 = {k: v for k, v in era5_options.items() if k in era5_members} + + selected_era5 = st.selectbox( + "Select ERA5 temporal aggregation", + options=list(available_era5.keys()), + format_func=lambda x: available_era5[x], + key="era5_temporal_select", + ) + + if selected_era5: + temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder' + + with st.spinner(f"Loading {selected_era5} data..."): + era5_ds, targets = load_source_data(ensemble, selected_era5) + + st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells") + + render_era5_overview(era5_ds, temporal_type) + render_era5_plots(era5_ds, temporal_type) + + st.markdown("---") + + render_era5_map(era5_ds, targets, ensemble.grid, temporal_type) + + st.balloons() - # Add more components and visualizations as needed for training data. else: st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.") diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py index ac0824e..70c0a21 100644 --- a/src/entropice/dashboard/utils/data.py +++ b/src/entropice/dashboard/utils/data.py @@ -99,3 +99,23 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD "count": e.create_cat_training_dataset("count"), "density": e.create_cat_training_dataset("density"), } + + +@st.cache_data +def load_source_data(e: DatasetEnsemble, source: str): + """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5). + + Args: + e: DatasetEnsemble object. + source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'. + + Returns: + xarray.Dataset with the raw data for the specified source. + + """ + targets = e._read_target() + + # Load the member data lazily to get metadata + ds = e._read_member(source, targets, lazy=False) + + return ds, targets