Finalize the Training Data Page

This commit is contained in:
Tobias Hölzer 2025-12-19 15:36:45 +01:00
parent 696bef39c2
commit 8338efb31e
2 changed files with 271 additions and 116 deletions

View file

@ -29,32 +29,36 @@ def render_alphaearth_overview(ds: xr.Dataset):
ds: xarray Dataset containing AlphaEarth embeddings.
"""
st.subheader("📊 AlphaEarth Embeddings Statistics")
st.subheader("📊 Data Overview")
# Overall statistics
# Key metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Embedding Dimensions", f"{len(ds['band'])}")
st.metric("Embedding Dims", f"{len(ds['band'])}")
with col3:
st.metric("Years Available", f"{len(ds['year'])}")
years = sorted(ds["year"].values)
st.metric("Years", f"{min(years)}{max(years)}")
with col4:
st.metric("Aggregations", f"{len(ds['agg'])}")
# Show temporal coverage
st.markdown("**Temporal Coverage:**")
years = sorted(ds["year"].values)
st.write(f"Years: {min(years)} - {max(years)}")
# Show aggregations
st.markdown("**Available Aggregations:**")
# Show aggregations as badges in an expander
with st.expander(" Data Details", expanded=False):
st.markdown("**Spatial Aggregations:**")
aggs = ds["agg"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
@st.fragment
@ -207,13 +211,13 @@ def render_arcticdem_overview(ds: xr.Dataset):
ds: xarray Dataset containing ArcticDEM data.
"""
st.subheader("🏔️ ArcticDEM Terrain Statistics")
st.subheader("📊 Data Overview")
# Overall statistics
# Key metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
@ -221,19 +225,22 @@ def render_arcticdem_overview(ds: xr.Dataset):
with col3:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
# Show available variables
st.markdown("**Available Variables:**")
variables = list(ds.data_vars)
st.write(", ".join(variables))
# Show aggregations
st.markdown("**Available Aggregations:**")
# Show details in expander
with st.expander(" Data Details", expanded=False):
st.markdown("**Spatial Aggregations:**")
aggs = ds["aggregations"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
st.markdown("**Variable Statistics (across all aggregations)**")
st.markdown("**📈 Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
@ -307,45 +314,50 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
st.subheader(f"🌡️ ERA5 Climate Statistics ({temporal_type.capitalize()})")
st.subheader("📊 Data Overview")
# Overall statistics
# Key metrics
has_agg = "aggregations" in ds.dims
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
st.metric("Time Steps", f"{len(ds['time'])}")
time_values = pd.to_datetime(ds["time"].values)
st.metric("Time Steps", f"{time_values.min().strftime('%Y')}{time_values.max().strftime('%Y')}")
with col4:
if has_agg:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
else:
st.metric("Aggregations", "1")
st.metric("Temporal Type", temporal_type.capitalize())
# Show available variables
st.markdown("**Available Variables:**")
variables = list(ds.data_vars)
st.write(", ".join(variables))
# Show temporal range
st.markdown("**Temporal Range:**")
time_values = pd.to_datetime(ds["time"].values)
st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
# Show details in expander
with st.expander(" Data Details", expanded=False):
st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
st.markdown(
f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
)
if has_agg:
st.markdown("**Available Aggregations:**")
st.markdown("**Spatial Aggregations:**")
aggs = ds["aggregations"].to_numpy()
st.write(", ".join(str(a) for a in aggs))
aggs_html = " ".join(
[
f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
st.markdown("**Variable Statistics (across all time steps and aggregations)**")
st.markdown("**📈 Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
@ -380,36 +392,115 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
if has_agg:
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
)
with col2:
selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select"
)
with col3:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
else:
col1, col2 = st.columns([3, 1])
with col1:
selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
)
with col2:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
if selected_var:
var_data = ds[selected_var]
# Calculate mean over space for each time step
if has_agg:
# Average over aggregations first, then over cells
time_series = var_data.mean(dim=["cell_ids", "aggregations"])
else:
time_series = var_data.mean(dim="cell_ids")
# Calculate statistics over space for each time step
time_values = pd.to_datetime(ds["time"].to_numpy())
time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
if has_agg:
# Select specific aggregation, then calculate stats over cells
var_data_agg = var_data.sel(aggregations=selected_agg)
time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
time_std = var_data_agg.std(dim="cell_ids").to_numpy()
time_min = var_data_agg.min(dim="cell_ids").to_numpy()
time_max = var_data_agg.max(dim="cell_ids").to_numpy()
else:
time_mean = var_data.mean(dim="cell_ids").to_numpy()
time_std = var_data.std(dim="cell_ids").to_numpy()
time_min = var_data.min(dim="cell_ids").to_numpy()
time_max = var_data.max(dim="cell_ids").to_numpy()
fig = go.Figure()
# Add min/max range first (background) - optional
if show_minmax:
fig.add_trace(
go.Scatter(
x=time_df["Time"],
y=time_df["Value"],
mode="lines+markers",
name=selected_var,
line={"width": 2},
x=time_values,
y=time_min,
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_max,
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add std band - optional
if show_std:
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean - time_std,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean + time_std,
mode="lines",
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.2)",
line={"width": 0},
name="±1 Std",
)
)
# Add mean line on top
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean,
mode="lines+markers",
name="Mean",
line={"color": "#1f77b4", "width": 2},
marker={"size": 4},
)
)
title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
fig.update_layout(
title=f"Temporal Trend of {selected_var} (Spatial Mean)",
title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
xaxis_title="Time",
yaxis_title=selected_var,
height=400,
@ -431,19 +522,27 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""
st.subheader("🗺️ AlphaEarth Spatial Distribution")
# Controls
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
# Year slider (full width)
years = sorted(ds["year"].values)
selected_year = st.slider(
"Year",
min_value=int(years[0]),
max_value=int(years[-1]),
value=int(years[-1]),
step=1,
key="alphaearth_year",
)
# Other controls
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
with col2:
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
with col3:
with col2:
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
with col4:
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
# Extract data for selected parameters
@ -475,6 +574,9 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = normalized
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
@ -484,24 +586,28 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"properties": {
"value": float(row["value"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
# Create pydeck layer
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck(
layers=[layer],
@ -582,12 +688,16 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if len(ds["aggregations"]) > 1:
@ -603,20 +713,23 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
}
geojson_data.append(feature)
# Create pydeck layer
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ArcticDEM
if len(ds["aggregations"]) > 1:
@ -663,37 +776,43 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
# Top row: Variable, Aggregation (if applicable), and Opacity
if has_agg:
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
else:
col1, col2, col3 = st.columns([3, 3, 1])
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
# Convert time to readable format
time_values = pd.to_datetime(ds["time"].values)
time_options = {str(t): t for t in time_values}
selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
if has_agg:
with col3:
selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
)
with col4:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
col1, col2 = st.columns([4, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
# Bottom row: Time slider (full width)
time_values = pd.to_datetime(ds["time"].values)
time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
selected_time_idx = st.slider(
"Time",
min_value=0,
max_value=len(time_values) - 1,
value=len(time_values) - 1,
format="",
key=f"era5_{temporal_type}_time_slider",
)
st.caption(f"Selected: {time_labels[selected_time_idx]}")
selected_time = time_values[selected_time_idx]
# Extract data for selected parameters
time_val = time_options[selected_time]
if has_agg:
data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
else:
data_values = ds[selected_var].sel(time=time_val)
data_values = ds[selected_var].sel(time=selected_time)
# Create GeoDataFrame
gdf = targets.copy()
@ -707,7 +826,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
# Add all aggregation values for tooltip if has_agg
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop dimension columns to avoid conflicts
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
if cols_to_drop:
@ -734,12 +853,16 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if has_agg and len(ds["aggregations"]) > 1:
@ -755,20 +878,23 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
}
geojson_data.append(feature)
# Create pydeck layer
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ERA5
if has_agg and len(ds["aggregations"]) > 1:

View file

@ -125,32 +125,61 @@ def render_training_data_page():
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Display target information
col1, col2 = st.columns(2)
# High-level summary metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(label="Target", value=stats["target"].replace("darts_", ""))
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}")
# Display member statistics
st.markdown("**Member Statistics:**")
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Detailed member statistics in expandable section
with st.expander("📦 Data Source Details", expanded=False):
for member, member_stats in stats["members"].items():
with st.expander(f"📦 {member}", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**Number of Features:** {member_stats['num_features']}")
st.markdown(f"**Number of Variables:** {member_stats['num_variables']}")
with col2:
st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`")
st.markdown(f"### {member}")
# Display variables as a compact list
st.markdown(f"**Variables ({member_stats['num_variables']}):**")
vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]])
st.markdown(vars_str)
# Create metrics for this member
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
# Display dimensions in a more readable format
dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Display total features
st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}")
# Show variables as colored badges
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Show dimension details
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
st.markdown("---")