diff --git a/src/entropice/dashboard/plots/climate.py b/src/entropice/dashboard/plots/climate.py new file mode 100644 index 0000000..4318a9c --- /dev/null +++ b/src/entropice/dashboard/plots/climate.py @@ -0,0 +1,421 @@ +"""Plots for visualizing ERA5 climate data.""" + +import geopandas as gpd +import matplotlib.colors as mcolors +import numpy as np +import plotly.graph_objects as go +import pydeck as pdk +import xarray as xr +from plotly.subplots import make_subplots + +from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb +from entropice.dashboard.utils.geometry import fix_hex_geometry + + +def create_climate_map( + climate_values: xr.DataArray, + grid_gdf: gpd.GeoDataFrame, + variable_name: str, + make_3d_map: bool, +) -> pdk.Deck: + """Create a spatial distribution map for ERA5 climate variables. + + Args: + climate_values: Series with cell_ids as index and climate values + grid_gdf: GeoDataFrame containing grid cell geometries + variable_name: Name of the climate variable being visualized + make_3d_map: Whether to render the map in 3D (extruded) or 2D + + Returns: + pdk.Deck: A PyDeck map visualization of the climate variable + + """ + # Subsample if too many cells for performance + n_cells = len(climate_values["cell_ids"]) + if n_cells > 100000: + rng = np.random.default_rng(42) + cell_indices = rng.choice(n_cells, size=100000, replace=False) + climate_values = climate_values.isel(cell_ids=cell_indices) + # Create a copy to avoid modifying the original + gdf = grid_gdf.copy().to_crs("EPSG:4326") + + # Convert to DataFrame for easier merging + climate_df = climate_values.to_dataframe(name="climate_value") + + # Reset index if cell_id is already the index + if gdf.index.name == "cell_id": + gdf = gdf.reset_index() + + # Filter grid to only cells that have climate data + gdf = gdf[gdf["cell_id"].isin(climate_df.index)] + gdf = gdf.set_index("cell_id") + + # Merge climate values with grid geometries + gdf = gdf.join(climate_df, how="inner") + + # Convert to WGS84 for pydeck + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues for hex cells + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) + + # Get colormap + cmap = get_cmap(variable_name) + + # Normalize the climate values to [0, 1] for color mapping + values = gdf_wgs84["climate_value"].to_numpy() + + # Use percentiles to avoid outliers + vmin, vmax = np.nanpercentile(values, [2, 98]) + if vmax > vmin: + normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1) + else: + normalized_values = np.zeros_like(values) + + # Map normalized values to colors + colors = [cmap(val) for val in normalized_values] + rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors] + gdf_wgs84["fill_color"] = rgb_colors + + # Store climate value for tooltip + gdf_wgs84["climate_value_display"] = values + + # Store normalized values for elevation (if 3D) + gdf_wgs84["elevation"] = normalized_values + + # Convert to GeoJSON format + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "fill_color": row["fill_color"], + "climate_value": float(row["climate_value_display"]), + "elevation": float(row["elevation"]) if make_3d_map else 0, + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=0.7, + stroked=True, + filled=True, + extruded=make_3d_map, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation" if make_3d_map else 0, + elevation_scale=500000, + pickable=True, + ) + + # Set initial view state + view_state = pdk.ViewState( + latitude=70, + longitude=0, + zoom=2 if not make_3d_map else 1.5, + pitch=0 if not make_3d_map else 45, + ) + + # Build tooltip HTML + tooltip_html = f"{variable_name}: {{climate_value}}" + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + return deck + + +def create_climate_trend_plot(climate_data: xr.DataArray, variable_name: str) -> go.Figure: + """Create a trend plot for climate variables over time. + + Args: + climate_data: DataArray containing climate variable with 'year' dimension + variable_name: Name of the variable being plotted + + Returns: + Plotly Figure with trend plot + + """ + # Subsample if too many cells for performance + n_cells = len(climate_data["cell_ids"]) + if n_cells > 10000: + rng = np.random.default_rng(42) + cell_indices = rng.choice(n_cells, size=10000, replace=False) + climate_data = climate_data.isel({"cell_ids": cell_indices}) + + # Get years + years = climate_data["year"].to_numpy() + + # Calculate statistics over space for each year + mean_values = climate_data.mean(dim="cell_ids").to_numpy() + min_values = climate_data.min(dim="cell_ids").to_numpy() + max_values = climate_data.max(dim="cell_ids").to_numpy() + p10_values = climate_data.quantile(0.10, dim="cell_ids").to_numpy() + p90_values = climate_data.quantile(0.90, dim="cell_ids").to_numpy() + + fig = go.Figure() + + # Add min/max range (background) + fig.add_trace( + go.Scatter( + x=years, + y=min_values, + mode="lines", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + name="Min/Max Range", + showlegend=True, + ) + ) + + fig.add_trace( + go.Scatter( + x=years, + y=max_values, + mode="lines", + fill="tonexty", + fillcolor="rgba(200, 200, 200, 0.1)", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + showlegend=False, + ) + ) + + # Add 10th-90th percentile band + fig.add_trace( + go.Scatter( + x=years, + y=p10_values, + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=years, + y=p90_values, + mode="lines", + fill="tonexty", + fillcolor="rgba(33, 150, 243, 0.2)", + line={"width": 0}, + name="10th-90th Percentile", + ) + ) + + # Add mean line + fig.add_trace( + go.Scatter( + x=years, + y=mean_values, + mode="lines+markers", + name="Mean", + line={"color": "#1976D2", "width": 2}, + marker={"size": 6}, + ) + ) + + fig.update_layout( + title=f"{variable_name} Over Time (Spatial Statistics)", + xaxis_title="Year", + yaxis_title=variable_name, + height=450, + hovermode="x unified", + legend={ + "orientation": "h", + "yanchor": "bottom", + "y": 1.02, + "xanchor": "right", + "x": 1, + }, + ) + + fig.update_xaxes(dtick=1) + + return fig + + +def create_climate_distribution_plot(climate_ds: xr.Dataset, variables: list[str]) -> go.Figure: + """Create distribution plots for climate variables. + + Args: + climate_ds: Xarray Dataset containing ERA5 climate features + variables: List of variable names to plot + + Returns: + Plotly Figure with distribution plots + + """ + # Subsample if too many cells for performance + n_cells = len(climate_ds.cell_ids) + if n_cells > 10000: + rng = np.random.default_rng(42) + cell_indices = rng.choice(n_cells, size=10000, replace=False) + climate_ds = climate_ds.isel(cell_ids=cell_indices) + + # Create subplots - one row per variable + n_rows = len(variables) + fig = make_subplots( + rows=n_rows, + cols=1, + subplot_titles=[v.replace("_", " ").title() for v in variables], + vertical_spacing=0.15 / max(n_rows, 1), + ) + + # Color palette + colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + ] + + for row_idx, variable in enumerate(variables, start=1): + if variable not in climate_ds.data_vars: + continue + + # Get data + data = climate_ds[variable] + values = data.to_numpy().flatten() + values = values[~np.isnan(values)] + + if len(values) == 0: + continue + + # Add violin plot + fig.add_trace( + go.Violin( + y=values, + name=variable.replace("_", " ").title(), + box_visible=True, + meanline_visible=True, + line_color=colors[row_idx % len(colors)], + showlegend=False, + ), + row=row_idx, + col=1, + ) + + # Update layout + fig.update_layout( + height=300 * n_rows, + title_text="Climate Variable Distributions", + showlegend=False, + ) + + # Update y-axes labels + for i in range(n_rows): + fig.update_yaxes(title_text="Value", row=i + 1, col=1) + + return fig + + +def create_temperature_comparison_plot(climate_ds: xr.Dataset) -> go.Figure: + """Create a comparison plot for temperature-related variables. + + Args: + climate_ds: Xarray Dataset containing temperature variables + + Returns: + Plotly Figure with temperature comparison + + """ + # Temperature variables to compare + temp_vars = ["t2m_max", "t2m_mean", "t2m_min"] + available_vars = [v for v in temp_vars if v in climate_ds.data_vars] + + if not available_vars: + return go.Figure() + + # Get years + years = climate_ds["year"].to_numpy() + + fig = go.Figure() + + colors = {"t2m_max": "#d32f2f", "t2m_mean": "#1976d2", "t2m_min": "#0288d1"} + + for var in available_vars: + # Calculate spatial mean for each year + values = climate_ds[var].mean(dim="cell_ids").to_numpy() + + fig.add_trace( + go.Scatter( + x=years, + y=values - 273.15, # Convert to Celsius + mode="lines+markers", + name=var.replace("_", " ").title(), + line={"color": colors.get(var, "#666666"), "width": 2}, + marker={"size": 4}, + ) + ) + + # Add freezing point reference line + fig.add_hline(y=0, line_dash="dash", line_color="gray", annotation_text="Freezing Point (0Β°C)") + + fig.update_layout( + title="Temperature Extremes Over Time", + xaxis_title="Year", + yaxis_title="Temperature (Β°C)", + height=450, + hovermode="x unified", + ) + + return fig + + +def create_seasonal_pattern_plot(climate_ds: xr.Dataset, variable: str) -> go.Figure: + """Create a plot showing seasonal patterns. + + Args: + climate_ds: Xarray Dataset with month dimension + variable: Variable name to plot + + Returns: + Plotly Figure with seasonal patterns + + """ + if variable not in climate_ds.data_vars or "month" not in climate_ds.dims: + return go.Figure() + + # Get unique months/seasons + months = climate_ds["month"].to_numpy() + + # Calculate mean across space and years for each month + values = climate_ds[variable].mean(dim=["cell_ids", "year"]).to_numpy() + + fig = go.Figure() + + fig.add_trace( + go.Bar( + x=months, + y=values, + marker_color="#1976d2", + ) + ) + + fig.update_layout( + title=f"{variable.replace('_', ' ').title()} by Season", + xaxis_title="Season", + yaxis_title=variable.replace("_", " ").title(), + height=400, + ) + + return fig diff --git a/src/entropice/dashboard/plots/terrain.py b/src/entropice/dashboard/plots/terrain.py new file mode 100644 index 0000000..8794227 --- /dev/null +++ b/src/entropice/dashboard/plots/terrain.py @@ -0,0 +1,418 @@ +"""Plots for visualizing ArcticDEM terrain features.""" + +import geopandas as gpd +import matplotlib.colors as mcolors +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pydeck as pdk +import xarray as xr +from plotly.subplots import make_subplots + +from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb +from entropice.dashboard.utils.geometry import fix_hex_geometry + + +def create_terrain_map( + terrain_values: pd.Series, + grid_gdf: gpd.GeoDataFrame, + variable_name: str, + make_3d_map: bool, +) -> pdk.Deck: + """Create a spatial distribution map for ArcticDEM terrain features. + + Args: + terrain_values: Series with cell_ids as index and terrain values + grid_gdf: GeoDataFrame containing grid cell geometries + variable_name: Name of the terrain variable being visualized + make_3d_map: Whether to render the map in 3D (extruded) or 2D + + Returns: + pdk.Deck: A PyDeck map visualization of the terrain feature + + """ + # Subsample if too many cells for performance + n_cells = len(terrain_values) + if n_cells > 100000: + rng = np.random.default_rng(42) + cell_indices = rng.choice(n_cells, size=100000, replace=False) + terrain_values = terrain_values.iloc[cell_indices] + + # Create a copy to avoid modifying the original + gdf = grid_gdf.copy().to_crs("EPSG:4326") + + # Reset index if cell_id is already the index + if gdf.index.name == "cell_id": + gdf = gdf.reset_index() + + # Filter grid to only cells that have terrain data + gdf = gdf[gdf["cell_id"].isin(terrain_values.index)] + gdf = gdf.set_index("cell_id") + + # Merge terrain values with grid geometries + gdf = gdf.join(terrain_values.to_frame("terrain_value"), how="inner") + + # Convert to WGS84 for pydeck + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues for hex cells + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) + + # Get colormap - use special colormap for aspect (circular) + if "aspect" in variable_name.lower(): + cmap = get_cmap("aspect") + else: + cmap = get_cmap("terrain") + + # Normalize the terrain values to [0, 1] for color mapping + values = gdf_wgs84["terrain_value"].to_numpy() + + # Handle aspect specially (0-360 degrees, circular) + if "aspect" in variable_name.lower(): + vmin, vmax = 0, 360 + normalized_values = values / 360 + else: + # Use percentiles to avoid outliers + vmin, vmax = np.nanpercentile(values, [2, 98]) + if vmax > vmin: + normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1) + else: + normalized_values = np.zeros_like(values) + + # Map normalized values to colors + colors = [cmap(val) for val in normalized_values] + rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors] + gdf_wgs84["fill_color"] = rgb_colors + + # Store terrain value for tooltip + gdf_wgs84["terrain_value_display"] = values + + # Store normalized values for elevation (if 3D) + gdf_wgs84["elevation"] = normalized_values + + # Convert to GeoJSON format + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "fill_color": row["fill_color"], + "terrain_value": float(row["terrain_value_display"]), + "elevation": float(row["elevation"]) if make_3d_map else 0, + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=0.7, + stroked=True, + filled=True, + extruded=make_3d_map, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation" if make_3d_map else 0, + elevation_scale=500000, + pickable=True, + ) + + # Set initial view state + view_state = pdk.ViewState( + latitude=70, + longitude=0, + zoom=2 if not make_3d_map else 1.5, + pitch=0 if not make_3d_map else 45, + ) + + # Build tooltip HTML + tooltip_html = f"{variable_name}: {{terrain_value}}" + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + return deck + + +def create_terrain_distribution_plot(arcticdem_ds: xr.Dataset, features: list[str]) -> go.Figure: + """Create distribution plots for terrain features. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + features: List of feature names to plot + + Returns: + Plotly Figure with distribution plots + + """ + # Subsample if too many cells for performance + n_cells = len(arcticdem_ds.cell_ids) + if n_cells > 10000: + rng = np.random.default_rng(42) + cell_indices = rng.choice(n_cells, size=10000, replace=False) + arcticdem_ds = arcticdem_ds.isel(cell_ids=cell_indices) + + # Determine aggregation types available + aggs = list(arcticdem_ds.coords["aggregations"].values) + + # Create subplots - one row per aggregation type + n_rows = len(aggs) + fig = make_subplots( + rows=n_rows, + cols=1, + subplot_titles=[f"{agg.title()} Values" for agg in aggs], + vertical_spacing=0.15 / max(n_rows, 1), + ) + + # Color palette for features + colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", + ] + + for row_idx, agg in enumerate(aggs, start=1): + for feat_idx, feature in enumerate(features): + # Get the data for this feature and aggregation + var_name = f"{feature}" + if var_name not in arcticdem_ds.data_vars: + continue + + # Extract values for this aggregation + data = arcticdem_ds[var_name].sel(aggregations=agg) + values = data.to_numpy().flatten() + values = values[~np.isnan(values)] + + if len(values) == 0: + continue + + # Add violin plot + fig.add_trace( + go.Violin( + y=values, + name=feature.replace("_", " ").title(), + box_visible=True, + meanline_visible=True, + line_color=colors[feat_idx % len(colors)], + showlegend=(row_idx == 1), + ), + row=row_idx, + col=1, + ) + + # Update layout + fig.update_layout( + height=300 * n_rows, + title_text="Terrain Feature Distributions by Aggregation Type", + showlegend=True, + ) + + # Update y-axes labels + for i in range(n_rows): + fig.update_yaxes(title_text="Value", row=i + 1, col=1) + + return fig + + +def create_aspect_rose_diagram(aspect_values: np.ndarray) -> go.Figure: + """Create a rose diagram (circular histogram) for aspect values. + + Args: + aspect_values: Array of aspect values in degrees (0-360) + + Returns: + Plotly Figure with rose diagram + + """ + # Remove NaN values + aspect_values = aspect_values[~np.isnan(aspect_values)] + + # Subsample if too many values for performance + if len(aspect_values) > 50000: + rng = np.random.default_rng(42) + indices = rng.choice(len(aspect_values), size=50000, replace=False) + aspect_values = aspect_values[indices] + + if len(aspect_values) == 0: + # Return empty figure + return go.Figure() + + # Create bins for aspect (every 10 degrees) + bins = np.arange(0, 361, 10) + bin_counts, _ = np.histogram(aspect_values, bins=bins) + + # Calculate bin centers in degrees and radians + bin_centers_deg = (bins[:-1] + bins[1:]) / 2 + bin_centers_rad = np.deg2rad(bin_centers_deg) + + # Close the circle + bin_centers_rad = np.append(bin_centers_rad, bin_centers_rad[0]) + bin_counts = np.append(bin_counts, bin_counts[0]) + + # Create polar bar chart + fig = go.Figure() + + fig.add_trace( + go.Barpolar( + r=bin_counts, + theta=np.rad2deg(bin_centers_rad), + width=10, + marker={ + "color": bin_centers_rad, + "colorscale": "HSV", + "cmin": 0, + "cmax": 2 * np.pi, + "showscale": False, + }, + ) + ) + + fig.update_layout( + title="Aspect Rose Diagram", + polar={ + "radialaxis": {"title": "Frequency", "showticklabels": True}, + "angularaxis": { + "direction": "clockwise", + "rotation": 90, + "tickmode": "array", + "tickvals": [0, 45, 90, 135, 180, 225, 270, 315], + "ticktext": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + }, + }, + showlegend=False, + height=500, + ) + + return fig + + +def create_slope_aspect_scatter(slope_values: np.ndarray, aspect_values: np.ndarray) -> go.Figure: + """Create a scatter plot showing the relationship between slope and aspect. + + Args: + slope_values: Array of slope values + aspect_values: Array of aspect values in degrees (0-360) + + Returns: + Plotly Figure with scatter plot + + """ + # Create DataFrame and remove NaN + df = pd.DataFrame({"slope": slope_values.flatten(), "aspect": aspect_values.flatten()}) + df = df.dropna() + + if len(df) == 0: + return go.Figure() + + # Sample if too many points for performance + if len(df) > 50000: + df = df.sample(n=50000, random_state=42) + + # Create 2D histogram (density plot) + fig = go.Figure() + + fig.add_trace( + go.Histogram2d( + x=df["aspect"], + y=df["slope"], + colorscale="Viridis", + nbinsx=36, # 10-degree bins for aspect + nbinsy=50, + ) + ) + + fig.update_layout( + title="Slope vs Aspect Distribution", + xaxis_title="Aspect (degrees)", + yaxis_title="Slope (degrees)", + height=500, + ) + + # Add directional labels on x-axis + fig.update_xaxes( + tickmode="array", + tickvals=[0, 45, 90, 135, 180, 225, 270, 315, 360], + ticktext=["N", "NE", "E", "SE", "S", "SW", "W", "NW", "N"], + ) + + return fig + + +def create_correlation_heatmap(arcticdem_ds: xr.Dataset, features: list[str], agg: str) -> go.Figure: + """Create a correlation heatmap for terrain features. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + features: List of feature names to include + agg: Aggregation type to use + + Returns: + Plotly Figure with correlation heatmap + + """ + # Extract data for each feature + data_dict = {} + for feature in features: + var_name = f"{feature}" + if var_name in arcticdem_ds.data_vars: + data = arcticdem_ds[var_name].sel(aggregations=agg) + values = data.to_numpy().flatten() + data_dict[feature] = values + + if not data_dict: + return go.Figure() + + # Create DataFrame + df = pd.DataFrame(data_dict) + df = df.dropna() + + # Sample if too many rows + if len(df) > 50000: + df = df.sample(n=50000, random_state=42) + + # Calculate correlation matrix + corr = df.corr() + + # Create heatmap + fig = go.Figure( + data=go.Heatmap( + z=corr.values, + x=[f.replace("_", " ").title() for f in corr.columns], + y=[f.replace("_", " ").title() for f in corr.index], + colorscale="RdBu", + zmid=0, + zmin=-1, + zmax=1, + text=np.round(corr.values, 2), + texttemplate="%{text}", + textfont={"size": 10}, + colorbar={"title": "Correlation"}, + ) + ) + + fig.update_layout( + title=f"Feature Correlation Matrix ({agg.title()})", + height=600, + xaxis={"side": "bottom"}, + ) + + return fig diff --git a/src/entropice/dashboard/sections/alphaearth.py b/src/entropice/dashboard/sections/alphaearth.py index 58cf9c4..7f0862c 100644 --- a/src/entropice/dashboard/sections/alphaearth.py +++ b/src/entropice/dashboard/sections/alphaearth.py @@ -1,5 +1,7 @@ """AlphaEarth embeddings dashboard section.""" +from typing import cast + import geopandas as gpd import matplotlib.colors as mcolors import numpy as np @@ -61,21 +63,18 @@ def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataF """ ) - cols = st.columns([4, 1]) - with cols[0]: - if "year" in embedding_values.dims or "year" in embedding_values.coords: - year_values = embedding_values["year"].values.tolist() - year = st.slider( - "Select Year", - min_value=int(min(year_values)), - max_value=int(max(year_values)), - value=int(max(year_values)), - step=1, - help="Select the year for which to visualize the embeddings.", - ) - embedding_values = embedding_values.sel(year=year) - with cols[1]: - make_3d_map = st.checkbox("3D Map", value=True) + if "year" in embedding_values.dims or "year" in embedding_values.coords: + year_values = embedding_values["year"].values.tolist() + year = st.slider( + "Select Year", + min_value=int(min(year_values)), + max_value=int(max(year_values)), + value=int(max(year_values)), + step=1, + help="Select the year for which to visualize the embeddings.", + ) + embedding_values = embedding_values.sel(year=year) + make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="embedding_map_3d")) # Check if subsampling will occur n_cells = len(embedding_values["cell_ids"]) diff --git a/src/entropice/dashboard/sections/arcticdem.py b/src/entropice/dashboard/sections/arcticdem.py new file mode 100644 index 0000000..be00308 --- /dev/null +++ b/src/entropice/dashboard/sections/arcticdem.py @@ -0,0 +1,317 @@ +"""ArcticDEM terrain features dashboard section.""" + +from typing import cast + +import geopandas as gpd +import matplotlib.colors as mcolors +import streamlit as st +import xarray as xr + +from entropice.dashboard.plots.terrain import ( + create_aspect_rose_diagram, + create_correlation_heatmap, + create_slope_aspect_scatter, + create_terrain_distribution_plot, + create_terrain_map, +) +from entropice.dashboard.sections.dataset_statistics import render_member_details +from entropice.dashboard.utils.colors import get_cmap +from entropice.dashboard.utils.stats import MemberStatistics + + +@st.fragment +def _render_terrain_map(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame): + """Visualize spatial distribution of terrain features. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + grid_gdf: GeoDataFrame with grid cell geometries + + """ + st.subheader("Spatial Distribution of Terrain Features") + + # Get available features and aggregations + available_vars = list(arcticdem_ds.data_vars) + aggs = list(arcticdem_ds.coords["aggregations"].values) + + # Feature selection with grouping + feature_groups = { + "TPI (Topographic Position Index)": ["tpi_small", "tpi_medium", "tpi_large"], + "TRI (Terrain Ruggedness Index)": ["tri_small", "tri_medium", "tri_large"], + "VRM (Vector Ruggedness Measure)": ["vrm_small", "vrm_medium", "vrm_large"], + "Slope & Curvature": ["slope", "curvature"], + "Aspect": ["aspect"], + } + + # Flatten to get all features + all_features = [f for features in feature_groups.values() for f in features if f in available_vars] + + cols = st.columns([2, 2, 1]) + with cols[0]: + # Feature selection with nice formatting + selected_feature = st.selectbox( + "Terrain Feature", + options=all_features, + format_func=lambda x: x.replace("_", " ").title(), + key="terrain_feature", + ) + + with cols[1]: + # Aggregation selection + selected_agg = st.selectbox( + "Aggregation", + options=aggs, + key="terrain_agg", + ) + + with cols[2]: + st.write("\n") + make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="terrain_map_3d")) + + # Extract the data + if selected_feature not in arcticdem_ds.data_vars: + st.error(f"Feature {selected_feature} not found in dataset") + return + + terrain_data = arcticdem_ds[selected_feature].sel(aggregations=selected_agg) + terrain_series = terrain_data.to_series() + + # Check if subsampling will occur + n_cells = len(terrain_series) + if n_cells > 100000: + st.info(f"πŸ—ΊοΈ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.") + + # Create map + map_deck = create_terrain_map(terrain_series, grid_gdf, selected_feature, make_3d_map) + st.pydeck_chart(map_deck) + + # Add legend + with st.expander("Legend", expanded=True): + st.markdown(f"**{selected_feature.replace('_', ' ').title()} ({selected_agg})**") + + values = terrain_series.dropna() + if len(values) > 0: + vmin, vmax = values.min(), values.max() + vmean = values.mean() + vstd = values.std() + + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Min", f"{vmin:.2f}") + with col2: + st.metric("Mean", f"{vmean:.2f}") + with col3: + st.metric("Max", f"{vmax:.2f}") + with col4: + st.metric("Std Dev", f"{vstd:.2f}") + + # Color scale visualization + if "aspect" in selected_feature.lower(): + cmap = get_cmap("aspect") + else: + cmap = get_cmap("terrain") + + gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]] + gradient_css = ", ".join(gradient_colors) + + gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;" + st.markdown( + f""" +
+
+ {vmin:.2f} + {vmax:.2f} +
+
+
+ """, + unsafe_allow_html=True, + ) + else: + st.warning("No data available for the selected feature and aggregation") + + +@st.fragment +def _render_terrain_distributions(arcticdem_ds: xr.Dataset): + """Display distribution plots for terrain features. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + + """ + st.subheader("Terrain Feature Distributions") + + # Get available features + available_vars = list(arcticdem_ds.data_vars) + + # Feature group selection + feature_groups = { + "Topographic Position (TPI)": ["tpi_small", "tpi_medium", "tpi_large"], + "Terrain Ruggedness (TRI)": ["tri_small", "tri_medium", "tri_large"], + "Vector Ruggedness (VRM)": ["vrm_small", "vrm_medium", "vrm_large"], + "All Scale-Invariant": ["slope", "aspect", "curvature"], + } + + selected_group = st.selectbox( + "Feature Group", + options=list(feature_groups.keys()), + key="terrain_dist_group", + ) + + features_to_plot = [f for f in feature_groups[selected_group] if f in available_vars] + + if features_to_plot: + # Check if subsampling will occur + n_cells = len(arcticdem_ds.cell_ids) + if n_cells > 10000: + st.info( + f"πŸ“Š **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} " + "for performance. Statistics remain representative." + ) + + fig = create_terrain_distribution_plot(arcticdem_ds, features_to_plot) + st.plotly_chart(fig, width="stretch") + + st.markdown( + f""" + Distribution of **{selected_group}** features across different aggregation types. + Violin plots show the full distribution with embedded box plots for quartiles. + """ + ) + else: + st.warning(f"No features available for {selected_group}") + + +@st.fragment +def _render_terrain_correlations(arcticdem_ds: xr.Dataset): + """Display correlation analysis for terrain features. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + + """ + st.subheader("Feature Correlation Analysis") + + # Get available features + available_vars = list(arcticdem_ds.data_vars) + aggs = list(arcticdem_ds.coords["aggregations"].values) + + # Select aggregation for correlation + selected_agg = st.selectbox( + "Aggregation Type", + options=aggs, + key="terrain_corr_agg", + ) + + # Select which features to include + all_features = [str(f) for f in available_vars if f in available_vars] + + if len(all_features) >= 2: + fig = create_correlation_heatmap(arcticdem_ds, all_features, selected_agg) + st.plotly_chart(fig, width="stretch") + + st.markdown( + """ + Correlation heatmap shows relationships between different terrain features. + Strong positive correlations (red) or negative correlations (blue) can indicate + related terrain characteristics. + """ + ) + else: + st.warning("Need at least 2 features for correlation analysis") + + +@st.fragment +def _render_slope_aspect_analysis(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame): + """Specialized visualization for slope and aspect relationships. + + Args: + arcticdem_ds: Xarray Dataset containing ArcticDEM terrain features + grid_gdf: GeoDataFrame with grid cell geometries + + """ + st.subheader("Slope & Aspect Analysis") + + # Check if slope and aspect are available + if "slope" not in arcticdem_ds.data_vars or "aspect" not in arcticdem_ds.data_vars: + st.warning("Slope and/or aspect data not available in this dataset") + return + + # Get aggregations + aggs = list(arcticdem_ds.coords["aggregations"].values) + + # Select aggregation + selected_agg = st.selectbox( + "Aggregation Type", + options=aggs, + key="slope_aspect_agg", + ) + + # Extract slope and aspect data + slope_data = arcticdem_ds["slope"].sel(aggregations=selected_agg).values + aspect_data = arcticdem_ds["aspect"].sel(aggregations=selected_agg).values + + # Check if subsampling will occur + n_values = len(slope_data.flatten()) + if n_values > 50000: + st.info( + f"πŸ“Š **Dataset subsampled:** Using 50,000 randomly selected values out of {n_values:,} " + "for performance. Distributions remain representative." + ) + + # Create two columns for visualizations + col1, col2 = st.columns(2) + + with col1: + st.markdown("**Aspect Rose Diagram**") + fig_rose = create_aspect_rose_diagram(aspect_data) + st.plotly_chart(fig_rose, width="stretch") + + st.markdown( + """ + The rose diagram shows the directional distribution of terrain aspect. + Each bar represents the frequency of slopes facing that direction. + """ + ) + + with col2: + st.markdown("**Slope vs Aspect Distribution**") + fig_scatter = create_slope_aspect_scatter(slope_data, aspect_data) + st.plotly_chart(fig_scatter, width="stretch") + + st.markdown( + """ + This 2D histogram shows how slope steepness relates to aspect direction. + Patterns can reveal preferential slope orientations (e.g., due to prevailing winds or sun exposure). + """ + ) + + +def render_arcticdem_tab(arcticdem_ds: xr.Dataset, grid_gdf: gpd.GeoDataFrame, arcticdem_stats: MemberStatistics): + """Render the ArcticDEM visualization tab. + + Args: + arcticdem_ds: The ArcticDEM dataset member, lazily loaded. + grid_gdf: GeoDataFrame with grid cell geometries + arcticdem_stats: Statistics for the ArcticDEM member. + + """ + # Render different visualizations + with st.expander("ArcticDEM Statistics", expanded=True): + render_member_details("ArcticDEM", arcticdem_stats) + + st.divider() + + _render_terrain_map(arcticdem_ds, grid_gdf) + + # st.divider() + + # _render_terrain_distributions(arcticdem_ds) + + st.divider() + + _render_terrain_correlations(arcticdem_ds) + + st.divider() + + _render_slope_aspect_analysis(arcticdem_ds, grid_gdf) diff --git a/src/entropice/dashboard/sections/areas.py b/src/entropice/dashboard/sections/areas.py index 5745679..678a8a1 100644 --- a/src/entropice/dashboard/sections/areas.py +++ b/src/entropice/dashboard/sections/areas.py @@ -14,16 +14,14 @@ from entropice.dashboard.utils.colors import get_cmap def _render_area_map(grid_gdf: gpd.GeoDataFrame): st.subheader("Spatial Distribution of Grid Cell Areas") - cols = st.columns([4, 1]) - with cols[0]: - metric = st.selectbox( - "Metric", - options=["cell_area", "land_area", "water_area", "land_ratio"], - format_func=lambda x: x.replace("_", " ").title(), - key="metric", - ) - with cols[1]: - make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) + metric = st.selectbox( + "Metric", + options=["cell_area", "land_area", "water_area", "land_ratio"], + format_func=lambda x: x.replace("_", " ").title(), + key="metric", + ) + + make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="area_map_3d")) map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map) st.pydeck_chart(map_deck) diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py index e0aa834..4d0048a 100644 --- a/src/entropice/dashboard/sections/dataset_statistics.py +++ b/src/entropice/dashboard/sections/dataset_statistics.py @@ -446,8 +446,6 @@ def render_member_details(member: str, member_stats: MemberStatistics): member_stats: Statistics for the member """ - st.markdown(f"### {member}") - # Variables st.markdown("**Variables:**") vars_html = " ".join( @@ -539,6 +537,7 @@ def render_ensemble_details( # Individual member details for member, stats in selected_member_stats.items(): + st.markdown(f"### {member}") render_member_details(member, stats) st.divider() diff --git a/src/entropice/dashboard/sections/era5.py b/src/entropice/dashboard/sections/era5.py new file mode 100644 index 0000000..39570ee --- /dev/null +++ b/src/entropice/dashboard/sections/era5.py @@ -0,0 +1,219 @@ +"""ERA5 climate data dashboard section.""" + +from typing import cast + +import geopandas as gpd +import matplotlib.colors as mcolors +import numpy as np +import streamlit as st +import xarray as xr + +from entropice.dashboard.plots.climate import ( + create_climate_map, + create_climate_trend_plot, +) +from entropice.dashboard.sections.dataset_statistics import render_member_details +from entropice.dashboard.utils.colors import get_cmap +from entropice.dashboard.utils.stats import MemberStatistics + + +def _get_climate_variable_aggregation_season_option(era5_ds: xr.Dataset): + available_vars = list(era5_ds.data_vars) + selected_var = st.selectbox("Select Climate Variable", available_vars, key="climate_var_map") + if "month" in era5_ds.dims: + months = era5_ds.month.to_numpy() + selected_month = cast(str, st.selectbox("Select Season", options=months.tolist(), key="climate_month")) + else: + selected_month = None + if "aggregations" in era5_ds.dims: + aggs = era5_ds.aggregations.to_numpy() + selected_agg = cast(str, st.selectbox("Select Aggregation Method", options=aggs.tolist(), key="climate_agg")) + else: + selected_agg = None + return selected_var, selected_agg, selected_month + + +@st.fragment +def _render_climate_variable_map(climate_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame, selected_var: str): + """Visualize spatial distribution of climate variables. + + Args: + climate_values: Xarray DataArray containing ERA5 climate features + grid_gdf: GeoDataFrame with grid cell geometries + selected_var: Name of the selected climate variable + + """ + st.subheader("Spatial Distribution of Climate Variables") + + if "year" in climate_values.dims: + years = climate_values.year.to_numpy() + selected_year = cast( + int, st.select_slider("Select Year", options=years.tolist(), value=int(years.max()), key="climate_year") + ) + climate_values = climate_values.sel(year=selected_year) + + # 3D toggle + make_3d = cast(bool, st.toggle("3D Map", value=True, key="climate_map_3d")) + + # Create map + n_cells = len(climate_values) + if n_cells > 100000: + st.info(f"Showing 100,000 / {n_cells:,} cells for performance") + + deck = create_climate_map(climate_values, grid_gdf, selected_var, make_3d) + st.pydeck_chart(deck, use_container_width=True) + + # Add legend + with st.expander("Legend", expanded=True): + st.markdown(f"**{selected_var.replace('_', ' ').title()}**") + + # Get the actual values to show accurate min/max (same as in the map function) + values = climate_values.values.flatten() + values = values[~np.isnan(values)] + vmin, vmax = np.nanpercentile(values, [2, 98]) + + vmin_str = f"{vmin:.4f}" + vmax_str = f"{vmax:.4f}" + + # Color scale visualization - use appropriate colormap + cmap = get_cmap(selected_var) + + gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]] + gradient_css = ", ".join(gradient_colors) + + gradient_style = f"height: 20px; background: linear-gradient(to right, {gradient_css}); border-radius: 4px;" + st.markdown( + f""" +
+
+ {vmin_str} + {vmax_str} +
+
+
+ """, + unsafe_allow_html=True, + ) + + st.caption( + "Color intensity represents climate values from low to high. " + "Values are normalized using the 2nd-98th percentile range to avoid outliers." + ) + + if make_3d: + st.markdown("---") + st.markdown("**3D Elevation:**") + st.caption("Height represents normalized climate values. Rotate the map by holding Ctrl/Cmd and dragging.") + + +def _render_climate_temporal_trends( + climate_values: xr.DataArray, selected_var: str, selected_agg: str | None, selected_month: str | None +): + """Display temporal trends of climate variables. + + Args: + climate_values: Xarray DataArray containing ERA5 climate features + selected_var: Name of the selected climate variable + selected_agg: Selected aggregation method + selected_month: Selected month/season + + """ + st.subheader("Climate Variable Temporal Trends") + + if "year" not in climate_values.dims: + st.info("Temporal trends require yearly dimension") + return + + st.markdown( + """ + This visualization shows how climate variables have changed over time across the study area. + The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal + patterns in climate variables that may correlate with environmental changes. + + **Understanding the plot:** + - **Mean line** (blue): Average value across all grid cells for each year + - **10th-90th percentile band** (light blue): Range containing 80% of the values, showing + typical variation + - **Min/Max range** (gray): Full extent of values, highlighting outliers + """ + ) + # Show dataset filtering info + n_years = len(climate_values["year"]) + n_cells = len(climate_values["cell_ids"]) + + agg_info = f", Aggregation: `{selected_agg}`" if selected_agg else "" + month_info = f", Season: `{selected_month}`" if selected_month else "" + st.caption( + f"πŸ“Š **Dataset selection:** Variable: `{selected_var}` " + f"({n_years} years, {n_cells:,} cells{agg_info}{month_info})" + ) + + # Check if subsampling will occur + if n_cells > 10000: + st.info( + f"πŸ“Š **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} " + "for performance. Statistics remain representative." + ) + + # Create trend plot + fig = create_climate_trend_plot(climate_values, selected_var) + st.plotly_chart(fig, use_container_width=True) + + +@st.fragment +def render_era5_tab( + era5_member_datasets: dict[str, xr.Dataset], + grid_gdf: gpd.GeoDataFrame, + era5_member_stats: dict[str, MemberStatistics], +): + """Render the ERA5 visualization tab. + + Args: + era5_member_datasets: Dictionary of ERA5 member datasets (yearly, seasonal, shoulder) + grid_gdf: GeoDataFrame with grid cell geometries + era5_member_stats: Dictionary of MemberStatistics for each ERA5 member. + + """ + if not era5_member_datasets: + st.warning("No ERA5 data available") + return + + # Member selection + available_members = list(era5_member_datasets.keys()) + + if len(available_members) == 1: + selected_member = available_members[0] + st.info(f"Showing ERA5 data: {selected_member}") + else: + selected_member = st.selectbox( + "Select ERA5 Aggregation Type", + available_members, + format_func=lambda x: x.replace("_", " ").title(), + key="era5_member", + ) + + era5_stats = era5_member_stats[selected_member] + # Load selected dataset + era5_ds = era5_member_datasets[selected_member] + + # Render different visualizations + with st.expander(f"{selected_member.replace('_', ' ').title()} Statistics", expanded=True): + render_member_details(selected_member, era5_stats) + + st.divider() + + selected_var, selected_agg, selected_month = _get_climate_variable_aggregation_season_option(era5_ds) + + climate_values = era5_ds[selected_var] + if selected_agg: + climate_values = climate_values.sel(aggregations=selected_agg) + if selected_month: + climate_values = climate_values.sel(month=selected_month) + climate_values = climate_values.compute() + + _render_climate_variable_map(climate_values, grid_gdf, selected_var) + + if "year" in climate_values.dims: + st.divider() + + _render_climate_temporal_trends(climate_values, selected_var, selected_agg, selected_month) diff --git a/src/entropice/dashboard/sections/targets.py b/src/entropice/dashboard/sections/targets.py index 1903c0f..659141b 100644 --- a/src/entropice/dashboard/sections/targets.py +++ b/src/entropice/dashboard/sections/targets.py @@ -185,9 +185,9 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS ) with cols[2]: # Controls weather a 3D map or a 2D map is shown - make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) + make_3d_map = cast(bool, st.toggle("3D Map", value=True, key="target_map_3d")) # Controls what should be shows, either the split or the labels / values - show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False)) + show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False, key="target_map_show_split")) training_set = train_data_dict[selected_target][selected_task] map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task) diff --git a/src/entropice/dashboard/views/dataset_page.py b/src/entropice/dashboard/views/dataset_page.py index 032511d..a6c8340 100644 --- a/src/entropice/dashboard/views/dataset_page.py +++ b/src/entropice/dashboard/views/dataset_page.py @@ -1,14 +1,16 @@ """Data page: Visualization of the data.""" -from typing import cast +from typing import Literal, cast import streamlit as st import xarray as xr from stopuhr import stopwatch from entropice.dashboard.sections.alphaearth import render_alphaearth_tab +from entropice.dashboard.sections.arcticdem import render_arcticdem_tab from entropice.dashboard.sections.areas import render_area_information_tab from entropice.dashboard.sections.dataset_statistics import render_ensemble_details +from entropice.dashboard.sections.era5 import render_era5_tab from entropice.dashboard.sections.targets import render_target_information_tab from entropice.dashboard.utils.loaders import load_training_sets from entropice.dashboard.utils.stats import DatasetStatistics @@ -130,20 +132,21 @@ def render_dataset_page(): tab_names.append("🌍 AlphaEarth") if "ArcticDEM" in ensemble.members: tab_names.append("πŸ”οΈ ArcticDEM") - era5_members = [m for m in ensemble.members if m.startswith("ERA5")] + era5_members = cast( + list[Literal["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]], + [m for m in ensemble.members if m.startswith("ERA5")], + ) if era5_members: tab_names.append("🌑️ ERA5") tabs = st.tabs(tab_names) with tabs[0]: st.header("🎯 Target Labels Visualization") - if False: # ! DEBUG - train_data_dict = load_training_sets(ensemble) - render_target_information_tab(train_data_dict) + train_data_dict = load_training_sets(ensemble) + render_target_information_tab(train_data_dict) with tabs[1]: st.header("πŸ“ Areas Visualization") - if False: # ! DEBUG - render_area_information_tab(grid_gdf) + render_area_information_tab(grid_gdf) tab_index = 2 if "AlphaEarth" in ensemble.members: with tabs[tab_index]: @@ -155,10 +158,16 @@ def render_dataset_page(): if "ArcticDEM" in ensemble.members: with tabs[tab_index]: st.header("πŸ”οΈ ArcticDEM Visualization") + arcticdem_ds = member_datasets["ArcticDEM"].compute() + arcticdem_stats = stats.members["ArcticDEM"] + render_arcticdem_tab(arcticdem_ds, grid_gdf, arcticdem_stats) tab_index += 1 if era5_members: with tabs[tab_index]: st.header("🌑️ ERA5 Visualization") + era5_member_dataset = {m: member_datasets[m] for m in era5_members} + era5_member_stats = {m: stats.members[m] for m in era5_members} + render_era5_tab(era5_member_dataset, grid_gdf, era5_member_stats) st.balloons() stopwatch.summary()