Add alphaearth tab

This commit is contained in:
Tobias Hölzer 2026-01-17 02:06:33 +01:00
parent 3581f9b80f
commit 26de80ee89
9 changed files with 709 additions and 49 deletions

View file

@ -0,0 +1,376 @@
"""Render the AlphaEarth visualization tab."""
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import plotly.graph_objects as go
import pydeck as pdk
import xarray as xr
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
from entropice.dashboard.utils.geometry import fix_hex_geometry
def create_embedding_map(
embedding_values: xr.DataArray,
grid_gdf: gpd.GeoDataFrame,
make_3d_map: bool,
) -> pdk.Deck:
"""Create a spatial distribution map for AlphaEarth embeddings.
Args:
embedding_values (xr.DataArray): DataArray containing the already filtered AlphaEarth embeddings.
grid_gdf (gpd.GeoDataFrame): GeoDataFrame containing grid cell geometries.
make_3d_map (bool): Whether to render the map in 3D (extruded) or 2D.
Returns:
pdk.Deck: A PyDeck map visualization of the AlphaEarth embeddings.
"""
# Subsample if too many cells for performance
n_cells = len(embedding_values["cell_ids"])
if n_cells > 100000:
rng = np.random.default_rng(42) # Fixed seed for reproducibility
cell_indices = rng.choice(n_cells, size=100000, replace=False)
embedding_values = embedding_values.isel(cell_ids=cell_indices)
# Create a copy to avoid modifying the original
gdf = grid_gdf.copy().to_crs("EPSG:4326")
# Convert to DataFrame for easier merging
embedding_df = embedding_values.to_dataframe(name="embedding_value")
# Reset index if cell_id is already the index
if gdf.index.name == "cell_id":
gdf = gdf.reset_index()
# Filter grid to only cells that have embedding data
gdf = gdf[gdf["cell_id"].isin(embedding_df.index)]
gdf = gdf.set_index("cell_id")
# Merge embedding values with grid geometries
gdf = gdf.join(embedding_df, how="inner")
# Convert to WGS84 for pydeck
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix antimeridian issues for hex cells
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Get colormap for embeddings
cmap = get_cmap("AlphaEarth")
# Normalize the embedding values to [0, 1] for color mapping
# Use percentiles to avoid outliers
values = gdf_wgs84["embedding_value"].values
vmin, vmax = np.nanpercentile(values, [2, 98])
if vmax > vmin:
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
else:
normalized_values = np.zeros_like(values)
# Map normalized values to colors
colors = [cmap(val) for val in normalized_values]
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
gdf_wgs84["fill_color"] = rgb_colors
# Store embedding value for tooltip
gdf_wgs84["embedding_value_display"] = values
# Store normalized values for elevation (if 3D)
gdf_wgs84["elevation"] = normalized_values
# Convert to GeoJSON format
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"fill_color": row["fill_color"],
"embedding_value": float(row["embedding_value_display"]),
"elevation": float(row["elevation"]) if make_3d_map else 0,
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=0.7,
stroked=True,
filled=True,
extruded=make_3d_map,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if make_3d_map else 0,
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
pickable=True,
)
# Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using 3D
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not make_3d_map else 1.5,
pitch=0 if not make_3d_map else 45,
)
# Build tooltip HTML
tooltip_html = "<b>Embedding Value:</b> {embedding_value}"
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
return deck
def create_embedding_trend_plot(embedding_values: xr.DataArray) -> go.Figure:
"""Create a trend plot for AlphaEarth embeddings over time.
Contains a line plot with shaded areas representing the 10th to 90th percentiles.
Min and Max values are marked through a dashed line.
Args:
embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings with a 'year' dimension.
Returns:
go.Figure: A Plotly figure showing the trend of embeddings over time.
"""
# Subsample if too many cells for performance
n_cells = len(embedding_values["cell_ids"])
if n_cells > 10000:
rng = np.random.default_rng(42) # Fixed seed for reproducibility
cell_indices = rng.choice(n_cells, size=10000, replace=False)
embedding_values = embedding_values.isel(cell_ids=cell_indices)
# Calculate statistics over space (cell_ids) for each year
years = embedding_values["year"].values
# Compute statistics across cells for each year
mean_values = embedding_values.mean(dim="cell_ids").to_numpy()
min_values = embedding_values.min(dim="cell_ids").to_numpy()
max_values = embedding_values.max(dim="cell_ids").to_numpy()
p10_values = embedding_values.quantile(0.10, dim="cell_ids").to_numpy()
p90_values = embedding_values.quantile(0.90, dim="cell_ids").to_numpy()
fig = go.Figure()
# Add min/max range first (background) - dashed lines
fig.add_trace(
go.Scatter(
x=years,
y=min_values,
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=years,
y=max_values,
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add 10th-90th percentile band
fig.add_trace(
go.Scatter(
x=years,
y=p10_values,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=years,
y=p90_values,
mode="lines",
fill="tonexty",
fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
line={"width": 0},
name="10th-90th Percentile",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=years,
y=mean_values,
mode="lines+markers",
name="Mean",
line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
marker={"size": 6},
)
)
fig.update_layout(
title="Embedding Values Over Time (Spatial Statistics)",
xaxis_title="Year",
yaxis_title="Embedding Value",
yaxis={"zeroline": True, "zerolinewidth": 2, "zerolinecolor": "gray"},
height=450,
hovermode="x unified",
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
)
# Format x-axis to show years as integers
fig.update_xaxes(dtick=1)
return fig
def create_embedding_distribution_plot(embedding_values: xr.DataArray) -> go.Figure:
"""Create a distribution plot showing the min/max, 10th/90th percentiles, and mean of AlphaEarth bands.
Args:
embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings.
Returns:
go.Figure: A Plotly figure showing the distribution of embeddings.
"""
# Subsample if too many cells for performance
n_cells = len(embedding_values["cell_ids"])
if n_cells > 10000:
rng = np.random.default_rng(42) # Fixed seed for reproducibility
cell_indices = rng.choice(n_cells, size=10000, replace=False)
embedding_values = embedding_values.isel(cell_ids=cell_indices)
# Get band dimension
bands = embedding_values["band"].values
# Calculate statistics for each band across all cells
band_stats = []
for band in bands:
band_data = embedding_values.sel(band=band).values.flatten()
# Remove NaN values
band_data = band_data[~np.isnan(band_data)]
if len(band_data) > 0:
band_stats.append(
{
"Band": str(band),
"Mean": float(np.mean(band_data)),
"Min": float(np.min(band_data)),
"Max": float(np.max(band_data)),
"P10": float(np.percentile(band_data, 10)),
"P90": float(np.percentile(band_data, 90)),
}
)
# Create DataFrame from statistics
import pandas as pd
band_df = pd.DataFrame(band_stats)
fig = go.Figure()
# Add min/max range first (background) - dashed lines
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Min"],
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Max"],
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add 10th-90th percentile band
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["P10"],
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["P90"],
mode="lines",
fill="tonexty",
fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
line={"width": 0},
name="10th-90th Percentile",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=band_df["Band"],
y=band_df["Mean"],
mode="lines+markers",
name="Mean",
line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
marker={"size": 4},
)
)
fig.update_layout(
title="Embedding Distribution by Band (Spatial Statistics)",
xaxis_title="Band",
yaxis_title="Embedding Value",
height=450,
hovermode="x unified",
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
)
return fig

View file

@ -0,0 +1,238 @@
"""AlphaEarth embeddings dashboard section."""
import geopandas as gpd
import matplotlib.colors as mcolors
import numpy as np
import streamlit as st
import xarray as xr
from entropice.dashboard.plots.embeddings import (
create_embedding_distribution_plot,
create_embedding_map,
create_embedding_trend_plot,
)
from entropice.dashboard.sections.dataset_statistics import render_member_details
from entropice.dashboard.utils.colors import get_cmap
from entropice.dashboard.utils.stats import MemberStatistics
def _get_band_agg_options(embeddings: xr.Dataset):
"""Get band and aggregation selection options from user."""
bands = embeddings["band"].values.tolist()
aggregations = embeddings["agg"].values.tolist()
cols = st.columns([2, 2])
with cols[0]:
band = st.selectbox(
"Select Embedding Band",
options=bands,
index=0,
help="Select which embedding band to visualize on the map.",
key="embedding_band_select",
)
with cols[1]:
aggregation = st.selectbox(
"Select Aggregation Method",
options=aggregations,
index=0,
help="Select the aggregation method for the embeddings to visualize.",
key="embedding_agg_select",
)
return band, aggregation
@st.fragment
def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame):
st.subheader("AlphaEarth Embedding Map")
st.markdown(
"""
This interactive map visualizes the spatial distribution of the selected embedding band across
the Arctic region. Each grid cell is colored according to its embedding value, revealing spatial
patterns in the satellite imagery features. High-resolution embeddings can indicate areas with
distinctive characteristics that may be relevant for RTS detection.
**Map controls:**
- **Hover** over cells to see exact embedding values
- **3D mode**: Elevation represents embedding magnitude - higher areas have larger values
- **Rotate** (3D mode): Hold Ctrl/Cmd and drag to rotate the view
- **Zoom/Pan**: Scroll to zoom, click and drag to pan
"""
)
cols = st.columns([4, 1])
with cols[0]:
if "year" in embedding_values.dims or "year" in embedding_values.coords:
year_values = embedding_values["year"].values.tolist()
year = st.slider(
"Select Year",
min_value=int(min(year_values)),
max_value=int(max(year_values)),
value=int(max(year_values)),
step=1,
help="Select the year for which to visualize the embeddings.",
)
embedding_values = embedding_values.sel(year=year)
with cols[1]:
make_3d_map = st.checkbox("3D Map", value=True)
# Check if subsampling will occur
n_cells = len(embedding_values["cell_ids"])
if n_cells > 100000:
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
map_deck = create_embedding_map(
embedding_values=embedding_values,
grid_gdf=grid_gdf,
make_3d_map=make_3d_map,
)
st.pydeck_chart(map_deck, width="stretch")
# Add legend
with st.expander("Legend", expanded=True):
st.markdown("**Embedding Value**")
# Get the actual values to show accurate min/max (same as in the map function)
values = embedding_values.values.flatten()
values = values[~np.isnan(values)]
vmin, vmax = np.nanpercentile(values, [2, 98])
vmin_str = f"{vmin:.4f}"
vmax_str = f"{vmax:.4f}"
# Get the same colormap used in the map
cmap = get_cmap("AlphaEarth")
# Sample 4 colors from the colormap to create the gradient
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
gradient_css = ", ".join(gradient_colors)
# Create a simple gradient legend
st.markdown(
f'<div style="display: flex; align-items: center; margin-top: 10px; margin-bottom: 10px;">'
f'<span style="margin-right: 10px;">{vmin_str}</span>'
f'<div style="flex: 1; height: 20px; background: linear-gradient(to right, '
f'{gradient_css}); border: 1px solid #ccc;"></div>'
f'<span style="margin-left: 10px;">{vmax_str}</span>'
f"</div>",
unsafe_allow_html=True,
)
st.caption(
"Color intensity represents embedding values from low (green) to high (yellow). "
"Values are normalized using the 2nd-98th percentile range to avoid outliers."
)
if make_3d_map:
st.markdown("---")
st.markdown("**3D Elevation:**")
st.caption(
"Height represents normalized embedding values. Rotate the map by holding Ctrl/Cmd and dragging."
)
def _render_trend(embeddin_values: xr.DataArray):
st.subheader("AlphaEarth Embedding Trends Over Time")
st.markdown(
"""
This visualization shows how embedding values have changed over time across the study area.
The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal
patterns in the satellite imagery embeddings that may correlate with environmental changes.
**Understanding the plot:**
- **Mean line** (dark green): Average embedding value across all grid cells for each year
- **10th-90th percentile band** (light green): Range containing 80% of the values, showing
typical variation
- **Min/Max range** (gray): Full extent of values, highlighting outliers
"""
)
# Show dataset filtering info
band_val = embeddin_values["band"].values.item() if embeddin_values["band"].size == 1 else "multiple"
agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
st.caption(
f"📊 **Dataset selection:** Band `{band_val}`, Aggregation `{agg_val}` "
f"({len(embeddin_values['year'])} years, {len(embeddin_values['cell_ids']):,} cells)"
)
# Check if subsampling will occur
n_cells = len(embeddin_values["cell_ids"])
if n_cells > 10000:
st.info(
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
"for performance. Statistics remain representative."
)
fig = create_embedding_trend_plot(embedding_values=embeddin_values)
st.plotly_chart(fig, width="stretch")
def _render_distribution(embeddin_values: xr.DataArray):
st.subheader("AlphaEarth Embedding Distribution")
st.markdown(
"""
This plot shows the statistical distribution of embedding values across all 64 embedding
dimensions (bands). AlphaEarth embeddings are learned representations from satellite imagery,
with each band capturing different aspects of the landscape (e.g., vegetation, terrain, ice
cover, land use).
**Understanding the plot:**
- **X-axis**: Embedding bands (A00-A63), each representing a learned feature from satellite
imagery
- **Mean line** (dark green): Average value across all grid cells for each band
- **10th-90th percentile band** (light green): Central distribution of values, excluding
outliers
- **Min/Max range** (gray): Full value range showing extreme values
Different bands may capture different landscape features - bands with higher variance often
represent more spatially heterogeneous characteristics.
"""
)
# Show dataset filtering info
agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
n_bands = len(embeddin_values["band"])
n_cells = len(embeddin_values["cell_ids"])
st.caption(f"📊 **Dataset selection:** Aggregation `{agg_val}` ({n_bands} bands, {n_cells:,} cells)")
# Check if subsampling will occur
if n_cells > 10000:
st.info(
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
"for performance. Statistics remain representative."
)
fig = create_embedding_distribution_plot(embedding_values=embeddin_values)
st.plotly_chart(fig, width="stretch")
@st.fragment
def render_alphaearth_tab(embeddings: xr.Dataset, grid_gdf: gpd.GeoDataFrame, member_stats: MemberStatistics):
"""Render the AlphaEarth visualization tab.
Args:
embeddings: The AlphaEarth dataset member, lazily loaded.
grid_gdf: GeoDataFrame with grid cell geometries
member_stats: Statistics for the AlphaEarth member.
"""
# Render different visualizations
with st.expander("AlphaEarth Embedding Statistics", expanded=True):
render_member_details("AlphaEarth", member_stats)
st.divider()
band, aggregation = _get_band_agg_options(embeddings)
embedding_values = embeddings["embeddings"].sel(agg=aggregation).compute()
_render_distribution(embedding_values)
st.divider()
if "year" in embedding_values.dims or "year" in embedding_values.coords:
_render_trend(embedding_values.sel(band=band))
st.divider()
_render_embedding_map(embedding_values.sel(band=band), grid_gdf)

View file

@ -23,7 +23,7 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame):
key="metric",
)
with cols[1]:
make_3d_map = cast(bool, st.checkbox("3D Map", value=True, key="area_3d_map"))
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
st.pydeck_chart(map_deck)

View file

@ -436,6 +436,42 @@ def _render_aggregation_selection(
return dimension_filters
def render_member_details(member: str, member_stats: MemberStatistics):
"""Render detailed information for a single member.
Displays variables and dimensions with styled badges.
Args:
member: Member dataset name
member_stats: Statistics for the member
"""
st.markdown(f"### {member}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size:,}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
def render_ensemble_details(
selected_members: list[L2SourceDataset],
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
@ -502,33 +538,9 @@ def render_ensemble_details(
st.dataframe(details_df, hide_index=True, width="stretch")
# Individual member details
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size:,}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
for member, stats in selected_member_stats.items():
render_member_details(member, stats)
st.divider()
def _render_configuration_summary(

View file

@ -52,6 +52,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Experiment",
options=["All", *experiments],
index=0,
key="exp_results_experiment",
)
else:
selected_experiment = "All"
@ -61,6 +62,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Task",
options=["All", *tasks],
index=0,
key="exp_results_task",
)
with filter_cols[2]:
@ -68,6 +70,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Model",
options=["All", *models],
index=0,
key="exp_results_model",
)
with filter_cols[3]:
@ -75,6 +78,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
"Grid",
options=["All", *grids],
index=0,
key="exp_results_grid",
)
# Apply filters

View file

@ -170,6 +170,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
"Select Target Dataset",
options=sorted(train_data_dict.keys()),
index=0,
key="target_map_dataset",
),
)
with cols[1]:
@ -179,6 +180,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
"Select Task",
options=sorted(train_data_dict[selected_target].keys()),
index=0,
key="target_map_task",
),
)
with cols[2]:

View file

@ -15,8 +15,9 @@ from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.ml.training import TrainingSettings
from entropice.utils.types import GridConfig
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
def _fix_hex_geometry(geom):
@ -239,3 +240,13 @@ def load_all_training_results() -> list[TrainingResult]:
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Task, TrainingSet]]:
"""Load training sets for all target-task combinations in the ensemble."""
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
for target in all_target_datasets:
train_data_dict[target] = {}
for task in all_tasks:
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
return train_data_dict

View file

@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass
from typing import Literal
import pandas as pd
import xarray as xr
from stopuhr import stopwatch
import entropice.spatial.grids
@ -39,10 +40,18 @@ class MemberStatistics:
size_bytes: int # Size of this member's data on disk in bytes
@classmethod
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
def compute(
cls,
e: DatasetEnsemble,
member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
) -> dict[L2SourceDataset, "MemberStatistics"]:
"""Pre-compute the statistics for a specific dataset member."""
member_datasets = member_datasets or {}
member_stats = {}
for member in e.members:
if member in member_datasets:
ds = member_datasets[member]
else:
ds = e.read_member(member, lazy=True)
size_bytes = ds.nbytes
@ -113,7 +122,11 @@ class DatasetStatistics:
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
@classmethod
def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
def from_ensemble(
cls,
e: DatasetEnsemble,
member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
) -> "DatasetStatistics":
"""Compute dataset statistics from a DatasetEnsemble."""
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
total_cells = len(grid_gdf)
@ -123,7 +136,7 @@ class DatasetStatistics:
# darts_mllabels does not support year-based temporal modes
continue
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
member_statistics = MemberStatistics.compute(e)
member_statistics = MemberStatistics.compute(e, member_datasets=member_datasets)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())

View file

@ -3,21 +3,20 @@
from typing import cast
import streamlit as st
import xarray as xr
from stopuhr import stopwatch
from entropice.dashboard.sections.alphaearth import render_alphaearth_tab
from entropice.dashboard.sections.areas import render_area_information_tab
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
from entropice.dashboard.sections.targets import render_target_information_tab
from entropice.dashboard.utils.loaders import load_training_sets
from entropice.dashboard.utils.stats import DatasetStatistics
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
GridConfig,
L2SourceDataset,
TargetDataset,
Task,
TemporalMode,
all_target_datasets,
all_tasks,
grid_configs,
)
@ -38,6 +37,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
options=grid_options,
index=0,
help="Select the grid system and resolution level",
key="dataset_page_grid",
)
# Find the selected grid config
@ -48,12 +48,13 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
"Temporal Mode",
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
index=0,
format_func=lambda x: "Synopsis (all years)"
format_func=lambda x: "Synopsis (mean + trend)"
if x == "synopsis"
else "Years-as-Features"
if x == "feature"
else f"Year {x}",
help="Select temporal mode: 'synopsis' for temporal features or specific year",
key="dataset_page_temporal_mode",
)
# Members selection
@ -108,23 +109,20 @@ def render_dataset_page():
st.divider()
member_datasets = cast(
dict[L2SourceDataset, xr.Dataset],
{member: ensemble.read_member(member, lazy=True) for member in ensemble.members},
)
# Render dataset statistics section
stats = DatasetStatistics.from_ensemble(ensemble)
stats = DatasetStatistics.from_ensemble(ensemble, member_datasets=member_datasets)
render_ensemble_details(ensemble.members, stats.members)
st.divider()
# Load data and precompute visualizations
# First, load for all task - target combinations the training data
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
for target in all_target_datasets:
train_data_dict[target] = {}
for task in all_tasks:
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
# Preload the grid GeoDataFrame
grid_gdf = ensemble.read_grid()
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
# Create tabs for different data views
tab_names = ["🎯 Targets", "📐 Areas"]
# Add tabs for each member based on what's in the ensemble
@ -132,21 +130,27 @@ def render_dataset_page():
tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM")
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
if era5_members:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
with tabs[0]:
st.header("🎯 Target Labels Visualization")
if False: #! debug
if False: # ! DEBUG
train_data_dict = load_training_sets(ensemble)
render_target_information_tab(train_data_dict)
with tabs[1]:
st.header("📐 Areas Visualization")
if False: # ! DEBUG
render_area_information_tab(grid_gdf)
tab_index = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_index]:
st.header("🌍 AlphaEarth Visualization")
alphaearth_ds = member_datasets["AlphaEarth"]
alphaearth_stats = stats.members["AlphaEarth"]
render_alphaearth_tab(alphaearth_ds, grid_gdf, alphaearth_stats)
tab_index += 1
if "ArcticDEM" in ensemble.members:
with tabs[tab_index]: