Add alphaearth tab
This commit is contained in:
parent
3581f9b80f
commit
26de80ee89
9 changed files with 709 additions and 49 deletions
376
src/entropice/dashboard/plots/embeddings.py
Normal file
376
src/entropice/dashboard/plots/embeddings.py
Normal file
|
|
@ -0,0 +1,376 @@
|
||||||
|
"""Render the AlphaEarth visualization tab."""
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import numpy as np
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import pydeck as pdk
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb
|
||||||
|
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_map(
|
||||||
|
embedding_values: xr.DataArray,
|
||||||
|
grid_gdf: gpd.GeoDataFrame,
|
||||||
|
make_3d_map: bool,
|
||||||
|
) -> pdk.Deck:
|
||||||
|
"""Create a spatial distribution map for AlphaEarth embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_values (xr.DataArray): DataArray containing the already filtered AlphaEarth embeddings.
|
||||||
|
grid_gdf (gpd.GeoDataFrame): GeoDataFrame containing grid cell geometries.
|
||||||
|
make_3d_map (bool): Whether to render the map in 3D (extruded) or 2D.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pdk.Deck: A PyDeck map visualization of the AlphaEarth embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Subsample if too many cells for performance
|
||||||
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
|
if n_cells > 100000:
|
||||||
|
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
||||||
|
cell_indices = rng.choice(n_cells, size=100000, replace=False)
|
||||||
|
embedding_values = embedding_values.isel(cell_ids=cell_indices)
|
||||||
|
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
gdf = grid_gdf.copy().to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Convert to DataFrame for easier merging
|
||||||
|
embedding_df = embedding_values.to_dataframe(name="embedding_value")
|
||||||
|
|
||||||
|
# Reset index if cell_id is already the index
|
||||||
|
if gdf.index.name == "cell_id":
|
||||||
|
gdf = gdf.reset_index()
|
||||||
|
|
||||||
|
# Filter grid to only cells that have embedding data
|
||||||
|
gdf = gdf[gdf["cell_id"].isin(embedding_df.index)]
|
||||||
|
gdf = gdf.set_index("cell_id")
|
||||||
|
|
||||||
|
# Merge embedding values with grid geometries
|
||||||
|
gdf = gdf.join(embedding_df, how="inner")
|
||||||
|
|
||||||
|
# Convert to WGS84 for pydeck
|
||||||
|
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix antimeridian issues for hex cells
|
||||||
|
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||||
|
|
||||||
|
# Get colormap for embeddings
|
||||||
|
cmap = get_cmap("AlphaEarth")
|
||||||
|
|
||||||
|
# Normalize the embedding values to [0, 1] for color mapping
|
||||||
|
# Use percentiles to avoid outliers
|
||||||
|
values = gdf_wgs84["embedding_value"].values
|
||||||
|
vmin, vmax = np.nanpercentile(values, [2, 98])
|
||||||
|
|
||||||
|
if vmax > vmin:
|
||||||
|
normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||||
|
else:
|
||||||
|
normalized_values = np.zeros_like(values)
|
||||||
|
|
||||||
|
# Map normalized values to colors
|
||||||
|
colors = [cmap(val) for val in normalized_values]
|
||||||
|
rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors]
|
||||||
|
gdf_wgs84["fill_color"] = rgb_colors
|
||||||
|
|
||||||
|
# Store embedding value for tooltip
|
||||||
|
gdf_wgs84["embedding_value_display"] = values
|
||||||
|
|
||||||
|
# Store normalized values for elevation (if 3D)
|
||||||
|
gdf_wgs84["elevation"] = normalized_values
|
||||||
|
|
||||||
|
# Convert to GeoJSON format
|
||||||
|
geojson_data = []
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
"embedding_value": float(row["embedding_value_display"]),
|
||||||
|
"elevation": float(row["elevation"]) if make_3d_map else 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=0.7,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
extruded=make_3d_map,
|
||||||
|
wireframe=False,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
get_elevation="properties.elevation" if make_3d_map else 0,
|
||||||
|
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set initial view state (centered on the Arctic)
|
||||||
|
# Adjust pitch and zoom based on whether we're using 3D
|
||||||
|
view_state = pdk.ViewState(
|
||||||
|
latitude=70,
|
||||||
|
longitude=0,
|
||||||
|
zoom=2 if not make_3d_map else 1.5,
|
||||||
|
pitch=0 if not make_3d_map else 45,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build tooltip HTML
|
||||||
|
tooltip_html = "<b>Embedding Value:</b> {embedding_value}"
|
||||||
|
|
||||||
|
# Create deck
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={
|
||||||
|
"html": tooltip_html,
|
||||||
|
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||||
|
},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return deck
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_trend_plot(embedding_values: xr.DataArray) -> go.Figure:
|
||||||
|
"""Create a trend plot for AlphaEarth embeddings over time.
|
||||||
|
|
||||||
|
Contains a line plot with shaded areas representing the 10th to 90th percentiles.
|
||||||
|
Min and Max values are marked through a dashed line.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings with a 'year' dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
go.Figure: A Plotly figure showing the trend of embeddings over time.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Subsample if too many cells for performance
|
||||||
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
|
if n_cells > 10000:
|
||||||
|
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
||||||
|
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
||||||
|
embedding_values = embedding_values.isel(cell_ids=cell_indices)
|
||||||
|
|
||||||
|
# Calculate statistics over space (cell_ids) for each year
|
||||||
|
years = embedding_values["year"].values
|
||||||
|
|
||||||
|
# Compute statistics across cells for each year
|
||||||
|
mean_values = embedding_values.mean(dim="cell_ids").to_numpy()
|
||||||
|
min_values = embedding_values.min(dim="cell_ids").to_numpy()
|
||||||
|
max_values = embedding_values.max(dim="cell_ids").to_numpy()
|
||||||
|
p10_values = embedding_values.quantile(0.10, dim="cell_ids").to_numpy()
|
||||||
|
p90_values = embedding_values.quantile(0.90, dim="cell_ids").to_numpy()
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Add min/max range first (background) - dashed lines
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=years,
|
||||||
|
y=min_values,
|
||||||
|
mode="lines",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
name="Min/Max Range",
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=years,
|
||||||
|
y=max_values,
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(200, 200, 200, 0.1)",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add 10th-90th percentile band
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=years,
|
||||||
|
y=p10_values,
|
||||||
|
mode="lines",
|
||||||
|
line={"width": 0},
|
||||||
|
showlegend=False,
|
||||||
|
hoverinfo="skip",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=years,
|
||||||
|
y=p90_values,
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
|
||||||
|
line={"width": 0},
|
||||||
|
name="10th-90th Percentile",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add mean line on top
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=years,
|
||||||
|
y=mean_values,
|
||||||
|
mode="lines+markers",
|
||||||
|
name="Mean",
|
||||||
|
line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
|
||||||
|
marker={"size": 6},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Embedding Values Over Time (Spatial Statistics)",
|
||||||
|
xaxis_title="Year",
|
||||||
|
yaxis_title="Embedding Value",
|
||||||
|
yaxis={"zeroline": True, "zerolinewidth": 2, "zerolinecolor": "gray"},
|
||||||
|
height=450,
|
||||||
|
hovermode="x unified",
|
||||||
|
legend={
|
||||||
|
"orientation": "h",
|
||||||
|
"yanchor": "bottom",
|
||||||
|
"y": 1.02,
|
||||||
|
"xanchor": "right",
|
||||||
|
"x": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format x-axis to show years as integers
|
||||||
|
fig.update_xaxes(dtick=1)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_distribution_plot(embedding_values: xr.DataArray) -> go.Figure:
|
||||||
|
"""Create a distribution plot showing the min/max, 10th/90th percentiles, and mean of AlphaEarth bands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
go.Figure: A Plotly figure showing the distribution of embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Subsample if too many cells for performance
|
||||||
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
|
if n_cells > 10000:
|
||||||
|
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
||||||
|
cell_indices = rng.choice(n_cells, size=10000, replace=False)
|
||||||
|
embedding_values = embedding_values.isel(cell_ids=cell_indices)
|
||||||
|
|
||||||
|
# Get band dimension
|
||||||
|
bands = embedding_values["band"].values
|
||||||
|
|
||||||
|
# Calculate statistics for each band across all cells
|
||||||
|
band_stats = []
|
||||||
|
for band in bands:
|
||||||
|
band_data = embedding_values.sel(band=band).values.flatten()
|
||||||
|
# Remove NaN values
|
||||||
|
band_data = band_data[~np.isnan(band_data)]
|
||||||
|
|
||||||
|
if len(band_data) > 0:
|
||||||
|
band_stats.append(
|
||||||
|
{
|
||||||
|
"Band": str(band),
|
||||||
|
"Mean": float(np.mean(band_data)),
|
||||||
|
"Min": float(np.min(band_data)),
|
||||||
|
"Max": float(np.max(band_data)),
|
||||||
|
"P10": float(np.percentile(band_data, 10)),
|
||||||
|
"P90": float(np.percentile(band_data, 90)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataFrame from statistics
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
band_df = pd.DataFrame(band_stats)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Add min/max range first (background) - dashed lines
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Min"],
|
||||||
|
mode="lines",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
name="Min/Max Range",
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Max"],
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(200, 200, 200, 0.1)",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add 10th-90th percentile band
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["P10"],
|
||||||
|
mode="lines",
|
||||||
|
line={"width": 0},
|
||||||
|
showlegend=False,
|
||||||
|
hoverinfo="skip",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["P90"],
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth
|
||||||
|
line={"width": 0},
|
||||||
|
name="10th-90th Percentile",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add mean line on top
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Mean"],
|
||||||
|
mode="lines+markers",
|
||||||
|
name="Mean",
|
||||||
|
line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth
|
||||||
|
marker={"size": 4},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Embedding Distribution by Band (Spatial Statistics)",
|
||||||
|
xaxis_title="Band",
|
||||||
|
yaxis_title="Embedding Value",
|
||||||
|
height=450,
|
||||||
|
hovermode="x unified",
|
||||||
|
legend={
|
||||||
|
"orientation": "h",
|
||||||
|
"yanchor": "bottom",
|
||||||
|
"y": 1.02,
|
||||||
|
"xanchor": "right",
|
||||||
|
"x": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
238
src/entropice/dashboard/sections/alphaearth.py
Normal file
238
src/entropice/dashboard/sections/alphaearth.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
"""AlphaEarth embeddings dashboard section."""
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import numpy as np
|
||||||
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.embeddings import (
|
||||||
|
create_embedding_distribution_plot,
|
||||||
|
create_embedding_map,
|
||||||
|
create_embedding_trend_plot,
|
||||||
|
)
|
||||||
|
from entropice.dashboard.sections.dataset_statistics import render_member_details
|
||||||
|
from entropice.dashboard.utils.colors import get_cmap
|
||||||
|
from entropice.dashboard.utils.stats import MemberStatistics
|
||||||
|
|
||||||
|
|
||||||
|
def _get_band_agg_options(embeddings: xr.Dataset):
|
||||||
|
"""Get band and aggregation selection options from user."""
|
||||||
|
bands = embeddings["band"].values.tolist()
|
||||||
|
aggregations = embeddings["agg"].values.tolist()
|
||||||
|
|
||||||
|
cols = st.columns([2, 2])
|
||||||
|
with cols[0]:
|
||||||
|
band = st.selectbox(
|
||||||
|
"Select Embedding Band",
|
||||||
|
options=bands,
|
||||||
|
index=0,
|
||||||
|
help="Select which embedding band to visualize on the map.",
|
||||||
|
key="embedding_band_select",
|
||||||
|
)
|
||||||
|
with cols[1]:
|
||||||
|
aggregation = st.selectbox(
|
||||||
|
"Select Aggregation Method",
|
||||||
|
options=aggregations,
|
||||||
|
index=0,
|
||||||
|
help="Select the aggregation method for the embeddings to visualize.",
|
||||||
|
key="embedding_agg_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
return band, aggregation
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame):
|
||||||
|
st.subheader("AlphaEarth Embedding Map")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This interactive map visualizes the spatial distribution of the selected embedding band across
|
||||||
|
the Arctic region. Each grid cell is colored according to its embedding value, revealing spatial
|
||||||
|
patterns in the satellite imagery features. High-resolution embeddings can indicate areas with
|
||||||
|
distinctive characteristics that may be relevant for RTS detection.
|
||||||
|
|
||||||
|
**Map controls:**
|
||||||
|
- **Hover** over cells to see exact embedding values
|
||||||
|
- **3D mode**: Elevation represents embedding magnitude - higher areas have larger values
|
||||||
|
- **Rotate** (3D mode): Hold Ctrl/Cmd and drag to rotate the view
|
||||||
|
- **Zoom/Pan**: Scroll to zoom, click and drag to pan
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cols = st.columns([4, 1])
|
||||||
|
with cols[0]:
|
||||||
|
if "year" in embedding_values.dims or "year" in embedding_values.coords:
|
||||||
|
year_values = embedding_values["year"].values.tolist()
|
||||||
|
year = st.slider(
|
||||||
|
"Select Year",
|
||||||
|
min_value=int(min(year_values)),
|
||||||
|
max_value=int(max(year_values)),
|
||||||
|
value=int(max(year_values)),
|
||||||
|
step=1,
|
||||||
|
help="Select the year for which to visualize the embeddings.",
|
||||||
|
)
|
||||||
|
embedding_values = embedding_values.sel(year=year)
|
||||||
|
with cols[1]:
|
||||||
|
make_3d_map = st.checkbox("3D Map", value=True)
|
||||||
|
|
||||||
|
# Check if subsampling will occur
|
||||||
|
n_cells = len(embedding_values["cell_ids"])
|
||||||
|
if n_cells > 100000:
|
||||||
|
st.info(f"🗺️ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.")
|
||||||
|
|
||||||
|
map_deck = create_embedding_map(
|
||||||
|
embedding_values=embedding_values,
|
||||||
|
grid_gdf=grid_gdf,
|
||||||
|
make_3d_map=make_3d_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.pydeck_chart(map_deck, width="stretch")
|
||||||
|
|
||||||
|
# Add legend
|
||||||
|
with st.expander("Legend", expanded=True):
|
||||||
|
st.markdown("**Embedding Value**")
|
||||||
|
|
||||||
|
# Get the actual values to show accurate min/max (same as in the map function)
|
||||||
|
values = embedding_values.values.flatten()
|
||||||
|
values = values[~np.isnan(values)]
|
||||||
|
vmin, vmax = np.nanpercentile(values, [2, 98])
|
||||||
|
|
||||||
|
vmin_str = f"{vmin:.4f}"
|
||||||
|
vmax_str = f"{vmax:.4f}"
|
||||||
|
|
||||||
|
# Get the same colormap used in the map
|
||||||
|
cmap = get_cmap("AlphaEarth")
|
||||||
|
# Sample 4 colors from the colormap to create the gradient
|
||||||
|
gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]]
|
||||||
|
gradient_css = ", ".join(gradient_colors)
|
||||||
|
|
||||||
|
# Create a simple gradient legend
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-top: 10px; margin-bottom: 10px;">'
|
||||||
|
f'<span style="margin-right: 10px;">{vmin_str}</span>'
|
||||||
|
f'<div style="flex: 1; height: 20px; background: linear-gradient(to right, '
|
||||||
|
f'{gradient_css}); border: 1px solid #ccc;"></div>'
|
||||||
|
f'<span style="margin-left: 10px;">{vmax_str}</span>'
|
||||||
|
f"</div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.caption(
|
||||||
|
"Color intensity represents embedding values from low (green) to high (yellow). "
|
||||||
|
"Values are normalized using the 2nd-98th percentile range to avoid outliers."
|
||||||
|
)
|
||||||
|
|
||||||
|
if make_3d_map:
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**3D Elevation:**")
|
||||||
|
st.caption(
|
||||||
|
"Height represents normalized embedding values. Rotate the map by holding Ctrl/Cmd and dragging."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_trend(embeddin_values: xr.DataArray):
|
||||||
|
st.subheader("AlphaEarth Embedding Trends Over Time")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows how embedding values have changed over time across the study area.
|
||||||
|
The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal
|
||||||
|
patterns in the satellite imagery embeddings that may correlate with environmental changes.
|
||||||
|
|
||||||
|
**Understanding the plot:**
|
||||||
|
- **Mean line** (dark green): Average embedding value across all grid cells for each year
|
||||||
|
- **10th-90th percentile band** (light green): Range containing 80% of the values, showing
|
||||||
|
typical variation
|
||||||
|
- **Min/Max range** (gray): Full extent of values, highlighting outliers
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show dataset filtering info
|
||||||
|
band_val = embeddin_values["band"].values.item() if embeddin_values["band"].size == 1 else "multiple"
|
||||||
|
agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
|
||||||
|
st.caption(
|
||||||
|
f"📊 **Dataset selection:** Band `{band_val}`, Aggregation `{agg_val}` "
|
||||||
|
f"({len(embeddin_values['year'])} years, {len(embeddin_values['cell_ids']):,} cells)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if subsampling will occur
|
||||||
|
n_cells = len(embeddin_values["cell_ids"])
|
||||||
|
if n_cells > 10000:
|
||||||
|
st.info(
|
||||||
|
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
|
||||||
|
"for performance. Statistics remain representative."
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = create_embedding_trend_plot(embedding_values=embeddin_values)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
|
def _render_distribution(embeddin_values: xr.DataArray):
|
||||||
|
st.subheader("AlphaEarth Embedding Distribution")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This plot shows the statistical distribution of embedding values across all 64 embedding
|
||||||
|
dimensions (bands). AlphaEarth embeddings are learned representations from satellite imagery,
|
||||||
|
with each band capturing different aspects of the landscape (e.g., vegetation, terrain, ice
|
||||||
|
cover, land use).
|
||||||
|
|
||||||
|
**Understanding the plot:**
|
||||||
|
- **X-axis**: Embedding bands (A00-A63), each representing a learned feature from satellite
|
||||||
|
imagery
|
||||||
|
- **Mean line** (dark green): Average value across all grid cells for each band
|
||||||
|
- **10th-90th percentile band** (light green): Central distribution of values, excluding
|
||||||
|
outliers
|
||||||
|
- **Min/Max range** (gray): Full value range showing extreme values
|
||||||
|
|
||||||
|
Different bands may capture different landscape features - bands with higher variance often
|
||||||
|
represent more spatially heterogeneous characteristics.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show dataset filtering info
|
||||||
|
agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple"
|
||||||
|
n_bands = len(embeddin_values["band"])
|
||||||
|
n_cells = len(embeddin_values["cell_ids"])
|
||||||
|
st.caption(f"📊 **Dataset selection:** Aggregation `{agg_val}` ({n_bands} bands, {n_cells:,} cells)")
|
||||||
|
|
||||||
|
# Check if subsampling will occur
|
||||||
|
if n_cells > 10000:
|
||||||
|
st.info(
|
||||||
|
f"📊 **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} "
|
||||||
|
"for performance. Statistics remain representative."
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = create_embedding_distribution_plot(embedding_values=embeddin_values)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_alphaearth_tab(embeddings: xr.Dataset, grid_gdf: gpd.GeoDataFrame, member_stats: MemberStatistics):
|
||||||
|
"""Render the AlphaEarth visualization tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: The AlphaEarth dataset member, lazily loaded.
|
||||||
|
grid_gdf: GeoDataFrame with grid cell geometries
|
||||||
|
member_stats: Statistics for the AlphaEarth member.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Render different visualizations
|
||||||
|
with st.expander("AlphaEarth Embedding Statistics", expanded=True):
|
||||||
|
render_member_details("AlphaEarth", member_stats)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
band, aggregation = _get_band_agg_options(embeddings)
|
||||||
|
embedding_values = embeddings["embeddings"].sel(agg=aggregation).compute()
|
||||||
|
|
||||||
|
_render_distribution(embedding_values)
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
if "year" in embedding_values.dims or "year" in embedding_values.coords:
|
||||||
|
_render_trend(embedding_values.sel(band=band))
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
_render_embedding_map(embedding_values.sel(band=band), grid_gdf)
|
||||||
|
|
@ -23,7 +23,7 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame):
|
||||||
key="metric",
|
key="metric",
|
||||||
)
|
)
|
||||||
with cols[1]:
|
with cols[1]:
|
||||||
make_3d_map = cast(bool, st.checkbox("3D Map", value=True, key="area_3d_map"))
|
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
|
||||||
|
|
||||||
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
|
map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map)
|
||||||
st.pydeck_chart(map_deck)
|
st.pydeck_chart(map_deck)
|
||||||
|
|
|
||||||
|
|
@ -436,6 +436,42 @@ def _render_aggregation_selection(
|
||||||
return dimension_filters
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def render_member_details(member: str, member_stats: MemberStatistics):
|
||||||
|
"""Render detailed information for a single member.
|
||||||
|
|
||||||
|
Displays variables and dimensions with styled badges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member: Member dataset name
|
||||||
|
member_stats: Statistics for the member
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown(f"### {member}")
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
st.markdown("**Variables:**")
|
||||||
|
vars_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||||
|
for v in member_stats.variable_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(vars_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Dimensions
|
||||||
|
st.markdown("**Dimensions:**")
|
||||||
|
dim_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||||
|
f"{dim_name}: {dim_size:,}</span>"
|
||||||
|
for dim_name, dim_size in member_stats.dimensions.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
|
||||||
def render_ensemble_details(
|
def render_ensemble_details(
|
||||||
selected_members: list[L2SourceDataset],
|
selected_members: list[L2SourceDataset],
|
||||||
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
||||||
|
|
@ -502,33 +538,9 @@ def render_ensemble_details(
|
||||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
# Individual member details
|
# Individual member details
|
||||||
for member, member_stats in selected_member_stats.items():
|
for member, stats in selected_member_stats.items():
|
||||||
st.markdown(f"### {member}")
|
render_member_details(member, stats)
|
||||||
|
st.divider()
|
||||||
# Variables
|
|
||||||
st.markdown("**Variables:**")
|
|
||||||
vars_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
|
||||||
for v in member_stats.variable_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(vars_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Dimensions
|
|
||||||
st.markdown("**Dimensions:**")
|
|
||||||
dim_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
|
||||||
f"{dim_name}: {dim_size:,}</span>"
|
|
||||||
for dim_name, dim_size in member_stats.dimensions.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
|
|
||||||
def _render_configuration_summary(
|
def _render_configuration_summary(
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
"Experiment",
|
"Experiment",
|
||||||
options=["All", *experiments],
|
options=["All", *experiments],
|
||||||
index=0,
|
index=0,
|
||||||
|
key="exp_results_experiment",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
selected_experiment = "All"
|
selected_experiment = "All"
|
||||||
|
|
@ -61,6 +62,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
"Task",
|
"Task",
|
||||||
options=["All", *tasks],
|
options=["All", *tasks],
|
||||||
index=0,
|
index=0,
|
||||||
|
key="exp_results_task",
|
||||||
)
|
)
|
||||||
|
|
||||||
with filter_cols[2]:
|
with filter_cols[2]:
|
||||||
|
|
@ -68,6 +70,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
"Model",
|
"Model",
|
||||||
options=["All", *models],
|
options=["All", *models],
|
||||||
index=0,
|
index=0,
|
||||||
|
key="exp_results_model",
|
||||||
)
|
)
|
||||||
|
|
||||||
with filter_cols[3]:
|
with filter_cols[3]:
|
||||||
|
|
@ -75,6 +78,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
"Grid",
|
"Grid",
|
||||||
options=["All", *grids],
|
options=["All", *grids],
|
||||||
index=0,
|
index=0,
|
||||||
|
key="exp_results_grid",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
|
||||||
"Select Target Dataset",
|
"Select Target Dataset",
|
||||||
options=sorted(train_data_dict.keys()),
|
options=sorted(train_data_dict.keys()),
|
||||||
index=0,
|
index=0,
|
||||||
|
key="target_map_dataset",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
with cols[1]:
|
with cols[1]:
|
||||||
|
|
@ -179,6 +180,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS
|
||||||
"Select Task",
|
"Select Task",
|
||||||
options=sorted(train_data_dict[selected_target].keys()),
|
options=sorted(train_data_dict[selected_target].keys()),
|
||||||
index=0,
|
index=0,
|
||||||
|
key="target_map_task",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
with cols[2]:
|
with cols[2]:
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ from shapely.geometry import shape
|
||||||
import entropice.spatial.grids
|
import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||||
from entropice.ml.training import TrainingSettings
|
from entropice.ml.training import TrainingSettings
|
||||||
from entropice.utils.types import GridConfig
|
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
|
||||||
|
|
||||||
|
|
||||||
def _fix_hex_geometry(geom):
|
def _fix_hex_geometry(geom):
|
||||||
|
|
@ -239,3 +240,13 @@ def load_all_training_results() -> list[TrainingResult]:
|
||||||
# Sort by creation time (most recent first)
|
# Sort by creation time (most recent first)
|
||||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
return training_results
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Task, TrainingSet]]:
|
||||||
|
"""Load training sets for all target-task combinations in the ensemble."""
|
||||||
|
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
|
||||||
|
for target in all_target_datasets:
|
||||||
|
train_data_dict[target] = {}
|
||||||
|
for task in all_tasks:
|
||||||
|
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
||||||
|
return train_data_dict
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import xarray as xr
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
import entropice.spatial.grids
|
import entropice.spatial.grids
|
||||||
|
|
@ -39,11 +40,19 @@ class MemberStatistics:
|
||||||
size_bytes: int # Size of this member's data on disk in bytes
|
size_bytes: int # Size of this member's data on disk in bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
def compute(
|
||||||
|
cls,
|
||||||
|
e: DatasetEnsemble,
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
|
||||||
|
) -> dict[L2SourceDataset, "MemberStatistics"]:
|
||||||
"""Pre-compute the statistics for a specific dataset member."""
|
"""Pre-compute the statistics for a specific dataset member."""
|
||||||
|
member_datasets = member_datasets or {}
|
||||||
member_stats = {}
|
member_stats = {}
|
||||||
for member in e.members:
|
for member in e.members:
|
||||||
ds = e.read_member(member, lazy=True)
|
if member in member_datasets:
|
||||||
|
ds = member_datasets[member]
|
||||||
|
else:
|
||||||
|
ds = e.read_member(member, lazy=True)
|
||||||
size_bytes = ds.nbytes
|
size_bytes = ds.nbytes
|
||||||
|
|
||||||
n_cols_member = len(ds.data_vars)
|
n_cols_member = len(ds.data_vars)
|
||||||
|
|
@ -113,7 +122,11 @@ class DatasetStatistics:
|
||||||
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
|
def from_ensemble(
|
||||||
|
cls,
|
||||||
|
e: DatasetEnsemble,
|
||||||
|
member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None,
|
||||||
|
) -> "DatasetStatistics":
|
||||||
"""Compute dataset statistics from a DatasetEnsemble."""
|
"""Compute dataset statistics from a DatasetEnsemble."""
|
||||||
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
|
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
|
||||||
total_cells = len(grid_gdf)
|
total_cells = len(grid_gdf)
|
||||||
|
|
@ -123,7 +136,7 @@ class DatasetStatistics:
|
||||||
# darts_mllabels does not support year-based temporal modes
|
# darts_mllabels does not support year-based temporal modes
|
||||||
continue
|
continue
|
||||||
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
||||||
member_statistics = MemberStatistics.compute(e)
|
member_statistics = MemberStatistics.compute(e, member_datasets=member_datasets)
|
||||||
|
|
||||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
||||||
|
|
|
||||||
|
|
@ -3,21 +3,20 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.dashboard.sections.alphaearth import render_alphaearth_tab
|
||||||
from entropice.dashboard.sections.areas import render_area_information_tab
|
from entropice.dashboard.sections.areas import render_area_information_tab
|
||||||
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
|
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
|
||||||
from entropice.dashboard.sections.targets import render_target_information_tab
|
from entropice.dashboard.sections.targets import render_target_information_tab
|
||||||
|
from entropice.dashboard.utils.loaders import load_training_sets
|
||||||
from entropice.dashboard.utils.stats import DatasetStatistics
|
from entropice.dashboard.utils.stats import DatasetStatistics
|
||||||
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.utils.types import (
|
from entropice.utils.types import (
|
||||||
GridConfig,
|
GridConfig,
|
||||||
L2SourceDataset,
|
L2SourceDataset,
|
||||||
TargetDataset,
|
|
||||||
Task,
|
|
||||||
TemporalMode,
|
TemporalMode,
|
||||||
all_target_datasets,
|
|
||||||
all_tasks,
|
|
||||||
grid_configs,
|
grid_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,6 +37,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
||||||
options=grid_options,
|
options=grid_options,
|
||||||
index=0,
|
index=0,
|
||||||
help="Select the grid system and resolution level",
|
help="Select the grid system and resolution level",
|
||||||
|
key="dataset_page_grid",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find the selected grid config
|
# Find the selected grid config
|
||||||
|
|
@ -48,12 +48,13 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
||||||
"Temporal Mode",
|
"Temporal Mode",
|
||||||
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
|
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
|
||||||
index=0,
|
index=0,
|
||||||
format_func=lambda x: "Synopsis (all years)"
|
format_func=lambda x: "Synopsis (mean + trend)"
|
||||||
if x == "synopsis"
|
if x == "synopsis"
|
||||||
else "Years-as-Features"
|
else "Years-as-Features"
|
||||||
if x == "feature"
|
if x == "feature"
|
||||||
else f"Year {x}",
|
else f"Year {x}",
|
||||||
help="Select temporal mode: 'synopsis' for temporal features or specific year",
|
help="Select temporal mode: 'synopsis' for temporal features or specific year",
|
||||||
|
key="dataset_page_temporal_mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Members selection
|
# Members selection
|
||||||
|
|
@ -108,23 +109,20 @@ def render_dataset_page():
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
|
member_datasets = cast(
|
||||||
|
dict[L2SourceDataset, xr.Dataset],
|
||||||
|
{member: ensemble.read_member(member, lazy=True) for member in ensemble.members},
|
||||||
|
)
|
||||||
# Render dataset statistics section
|
# Render dataset statistics section
|
||||||
stats = DatasetStatistics.from_ensemble(ensemble)
|
stats = DatasetStatistics.from_ensemble(ensemble, member_datasets=member_datasets)
|
||||||
render_ensemble_details(ensemble.members, stats.members)
|
render_ensemble_details(ensemble.members, stats.members)
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Load data and precompute visualizations
|
# Load data and precompute visualizations
|
||||||
# First, load for all task - target combinations the training data
|
|
||||||
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
|
|
||||||
for target in all_target_datasets:
|
|
||||||
train_data_dict[target] = {}
|
|
||||||
for task in all_tasks:
|
|
||||||
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
|
||||||
# Preload the grid GeoDataFrame
|
# Preload the grid GeoDataFrame
|
||||||
grid_gdf = ensemble.read_grid()
|
grid_gdf = ensemble.read_grid()
|
||||||
|
|
||||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
|
||||||
# Create tabs for different data views
|
# Create tabs for different data views
|
||||||
tab_names = ["🎯 Targets", "📐 Areas"]
|
tab_names = ["🎯 Targets", "📐 Areas"]
|
||||||
# Add tabs for each member based on what's in the ensemble
|
# Add tabs for each member based on what's in the ensemble
|
||||||
|
|
@ -132,21 +130,27 @@ def render_dataset_page():
|
||||||
tab_names.append("🌍 AlphaEarth")
|
tab_names.append("🌍 AlphaEarth")
|
||||||
if "ArcticDEM" in ensemble.members:
|
if "ArcticDEM" in ensemble.members:
|
||||||
tab_names.append("🏔️ ArcticDEM")
|
tab_names.append("🏔️ ArcticDEM")
|
||||||
|
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||||
if era5_members:
|
if era5_members:
|
||||||
tab_names.append("🌡️ ERA5")
|
tab_names.append("🌡️ ERA5")
|
||||||
tabs = st.tabs(tab_names)
|
tabs = st.tabs(tab_names)
|
||||||
|
|
||||||
with tabs[0]:
|
with tabs[0]:
|
||||||
st.header("🎯 Target Labels Visualization")
|
st.header("🎯 Target Labels Visualization")
|
||||||
if False: #! debug
|
if False: # ! DEBUG
|
||||||
|
train_data_dict = load_training_sets(ensemble)
|
||||||
render_target_information_tab(train_data_dict)
|
render_target_information_tab(train_data_dict)
|
||||||
with tabs[1]:
|
with tabs[1]:
|
||||||
st.header("📐 Areas Visualization")
|
st.header("📐 Areas Visualization")
|
||||||
render_area_information_tab(grid_gdf)
|
if False: # ! DEBUG
|
||||||
|
render_area_information_tab(grid_gdf)
|
||||||
tab_index = 2
|
tab_index = 2
|
||||||
if "AlphaEarth" in ensemble.members:
|
if "AlphaEarth" in ensemble.members:
|
||||||
with tabs[tab_index]:
|
with tabs[tab_index]:
|
||||||
st.header("🌍 AlphaEarth Visualization")
|
st.header("🌍 AlphaEarth Visualization")
|
||||||
|
alphaearth_ds = member_datasets["AlphaEarth"]
|
||||||
|
alphaearth_stats = stats.members["AlphaEarth"]
|
||||||
|
render_alphaearth_tab(alphaearth_ds, grid_gdf, alphaearth_stats)
|
||||||
tab_index += 1
|
tab_index += 1
|
||||||
if "ArcticDEM" in ensemble.members:
|
if "ArcticDEM" in ensemble.members:
|
||||||
with tabs[tab_index]:
|
with tabs[tab_index]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue