diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py
index cee846a..8406c69 100644
--- a/src/entropice/dashboard/plots/source_data.py
+++ b/src/entropice/dashboard/plots/source_data.py
@@ -29,32 +29,36 @@ def render_alphaearth_overview(ds: xr.Dataset):
ds: xarray Dataset containing AlphaEarth embeddings.
"""
- st.subheader("π AlphaEarth Embeddings Statistics")
+ st.subheader("π Data Overview")
- # Overall statistics
+ # Key metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
- st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+ st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
- st.metric("Embedding Dimensions", f"{len(ds['band'])}")
+ st.metric("Embedding Dims", f"{len(ds['band'])}")
with col3:
- st.metric("Years Available", f"{len(ds['year'])}")
+ years = sorted(ds["year"].values)
+ st.metric("Years", f"{min(years)}β{max(years)}")
with col4:
st.metric("Aggregations", f"{len(ds['agg'])}")
- # Show temporal coverage
- st.markdown("**Temporal Coverage:**")
- years = sorted(ds["year"].values)
- st.write(f"Years: {min(years)} - {max(years)}")
-
- # Show aggregations
- st.markdown("**Available Aggregations:**")
- aggs = ds["agg"].to_numpy()
- st.write(", ".join(str(a) for a in aggs))
+ # Show aggregations as badges in an expander
+ with st.expander("βΉοΈ Data Details", expanded=False):
+ st.markdown("**Spatial Aggregations:**")
+ aggs = ds["agg"].to_numpy()
+ aggs_html = " ".join(
+ [
+ f'{a}'
+ for a in aggs
+ ]
+ )
+ st.markdown(aggs_html, unsafe_allow_html=True)
@st.fragment
@@ -207,13 +211,13 @@ def render_arcticdem_overview(ds: xr.Dataset):
ds: xarray Dataset containing ArcticDEM data.
"""
- st.subheader("ποΈ ArcticDEM Terrain Statistics")
+ st.subheader("π Data Overview")
- # Overall statistics
+ # Key metrics
col1, col2, col3 = st.columns(3)
with col1:
- st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+ st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
@@ -221,19 +225,22 @@ def render_arcticdem_overview(ds: xr.Dataset):
with col3:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
- # Show available variables
- st.markdown("**Available Variables:**")
- variables = list(ds.data_vars)
- st.write(", ".join(variables))
-
- # Show aggregations
- st.markdown("**Available Aggregations:**")
- aggs = ds["aggregations"].to_numpy()
- st.write(", ".join(str(a) for a in aggs))
+ # Show details in expander
+ with st.expander("βΉοΈ Data Details", expanded=False):
+ st.markdown("**Spatial Aggregations:**")
+ aggs = ds["aggregations"].to_numpy()
+ aggs_html = " ".join(
+ [
+ f'{a}'
+ for a in aggs
+ ]
+ )
+ st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
- st.markdown("**Variable Statistics (across all aggregations)**")
+ st.markdown("**π Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
@@ -307,45 +314,50 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
"""
- st.subheader(f"π‘οΈ ERA5 Climate Statistics ({temporal_type.capitalize()})")
+ st.subheader("π Data Overview")
- # Overall statistics
+ # Key metrics
has_agg = "aggregations" in ds.dims
col1, col2, col3, col4 = st.columns(4)
with col1:
- st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
+ st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2:
st.metric("Variables", f"{len(ds.data_vars)}")
with col3:
- st.metric("Time Steps", f"{len(ds['time'])}")
+ time_values = pd.to_datetime(ds["time"].values)
+ st.metric("Time Steps", f"{time_values.min().strftime('%Y')}β{time_values.max().strftime('%Y')}")
with col4:
if has_agg:
st.metric("Aggregations", f"{len(ds['aggregations'])}")
else:
- st.metric("Aggregations", "1")
+ st.metric("Temporal Type", temporal_type.capitalize())
- # Show available variables
- st.markdown("**Available Variables:**")
- variables = list(ds.data_vars)
- st.write(", ".join(variables))
+ # Show details in expander
+ with st.expander("βΉοΈ Data Details", expanded=False):
+ st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
+ st.markdown(
+ f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
+ )
- # Show temporal range
- st.markdown("**Temporal Range:**")
- time_values = pd.to_datetime(ds["time"].values)
- st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
-
- if has_agg:
- st.markdown("**Available Aggregations:**")
- aggs = ds["aggregations"].to_numpy()
- st.write(", ".join(str(a) for a in aggs))
+ if has_agg:
+ st.markdown("**Spatial Aggregations:**")
+ aggs = ds["aggregations"].to_numpy()
+ aggs_html = " ".join(
+ [
+ f'{a}'
+ for a in aggs
+ ]
+ )
+ st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable
st.markdown("---")
- st.markdown("**Variable Statistics (across all time steps and aggregations)**")
+ st.markdown("**π Variable Statistics**")
var_stats = []
for var_name in ds.data_vars:
@@ -380,36 +392,115 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
- selected_var = st.selectbox(
- "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
- )
+ if has_agg:
+ col1, col2, col3 = st.columns([2, 2, 1])
+ with col1:
+ selected_var = st.selectbox(
+ "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
+ )
+ with col2:
+ selected_agg = st.selectbox(
+ "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select"
+ )
+ with col3:
+ show_std = st.checkbox("Show Β±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
+ show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
+ else:
+ col1, col2 = st.columns([3, 1])
+ with col1:
+ selected_var = st.selectbox(
+ "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
+ )
+ with col2:
+ show_std = st.checkbox("Show Β±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
+ show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
if selected_var:
var_data = ds[selected_var]
- # Calculate mean over space for each time step
- if has_agg:
- # Average over aggregations first, then over cells
- time_series = var_data.mean(dim=["cell_ids", "aggregations"])
- else:
- time_series = var_data.mean(dim="cell_ids")
+ # Calculate statistics over space for each time step
+ time_values = pd.to_datetime(ds["time"].to_numpy())
- time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
+ if has_agg:
+ # Select specific aggregation, then calculate stats over cells
+ var_data_agg = var_data.sel(aggregations=selected_agg)
+ time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
+ time_std = var_data_agg.std(dim="cell_ids").to_numpy()
+ time_min = var_data_agg.min(dim="cell_ids").to_numpy()
+ time_max = var_data_agg.max(dim="cell_ids").to_numpy()
+ else:
+ time_mean = var_data.mean(dim="cell_ids").to_numpy()
+ time_std = var_data.std(dim="cell_ids").to_numpy()
+ time_min = var_data.min(dim="cell_ids").to_numpy()
+ time_max = var_data.max(dim="cell_ids").to_numpy()
fig = go.Figure()
+ # Add min/max range first (background) - optional
+ if show_minmax:
+ fig.add_trace(
+ go.Scatter(
+ x=time_values,
+ y=time_min,
+ mode="lines",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ name="Min/Max Range",
+ showlegend=True,
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=time_values,
+ y=time_max,
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(200, 200, 200, 0.1)",
+ line={"color": "lightgray", "width": 1, "dash": "dash"},
+ showlegend=False,
+ )
+ )
+
+ # Add std band - optional
+ if show_std:
+ fig.add_trace(
+ go.Scatter(
+ x=time_values,
+ y=time_mean - time_std,
+ mode="lines",
+ line={"width": 0},
+ showlegend=False,
+ hoverinfo="skip",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=time_values,
+ y=time_mean + time_std,
+ mode="lines",
+ fill="tonexty",
+ fillcolor="rgba(31, 119, 180, 0.2)",
+ line={"width": 0},
+ name="Β±1 Std",
+ )
+ )
+
+ # Add mean line on top
fig.add_trace(
go.Scatter(
- x=time_df["Time"],
- y=time_df["Value"],
+ x=time_values,
+ y=time_mean,
mode="lines+markers",
- name=selected_var,
- line={"width": 2},
+ name="Mean",
+ line={"color": "#1f77b4", "width": 2},
+ marker={"size": 4},
)
)
+ title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
fig.update_layout(
- title=f"Temporal Trend of {selected_var} (Spatial Mean)",
+ title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
xaxis_title="Time",
yaxis_title=selected_var,
height=400,
@@ -431,19 +522,27 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"""
st.subheader("πΊοΈ AlphaEarth Spatial Distribution")
- # Controls
- col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
+ # Year slider (full width)
+ years = sorted(ds["year"].values)
+ selected_year = st.slider(
+ "Year",
+ min_value=int(years[0]),
+ max_value=int(years[-1]),
+ value=int(years[-1]),
+ step=1,
+ key="alphaearth_year",
+ )
+
+ # Other controls
+ col1, col2, col3 = st.columns([2, 2, 1])
with col1:
- selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
-
- with col2:
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
- with col3:
+ with col2:
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
- with col4:
+ with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
# Extract data for selected parameters
@@ -475,6 +574,9 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+ # Set elevation based on normalized values
+ gdf_wgs84["elevation"] = normalized
+
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
@@ -484,24 +586,28 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"properties": {
"value": float(row["value"]),
"fill_color": row["fill_color"],
+ "elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
- # Create pydeck layer
+ # Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
+ extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
+ get_elevation="properties.elevation",
+ elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck(
layers=[layer],
@@ -582,12 +688,16 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+ # Set elevation based on normalized values
+ gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
+
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
+ "elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if len(ds["aggregations"]) > 1:
@@ -603,20 +713,23 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
}
geojson_data.append(feature)
- # Create pydeck layer
+ # Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
+ extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
+ get_elevation="properties.elevation",
+ elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ArcticDEM
if len(ds["aggregations"]) > 1:
@@ -663,37 +776,43 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims
+ # Top row: Variable, Aggregation (if applicable), and Opacity
if has_agg:
- col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
- else:
- col1, col2, col3 = st.columns([3, 3, 1])
-
- with col1:
- selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
-
- with col2:
- # Convert time to readable format
- time_values = pd.to_datetime(ds["time"].values)
- time_options = {str(t): t for t in time_values}
- selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
-
- if has_agg:
- with col3:
+ col1, col2, col3 = st.columns([2, 2, 1])
+ with col1:
+ selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
+ with col2:
selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
)
- with col4:
- opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
- else:
with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
+ else:
+ col1, col2 = st.columns([4, 1])
+ with col1:
+ selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
+ with col2:
+ opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
+
+ # Bottom row: Time slider (full width)
+ time_values = pd.to_datetime(ds["time"].values)
+ time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
+ selected_time_idx = st.slider(
+ "Time",
+ min_value=0,
+ max_value=len(time_values) - 1,
+ value=len(time_values) - 1,
+ format="",
+ key=f"era5_{temporal_type}_time_slider",
+ )
+ st.caption(f"Selected: {time_labels[selected_time_idx]}")
+ selected_time = time_values[selected_time_idx]
# Extract data for selected parameters
- time_val = time_options[selected_time]
if has_agg:
- data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
+ data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
else:
- data_values = ds[selected_var].sel(time=time_val)
+ data_values = ds[selected_var].sel(time=selected_time)
# Create GeoDataFrame
gdf = targets.copy()
@@ -707,7 +826,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
# Add all aggregation values for tooltip if has_agg
if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values:
- agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
+ agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop dimension columns to avoid conflicts
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
if cols_to_drop:
@@ -734,12 +853,16 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+ # Set elevation based on normalized values
+ gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
+
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"],
+ "elevation": float(row["elevation"]),
}
# Add all aggregation values if available
if has_agg and len(ds["aggregations"]) > 1:
@@ -755,20 +878,23 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
}
geojson_data.append(feature)
- # Create pydeck layer
+ # Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
+ extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
+ get_elevation="properties.elevation",
+ elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ERA5
if has_agg and len(ds["aggregations"]) > 1:
diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py
index 72f31ff..88dc5f5 100644
--- a/src/entropice/dashboard/training_data_page.py
+++ b/src/entropice/dashboard/training_data_page.py
@@ -125,32 +125,61 @@ def render_training_data_page():
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
- # Display target information
- col1, col2 = st.columns(2)
+ # High-level summary metrics
+ col1, col2, col3 = st.columns(3)
with col1:
- st.metric(label="Target", value=stats["target"].replace("darts_", ""))
+ st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
- st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}")
+ st.metric(label="Total Features", value=f"{stats['total_features']:,}")
+ with col3:
+ st.metric(label="Data Sources", value=len(stats["members"]))
- # Display member statistics
- st.markdown("**Member Statistics:**")
+ # Detailed member statistics in expandable section
+ with st.expander("π¦ Data Source Details", expanded=False):
+ for member, member_stats in stats["members"].items():
+ st.markdown(f"### {member}")
- for member, member_stats in stats["members"].items():
- with st.expander(f"π¦ {member}", expanded=False):
- col1, col2 = st.columns(2)
- with col1:
- st.markdown(f"**Number of Features:** {member_stats['num_features']}")
- st.markdown(f"**Number of Variables:** {member_stats['num_variables']}")
- with col2:
- st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`")
+ # Create metrics for this member
+ metric_cols = st.columns(4)
+ with metric_cols[0]:
+ st.metric("Features", member_stats["num_features"])
+ with metric_cols[1]:
+ st.metric("Variables", member_stats["num_variables"])
+ with metric_cols[2]:
+ # Display dimensions in a more readable format
+ dim_str = " Γ ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
+ st.metric("Shape", dim_str)
+ with metric_cols[3]:
+ # Calculate total data points
+ total_points = 1
+ for dim_size in member_stats["dimensions"].values():
+ total_points *= dim_size
+ st.metric("Data Points", f"{total_points:,}")
- # Display variables as a compact list
- st.markdown(f"**Variables ({member_stats['num_variables']}):**")
- vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]])
- st.markdown(vars_str)
+ # Show variables as colored badges
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats["variables"]
+ ]
+ )
+ st.markdown(vars_html, unsafe_allow_html=True)
- # Display total features
- st.metric(label="π― Total Number of Features", value=f"{stats['total_features']:,}")
+ # Show dimension details
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size}"
+ for dim_name, dim_size in member_stats["dimensions"].items()
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
+
+ st.markdown("---")
st.markdown("---")