diff --git a/src/entropice/dashboard/plots/embeddings.py b/src/entropice/dashboard/plots/embeddings.py
new file mode 100644
index 0000000..3166df0
--- /dev/null
+++ b/src/entropice/dashboard/plots/embeddings.py
@@ -0,0 +1,376 @@
+"""Render the AlphaEarth visualization tab."""
+
+import geopandas as gpd
+import matplotlib.colors as mcolors
+import numpy as np
+import plotly.graph_objects as go
+import pydeck as pdk
+import xarray as xr
+
+from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
+from entropice.dashboard.utils.geometry import fix_hex_geometry
+
+
+def create_embedding_map(
+ embedding_values: xr.DataArray,
+ grid_gdf: gpd.GeoDataFrame,
+ make_3d_map: bool,
+) -> pdk.Deck:
+ """Create a spatial distribution map for AlphaEarth embeddings.
+
+ Args:
+ embedding_values (xr.DataArray): DataArray containing the already filtered AlphaEarth embeddings.
+ grid_gdf (gpd.GeoDataFrame): GeoDataFrame containing grid cell geometries.
+ make_3d_map (bool): Whether to render the map in 3D (extruded) or 2D.
+
+ Returns:
+ pdk.Deck: A PyDeck map visualization of the AlphaEarth embeddings.
+
+ """
+ # Subsample if too many cells for performance
+ n_cells = len(embedding_values["cell_ids"])
+ if n_cells > 100000:
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
+ cell_indices = rng.choice(n_cells, size=100000, replace=False)
+ embedding_values = embedding_values.isel(cell_ids=cell_indices)
+
+ # Create a copy to avoid modifying the original
+ gdf = grid_gdf.copy().to_crs("EPSG:4326")
+
+ # Convert to DataFrame for easier merging
+ embedding_df = embedding_values.to_dataframe(name="embedding_value")
+
+ # Reset index if cell_id is already the index
+ if gdf.index.name == "cell_id":
+ gdf = gdf.reset_index()
+
+ # Filter grid to only cells that have embedding data
+ gdf = gdf[gdf["cell_id"].isin(embedding_df.index)]
+ gdf = gdf.set_index("cell_id")
+
+ # Merge embedding values with grid geometries
+ gdf = gdf.join(embedding_df, how="inner")
+
+ # Convert to WGS84 for pydeck
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix antimeridian issues for hex cells
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
+
+ # Get colormap for embeddings
+ cmap = get_cmap("AlphaEarth")
+
+ # Normalize the embedding values to [0, 1] for color mapping
+ # Use percentiles to avoid outliers
+ values = gdf_wgs84["embedding_value"].values
+ vmin, vmax = np.nanpercentile(values, [2, 98])
+
+ if vmax > vmin:
+ normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
+ else:
+ normalized_values = np.zeros_like(values)
+
+ # Map normalized values to colors
+ colors = [cmap(val) for val in normalized_values]
+ rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
+ gdf_wgs84["fill_color"] = rgb_colors
+
+ # Store embedding value for tooltip
+ gdf_wgs84["embedding_value_display"] = values
+
+ # Store normalized values for elevation (if 3D)
+ gdf_wgs84["elevation"] = normalized_values
+
+ # Convert to GeoJSON format
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": {
+ "fill_color": row["fill_color"],
+ "embedding_value": float(row["embedding_value_display"]),
+ "elevation": float(row["elevation"]) if make_3d_map else 0,
+ },
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=0.7,
+ stroked=True,
+ filled=True,
+ extruded=make_3d_map,
+ wireframe=False,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ get_elevation="properties.elevation" if make_3d_map else 0,
+ elevation_scale=500000, # Scale normalized values (0-1) to 500km height
+ pickable=True,
+ )
+
+ # Set initial view state (centered on the Arctic)
+ # Adjust pitch and zoom based on whether we're using 3D
+ view_state = pdk.ViewState(
+ latitude=70,
+ longitude=0,
+ zoom=2 if not make_3d_map else 1.5,
+ pitch=0 if not make_3d_map else 45,
+ )
+
+ # Build tooltip HTML
+ tooltip_html = "Embedding Value: {embedding_value}"
+
+ # Create deck
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={
+ "html": tooltip_html,
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ return deck
+
+
+def create_embedding_trend_plot(embedding_values: xr.DataArray) -> go.Figure:
+ """Create a trend plot for AlphaEarth embeddings over time.
+
+ Contains a line plot with shaded areas representing the 10th to 90th percentiles.
+ Min and Max values are marked through a dashed line.
+
+ Args:
+ embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings with a 'year' dimension.
+
+ Returns:
+ go.Figure: A Plotly figure showing the trend of embeddings over time.
+
+ """
+ # Subsample if too many cells for performance
+ n_cells = len(embedding_values["cell_ids"])
+ if n_cells > 10000:
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
+ cell_indices = rng.choice(n_cells, size=10000, replace=False)
+ embedding_values = embedding_values.isel(cell_ids=cell_indices)
+
+ # Calculate statistics over space (cell_ids) for each year
+ years = embedding_values["year"].values
+
+ # Compute statistics across cells for each year
+ mean_values = embedding_values.mean(dim="cell_ids").to_numpy()
+ min_values = embedding_values.min(dim="cell_ids").to_numpy()
+ max_values = embedding_values.max(dim="cell_ids").to_numpy()
+ p10_values = embedding_values.quantile(0.10, dim="cell_ids").to_numpy()
+ p90_values = embedding_values.quantile(0.90, dim="cell_ids").to_numpy()
+
+ fig = go.Figure()
+
+ # Add min/max range first (background) - dashed lines
+ fig.add_trace(
+ go.Scatter(
+ x=years,
+ y=min_values,
+ mode="lines",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ name="Min/Max Range",
+ showlegend=True,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=years,
+ y=max_values,
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(200, 200, 200, 0.1)",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ showlegend=False,
+ )
+ )
+
+ # Add 10th-90th percentile band
+ fig.add_trace(
+ go.Scatter(
+ x=years,
+ y=p10_values,
+ mode="lines",
+ line={"width": 0},
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=years,
+ y=p90_values,
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
+ line={"width": 0},
+ name="10th-90th Percentile",
+ )
+ )
+
+ # Add mean line on top
+ fig.add_trace(
+ go.Scatter(
+ x=years,
+ y=mean_values,
+ mode="lines+markers",
+ name="Mean",
+ line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
+ marker={"size": 6},
+ )
+ )
+
+ fig.update_layout(
+ title="Embedding Values Over Time (Spatial Statistics)",
+ xaxis_title="Year",
+ yaxis_title="Embedding Value",
+ yaxis={"zeroline": True, "zerolinewidth": 2, "zerolinecolor": "gray"},
+ height=450,
+ hovermode="x unified",
+ legend={
+ "orientation": "h",
+ "yanchor": "bottom",
+ "y": 1.02,
+ "xanchor": "right",
+ "x": 1,
+ },
+ )
+
+ # Format x-axis to show years as integers
+ fig.update_xaxes(dtick=1)
+
+ return fig
+
+
+def create_embedding_distribution_plot(embedding_values: xr.DataArray) -> go.Figure:
+ """Create a distribution plot showing the min/max, 10th/90th percentiles, and mean of AlphaEarth bands.
+
+ Args:
+ embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings.
+
+ Returns:
+ go.Figure: A Plotly figure showing the distribution of embeddings.
+
+ """
+ # Subsample if too many cells for performance
+ n_cells = len(embedding_values["cell_ids"])
+ if n_cells > 10000:
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
+ cell_indices = rng.choice(n_cells, size=10000, replace=False)
+ embedding_values = embedding_values.isel(cell_ids=cell_indices)
+
+ # Get band dimension
+ bands = embedding_values["band"].values
+
+ # Calculate statistics for each band across all cells
+ band_stats = []
+ for band in bands:
+ band_data = embedding_values.sel(band=band).values.flatten()
+ # Remove NaN values
+ band_data = band_data[~np.isnan(band_data)]
+
+ if len(band_data) > 0:
+ band_stats.append(
+ {
+ "Band": str(band),
+ "Mean": float(np.mean(band_data)),
+ "Min": float(np.min(band_data)),
+ "Max": float(np.max(band_data)),
+ "P10": float(np.percentile(band_data, 10)),
+ "P90": float(np.percentile(band_data, 90)),
+ }
+ )
+
+ # Create DataFrame from statistics
+ import pandas as pd
+
+ band_df = pd.DataFrame(band_stats)
+
+ fig = go.Figure()
+
+ # Add min/max range first (background) - dashed lines
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Min"],
+ mode="lines",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ name="Min/Max Range",
+ showlegend=True,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Max"],
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(200, 200, 200, 0.1)",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ showlegend=False,
+ )
+ )
+
+ # Add 10th-90th percentile band
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["P10"],
+ mode="lines",
+ line={"width": 0},
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["P90"],
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
+ line={"width": 0},
+ name="10th-90th Percentile",
+ )
+ )
+
+ # Add mean line on top
+ fig.add_trace(
+ go.Scatter(
+ x=band_df["Band"],
+ y=band_df["Mean"],
+ mode="lines+markers",
+ name="Mean",
+ line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
+ marker={"size": 4},
+ )
+ )
+
+ fig.update_layout(
+ title="Embedding Distribution by Band (Spatial Statistics)",
+ xaxis_title="Band",
+ yaxis_title="Embedding Value",
+ height=450,
+ hovermode="x unified",
+ legend={
+ "orientation": "h",
+ "yanchor": "bottom",
+ "y": 1.02,
+ "xanchor": "right",
+ "x": 1,
+ },
+ )
+
+ return fig
diff --git a/src/entropice/dashboard/sections/alphaearth.py b/src/entropice/dashboard/sections/alphaearth.py
new file mode 100644
index 0000000..58cf9c4
--- /dev/null
+++ b/src/entropice/dashboard/sections/alphaearth.py
@@ -0,0 +1,238 @@
+"""AlphaEarth embeddings dashboard section."""
+
+import geopandas as gpd
+import matplotlib.colors as mcolors
+import numpy as np
+import streamlit as st
+import xarray as xr
+
+from entropice.dashboard.plots.embeddings import (
+ create_embedding_distribution_plot,
+ create_embedding_map,
+ create_embedding_trend_plot,
+)
+from entropice.dashboard.sections.dataset_statistics import render_member_details
+from entropice.dashboard.utils.colors import get_cmap
+from entropice.dashboard.utils.stats import MemberStatistics
+
+
+def _get_band_agg_options(embeddings: xr.Dataset):
+ """Get band and aggregation selection options from user."""
+ bands = embeddings["band"].values.tolist()
+ aggregations = embeddings["agg"].values.tolist()
+
+ cols = st.columns([2, 2])
+ with cols[0]:
+ band = st.selectbox(
+ "Select Embedding Band",
+ options=bands,
+ index=0,
+ help="Select which embedding band to visualize on the map.",
+ key="embedding_band_select",
+ )
+ with cols[1]:
+ aggregation = st.selectbox(
+ "Select Aggregation Method",
+ options=aggregations,
+ index=0,
+ help="Select the aggregation method for the embeddings to visualize.",
+ key="embedding_agg_select",
+ )
+
+ return band, aggregation
+
+
+@st.fragment
+def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame):
+ st.subheader("AlphaEarth Embedding Map")
+
+ st.markdown(
+ """
+ This interactive map visualizes the spatial distribution of the selected embedding band across
+ the Arctic region. Each grid cell is colored according to its embedding value, revealing spatial
+ patterns in the satellite imagery features. High-resolution embeddings can indicate areas with
+ distinctive characteristics that may be relevant for RTS detection.
+
+ **Map controls:**
+ - **Hover** over cells to see exact embedding values
+ - **3D mode**: Elevation represents embedding magnitude - higher areas have larger values
+ - **Rotate** (3D mode): Hold Ctrl/Cmd and drag to rotate the view
+ - **Zoom/Pan**: Scroll to zoom, click and drag to pan
+ """
+ )
+
+ cols = st.columns([4, 1])
+ with cols[0]:
+ if "year" in embedding_values.dims or "year" in embedding_values.coords:
+ year_values = embedding_values["year"].values.tolist()
+ year = st.slider(
+ "Select Year",
+ min_value=int(min(year_values)),
+ max_value=int(max(year_values)),
+ value=int(max(year_values)),
+ step=1,
+ help="Select the year for which to visualize the embeddings.",
+ )
+ embedding_values = embedding_values.sel(year=year)
+ with cols[1]:
+ make_3d_map = st.checkbox("3D Map", value=True)
+
+ # Check if subsampling will occur
+ n_cells = len(embedding_values["cell_ids"])
+ if n_cells > 100000:
+ st.info(f"πΊοΈ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
+
+ map_deck = create_embedding_map(
+ embedding_values=embedding_values,
+ grid_gdf=grid_gdf,
+ make_3d_map=make_3d_map,
+ )
+
+ st.pydeck_chart(map_deck, width="stretch")
+
+ # Add legend
+ with st.expander("Legend", expanded=True):
+ st.markdown("**Embedding Value**")
+
+ # Get the actual values to show accurate min/max (same as in the map function)
+ values = embedding_values.values.flatten()
+ values = values[~np.isnan(values)]
+ vmin, vmax = np.nanpercentile(values, [2, 98])
+
+ vmin_str = f"{vmin:.4f}"
+ vmax_str = f"{vmax:.4f}"
+
+ # Get the same colormap used in the map
+ cmap = get_cmap("AlphaEarth")
+ # Sample 4 colors from the colormap to create the gradient
+ gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
+ gradient_css = ", ".join(gradient_colors)
+
+ # Create a simple gradient legend
+ st.markdown(
+ f'
'
+ f'
{vmin_str}'
+ f'
'
+ f'
{vmax_str}'
+ f"
",
+ unsafe_allow_html=True,
+ )
+
+ st.caption(
+ "Color intensity represents embedding values from low (green) to high (yellow). "
+ "Values are normalized using the 2nd-98th percentile range to avoid outliers."
+ )
+
+ if make_3d_map:
+ st.markdown("---")
+ st.markdown("**3D Elevation:**")
+ st.caption(
+ "Height represents normalized embedding values. Rotate the map by holding Ctrl/Cmd and dragging."
+ )
+
+
+def _render_trend(embeddin_values: xr.DataArray):
+ st.subheader("AlphaEarth Embedding Trends Over Time")
+
+ st.markdown(
+ """
+ This visualization shows how embedding values have changed over time across the study area.
+ The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal
+ patterns in the satellite imagery embeddings that may correlate with environmental changes.
+
+ **Understanding the plot:**
+ - **Mean line** (dark green): Average embedding value across all grid cells for each year
+ - **10th-90th percentile band** (light green): Range containing 80% of the values, showing
+ typical variation
+ - **Min/Max range** (gray): Full extent of values, highlighting outliers
+ """
+ )
+
+ # Show dataset filtering info
+ band_val = embeddin_values["band"].values.item() if embeddin_values["band"].size == 1 else "multiple"
+ agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
+ st.caption(
+ f"π **Dataset selection:** Band `{band_val}`, Aggregation `{agg_val}` "
+ f"({len(embeddin_values['year'])} years, {len(embeddin_values['cell_ids']):,} cells)"
+ )
+
+ # Check if subsampling will occur
+ n_cells = len(embeddin_values["cell_ids"])
+ if n_cells > 10000:
+ st.info(
+ f"π **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
+ "for performance. Statistics remain representative."
+ )
+
+ fig = create_embedding_trend_plot(embedding_values=embeddin_values)
+ st.plotly_chart(fig, width="stretch")
+
+
+def _render_distribution(embeddin_values: xr.DataArray):
+ st.subheader("AlphaEarth Embedding Distribution")
+
+ st.markdown(
+ """
+ This plot shows the statistical distribution of embedding values across all 64 embedding
+ dimensions (bands). AlphaEarth embeddings are learned representations from satellite imagery,
+ with each band capturing different aspects of the landscape (e.g., vegetation, terrain, ice
+ cover, land use).
+
+ **Understanding the plot:**
+ - **X-axis**: Embedding bands (A00-A63), each representing a learned feature from satellite
+ imagery
+ - **Mean line** (dark green): Average value across all grid cells for each band
+ - **10th-90th percentile band** (light green): Central distribution of values, excluding
+ outliers
+ - **Min/Max range** (gray): Full value range showing extreme values
+
+ Different bands may capture different landscape features - bands with higher variance often
+ represent more spatially heterogeneous characteristics.
+ """
+ )
+
+ # Show dataset filtering info
+ agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
+ n_bands = len(embeddin_values["band"])
+ n_cells = len(embeddin_values["cell_ids"])
+ st.caption(f"π **Dataset selection:** Aggregation `{agg_val}` ({n_bands} bands, {n_cells:,} cells)")
+
+ # Check if subsampling will occur
+ if n_cells > 10000:
+ st.info(
+ f"π **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
+ "for performance. Statistics remain representative."
+ )
+
+ fig = create_embedding_distribution_plot(embedding_values=embeddin_values)
+ st.plotly_chart(fig, width="stretch")
+
+
+@st.fragment
+def render_alphaearth_tab(embeddings: xr.Dataset, grid_gdf: gpd.GeoDataFrame, member_stats: MemberStatistics):
+ """Render the AlphaEarth visualization tab.
+
+ Args:
+ embeddings: The AlphaEarth dataset member, lazily loaded.
+ grid_gdf: GeoDataFrame with grid cell geometries
+ member_stats: Statistics for the AlphaEarth member.
+
+ """
+ # Render different visualizations
+ with st.expander("AlphaEarth Embedding Statistics", expanded=True):
+ render_member_details("AlphaEarth", member_stats)
+
+ st.divider()
+
+ band, aggregation = _get_band_agg_options(embeddings)
+ embedding_values = embeddings["embeddings"].sel(agg=aggregation).compute()
+
+ _render_distribution(embedding_values)
+ st.divider()
+
+ if "year" in embedding_values.dims or "year" in embedding_values.coords:
+ _render_trend(embedding_values.sel(band=band))
+ st.divider()
+
+ _render_embedding_map(embedding_values.sel(band=band), grid_gdf)
diff --git a/src/entropice/dashboard/sections/areas.py b/src/entropice/dashboard/sections/areas.py
index dd4c732..5745679 100644
--- a/src/entropice/dashboard/sections/areas.py
+++ b/src/entropice/dashboard/sections/areas.py
@@ -23,7 +23,7 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame):
key="metric",
)
with cols[1]:
- make_3d_map = cast(bool, st.checkbox("3D Map", value=True, key="area_3d_map"))
+ make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
st.pydeck_chart(map_deck)
diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py
index 7105402..e0aa834 100644
--- a/src/entropice/dashboard/sections/dataset_statistics.py
+++ b/src/entropice/dashboard/sections/dataset_statistics.py
@@ -436,6 +436,42 @@ def _render_aggregation_selection(
return dimension_filters
+def render_member_details(member: str, member_stats: MemberStatistics):
+ """Render detailed information for a single member.
+
+ Displays variables and dimensions with styled badges.
+
+ Args:
+ member: Member dataset name
+ member_stats: Statistics for the member
+
+ """
+ st.markdown(f"### {member}")
+
+ # Variables
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats.variable_names
+ ]
+ )
+ st.markdown(vars_html, unsafe_allow_html=True)
+
+ # Dimensions
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size:,}"
+ for dim_name, dim_size in member_stats.dimensions.items()
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
+
+
def render_ensemble_details(
selected_members: list[L2SourceDataset],
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
@@ -502,33 +538,9 @@ def render_ensemble_details(
st.dataframe(details_df, hide_index=True, width="stretch")
# Individual member details
- for member, member_stats in selected_member_stats.items():
- st.markdown(f"### {member}")
-
- # Variables
- st.markdown("**Variables:**")
- vars_html = " ".join(
- [
- f'{v}'
- for v in member_stats.variable_names
- ]
- )
- st.markdown(vars_html, unsafe_allow_html=True)
-
- # Dimensions
- st.markdown("**Dimensions:**")
- dim_html = " ".join(
- [
- f''
- f"{dim_name}: {dim_size:,}"
- for dim_name, dim_size in member_stats.dimensions.items()
- ]
- )
- st.markdown(dim_html, unsafe_allow_html=True)
-
- st.markdown("---")
+ for member, stats in selected_member_stats.items():
+ render_member_details(member, stats)
+ st.divider()
def _render_configuration_summary(
diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py
index 04b3176..b627f61 100644
--- a/src/entropice/dashboard/sections/experiment_results.py
+++ b/src/entropice/dashboard/sections/experiment_results.py
@@ -52,6 +52,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Experiment",
options=["All", *experiments],
index=0,
+ key="exp_results_experiment",
)
else:
selected_experiment = "All"
@@ -61,6 +62,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Task",
options=["All", *tasks],
index=0,
+ key="exp_results_task",
)
with filter_cols[2]:
@@ -68,6 +70,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Model",
options=["All", *models],
index=0,
+ key="exp_results_model",
)
with filter_cols[3]:
@@ -75,6 +78,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Grid",
options=["All", *grids],
index=0,
+ key="exp_results_grid",
)
# Apply filters
diff --git a/src/entropice/dashboard/sections/targets.py b/src/entropice/dashboard/sections/targets.py
index 8c41cd1..1903c0f 100644
--- a/src/entropice/dashboard/sections/targets.py
+++ b/src/entropice/dashboard/sections/targets.py
@@ -170,6 +170,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
"Select Target Dataset",
options=sorted(train_data_dict.keys()),
index=0,
+ key="target_map_dataset",
),
)
with cols[1]:
@@ -179,6 +180,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
"Select Task",
options=sorted(train_data_dict[selected_target].keys()),
index=0,
+ key="target_map_task",
),
)
with cols[2]:
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index 83c9796..cc3419c 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -15,8 +15,9 @@ from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
+from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.ml.training import TrainingSettings
-from entropice.utils.types import GridConfig
+from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
def _fix_hex_geometry(geom):
@@ -239,3 +240,13 @@ def load_all_training_results() -> list[TrainingResult]:
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
+
+
+def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Task, TrainingSet]]:
+ """Load training sets for all target-task combinations in the ensemble."""
+ train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
+ for target in all_target_datasets:
+ train_data_dict[target] = {}
+ for task in all_tasks:
+ train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
+ return train_data_dict
diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py
index d061bc4..e88b83c 100644
--- a/src/entropice/dashboard/utils/stats.py
+++ b/src/entropice/dashboard/utils/stats.py
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass
from typing import Literal
import pandas as pd
+import xarray as xr
from stopuhr import stopwatch
import entropice.spatial.grids
@@ -39,11 +40,19 @@ class MemberStatistics:
size_bytes: int # Size of this member's data on disk in bytes
@classmethod
- def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
+ def compute(
+ cls,
+ e: DatasetEnsemble,
+ member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
+ ) -> dict[L2SourceDataset, "MemberStatistics"]:
"""Pre-compute the statistics for a specific dataset member."""
+ member_datasets = member_datasets or {}
member_stats = {}
for member in e.members:
- ds = e.read_member(member, lazy=True)
+ if member in member_datasets:
+ ds = member_datasets[member]
+ else:
+ ds = e.read_member(member, lazy=True)
size_bytes = ds.nbytes
n_cols_member = len(ds.data_vars)
@@ -113,7 +122,11 @@ class DatasetStatistics:
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
@classmethod
- def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
+ def from_ensemble(
+ cls,
+ e: DatasetEnsemble,
+ member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
+ ) -> "DatasetStatistics":
"""Compute dataset statistics from a DatasetEnsemble."""
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
total_cells = len(grid_gdf)
@@ -123,7 +136,7 @@ class DatasetStatistics:
# darts_mllabels does not support year-based temporal modes
continue
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
- member_statistics = MemberStatistics.compute(e)
+ member_statistics = MemberStatistics.compute(e, member_datasets=member_datasets)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
diff --git a/src/entropice/dashboard/views/dataset_page.py b/src/entropice/dashboard/views/dataset_page.py
index e2b0a19..032511d 100644
--- a/src/entropice/dashboard/views/dataset_page.py
+++ b/src/entropice/dashboard/views/dataset_page.py
@@ -3,21 +3,20 @@
from typing import cast
import streamlit as st
+import xarray as xr
from stopuhr import stopwatch
+from entropice.dashboard.sections.alphaearth import render_alphaearth_tab
from entropice.dashboard.sections.areas import render_area_information_tab
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
from entropice.dashboard.sections.targets import render_target_information_tab
+from entropice.dashboard.utils.loaders import load_training_sets
from entropice.dashboard.utils.stats import DatasetStatistics
-from entropice.ml.dataset import DatasetEnsemble, TrainingSet
+from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
GridConfig,
L2SourceDataset,
- TargetDataset,
- Task,
TemporalMode,
- all_target_datasets,
- all_tasks,
grid_configs,
)
@@ -38,6 +37,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
options=grid_options,
index=0,
help="Select the grid system and resolution level",
+ key="dataset_page_grid",
)
# Find the selected grid config
@@ -48,12 +48,13 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
"Temporal Mode",
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
index=0,
- format_func=lambda x: "Synopsis (all years)"
+ format_func=lambda x: "Synopsis (mean + trend)"
if x == "synopsis"
else "Years-as-Features"
if x == "feature"
else f"Year {x}",
help="Select temporal mode: 'synopsis' for temporal features or specific year",
+ key="dataset_page_temporal_mode",
)
# Members selection
@@ -108,23 +109,20 @@ def render_dataset_page():
st.divider()
+ member_datasets = cast(
+ dict[L2SourceDataset, xr.Dataset],
+ {member: ensemble.read_member(member, lazy=True) for member in ensemble.members},
+ )
# Render dataset statistics section
- stats = DatasetStatistics.from_ensemble(ensemble)
+ stats = DatasetStatistics.from_ensemble(ensemble, member_datasets=member_datasets)
render_ensemble_details(ensemble.members, stats.members)
st.divider()
# Load data and precompute visualizations
- # First, load for all task - target combinations the training data
- train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
- for target in all_target_datasets:
- train_data_dict[target] = {}
- for task in all_tasks:
- train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
# Preload the grid GeoDataFrame
grid_gdf = ensemble.read_grid()
- era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
# Create tabs for different data views
tab_names = ["π― Targets", "π Areas"]
# Add tabs for each member based on what's in the ensemble
@@ -132,21 +130,27 @@ def render_dataset_page():
tab_names.append("π AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("ποΈ ArcticDEM")
+ era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
if era5_members:
tab_names.append("π‘οΈ ERA5")
tabs = st.tabs(tab_names)
with tabs[0]:
st.header("π― Target Labels Visualization")
- if False: #! debug
+ if False: # ! DEBUG
+ train_data_dict = load_training_sets(ensemble)
render_target_information_tab(train_data_dict)
with tabs[1]:
st.header("π Areas Visualization")
- render_area_information_tab(grid_gdf)
+ if False: # ! DEBUG
+ render_area_information_tab(grid_gdf)
tab_index = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_index]:
st.header("π AlphaEarth Visualization")
+ alphaearth_ds = member_datasets["AlphaEarth"]
+ alphaearth_stats = stats.members["AlphaEarth"]
+ render_alphaearth_tab(alphaearth_ds, grid_gdf, alphaearth_stats)
tab_index += 1
if "ArcticDEM" in ensemble.members:
with tabs[tab_index]: