From 8338efb31e25442ad89bf39e78ac63b5dfaa25d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Fri, 19 Dec 2025 15:36:45 +0100 Subject: [PATCH] Finalize the Training Data Page --- src/entropice/dashboard/plots/source_data.py | 318 ++++++++++++------ src/entropice/dashboard/training_data_page.py | 69 ++-- 2 files changed, 271 insertions(+), 116 deletions(-) diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py index cee846a..8406c69 100644 --- a/src/entropice/dashboard/plots/source_data.py +++ b/src/entropice/dashboard/plots/source_data.py @@ -29,32 +29,36 @@ def render_alphaearth_overview(ds: xr.Dataset): ds: xarray Dataset containing AlphaEarth embeddings. """ - st.subheader("πŸ“Š AlphaEarth Embeddings Statistics") + st.subheader("πŸ“Š Data Overview") - # Overall statistics + # Key metrics col1, col2, col3, col4 = st.columns(4) with col1: - st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + st.metric("Cells", f"{len(ds['cell_ids']):,}") with col2: - st.metric("Embedding Dimensions", f"{len(ds['band'])}") + st.metric("Embedding Dims", f"{len(ds['band'])}") with col3: - st.metric("Years Available", f"{len(ds['year'])}") + years = sorted(ds["year"].values) + st.metric("Years", f"{min(years)}–{max(years)}") with col4: st.metric("Aggregations", f"{len(ds['agg'])}") - # Show temporal coverage - st.markdown("**Temporal Coverage:**") - years = sorted(ds["year"].values) - st.write(f"Years: {min(years)} - {max(years)}") - - # Show aggregations - st.markdown("**Available Aggregations:**") - aggs = ds["agg"].to_numpy() - st.write(", ".join(str(a) for a in aggs)) + # Show aggregations as badges in an expander + with st.expander("ℹ️ Data Details", expanded=False): + st.markdown("**Spatial Aggregations:**") + aggs = ds["agg"].to_numpy() + aggs_html = " ".join( + [ + f'{a}' + for a in aggs + ] + ) + st.markdown(aggs_html, unsafe_allow_html=True) @st.fragment @@ -207,13 +211,13 @@ def render_arcticdem_overview(ds: xr.Dataset): ds: xarray Dataset containing ArcticDEM data. """ - st.subheader("πŸ”οΈ ArcticDEM Terrain Statistics") + st.subheader("πŸ“Š Data Overview") - # Overall statistics + # Key metrics col1, col2, col3 = st.columns(3) with col1: - st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + st.metric("Cells", f"{len(ds['cell_ids']):,}") with col2: st.metric("Variables", f"{len(ds.data_vars)}") @@ -221,19 +225,22 @@ def render_arcticdem_overview(ds: xr.Dataset): with col3: st.metric("Aggregations", f"{len(ds['aggregations'])}") - # Show available variables - st.markdown("**Available Variables:**") - variables = list(ds.data_vars) - st.write(", ".join(variables)) - - # Show aggregations - st.markdown("**Available Aggregations:**") - aggs = ds["aggregations"].to_numpy() - st.write(", ".join(str(a) for a in aggs)) + # Show details in expander + with st.expander("ℹ️ Data Details", expanded=False): + st.markdown("**Spatial Aggregations:**") + aggs = ds["aggregations"].to_numpy() + aggs_html = " ".join( + [ + f'{a}' + for a in aggs + ] + ) + st.markdown(aggs_html, unsafe_allow_html=True) # Statistics by variable st.markdown("---") - st.markdown("**Variable Statistics (across all aggregations)**") + st.markdown("**πŸ“ˆ Variable Statistics**") var_stats = [] for var_name in ds.data_vars: @@ -307,45 +314,50 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str): temporal_type: One of 'yearly', 'seasonal', 'shoulder'. """ - st.subheader(f"🌑️ ERA5 Climate Statistics ({temporal_type.capitalize()})") + st.subheader("πŸ“Š Data Overview") - # Overall statistics + # Key metrics has_agg = "aggregations" in ds.dims col1, col2, col3, col4 = st.columns(4) with col1: - st.metric("Total Cells", f"{len(ds['cell_ids']):,}") + st.metric("Cells", f"{len(ds['cell_ids']):,}") with col2: st.metric("Variables", f"{len(ds.data_vars)}") with col3: - st.metric("Time Steps", f"{len(ds['time'])}") + time_values = pd.to_datetime(ds["time"].values) + st.metric("Time Steps", f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}") with col4: if has_agg: st.metric("Aggregations", f"{len(ds['aggregations'])}") else: - st.metric("Aggregations", "1") + st.metric("Temporal Type", temporal_type.capitalize()) - # Show available variables - st.markdown("**Available Variables:**") - variables = list(ds.data_vars) - st.write(", ".join(variables)) + # Show details in expander + with st.expander("ℹ️ Data Details", expanded=False): + st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}") + st.markdown( + f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}" + ) - # Show temporal range - st.markdown("**Temporal Range:**") - time_values = pd.to_datetime(ds["time"].values) - st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}") - - if has_agg: - st.markdown("**Available Aggregations:**") - aggs = ds["aggregations"].to_numpy() - st.write(", ".join(str(a) for a in aggs)) + if has_agg: + st.markdown("**Spatial Aggregations:**") + aggs = ds["aggregations"].to_numpy() + aggs_html = " ".join( + [ + f'{a}' + for a in aggs + ] + ) + st.markdown(aggs_html, unsafe_allow_html=True) # Statistics by variable st.markdown("---") - st.markdown("**Variable Statistics (across all time steps and aggregations)**") + st.markdown("**πŸ“ˆ Variable Statistics**") var_stats = [] for var_name in ds.data_vars: @@ -380,36 +392,115 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str): variables = list(ds.data_vars) has_agg = "aggregations" in ds.dims - selected_var = st.selectbox( - "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" - ) + if has_agg: + col1, col2, col3 = st.columns([2, 2, 1]) + with col1: + selected_var = st.selectbox( + "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" + ) + with col2: + selected_agg = st.selectbox( + "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select" + ) + with col3: + show_std = st.checkbox("Show Β±1 Std", value=True, key=f"era5_{temporal_type}_show_std") + show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax") + else: + col1, col2 = st.columns([3, 1]) + with col1: + selected_var = st.selectbox( + "Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" + ) + with col2: + show_std = st.checkbox("Show Β±1 Std", value=True, key=f"era5_{temporal_type}_show_std") + show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax") if selected_var: var_data = ds[selected_var] - # Calculate mean over space for each time step - if has_agg: - # Average over aggregations first, then over cells - time_series = var_data.mean(dim=["cell_ids", "aggregations"]) - else: - time_series = var_data.mean(dim="cell_ids") + # Calculate statistics over space for each time step + time_values = pd.to_datetime(ds["time"].to_numpy()) - time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()}) + if has_agg: + # Select specific aggregation, then calculate stats over cells + var_data_agg = var_data.sel(aggregations=selected_agg) + time_mean = var_data_agg.mean(dim="cell_ids").to_numpy() + time_std = var_data_agg.std(dim="cell_ids").to_numpy() + time_min = var_data_agg.min(dim="cell_ids").to_numpy() + time_max = var_data_agg.max(dim="cell_ids").to_numpy() + else: + time_mean = var_data.mean(dim="cell_ids").to_numpy() + time_std = var_data.std(dim="cell_ids").to_numpy() + time_min = var_data.min(dim="cell_ids").to_numpy() + time_max = var_data.max(dim="cell_ids").to_numpy() fig = go.Figure() + # Add min/max range first (background) - optional + if show_minmax: + fig.add_trace( + go.Scatter( + x=time_values, + y=time_min, + mode="lines", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + name="Min/Max Range", + showlegend=True, + ) + ) + + fig.add_trace( + go.Scatter( + x=time_values, + y=time_max, + mode="lines", + fill="tonexty", + fillcolor="rgba(200, 200, 200, 0.1)", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + showlegend=False, + ) + ) + + # Add std band - optional + if show_std: + fig.add_trace( + go.Scatter( + x=time_values, + y=time_mean - time_std, + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=time_values, + y=time_mean + time_std, + mode="lines", + fill="tonexty", + fillcolor="rgba(31, 119, 180, 0.2)", + line={"width": 0}, + name="Β±1 Std", + ) + ) + + # Add mean line on top fig.add_trace( go.Scatter( - x=time_df["Time"], - y=time_df["Value"], + x=time_values, + y=time_mean, mode="lines+markers", - name=selected_var, - line={"width": 2}, + name="Mean", + line={"color": "#1f77b4", "width": 2}, + marker={"size": 4}, ) ) + title_suffix = f" (Aggregation: {selected_agg})" if has_agg else "" fig.update_layout( - title=f"Temporal Trend of {selected_var} (Spatial Mean)", + title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}", xaxis_title="Time", yaxis_title=selected_var, height=400, @@ -431,19 +522,27 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): """ st.subheader("πŸ—ΊοΈ AlphaEarth Spatial Distribution") - # Controls - col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) + # Year slider (full width) + years = sorted(ds["year"].values) + selected_year = st.slider( + "Year", + min_value=int(years[0]), + max_value=int(years[-1]), + value=int(years[-1]), + step=1, + key="alphaearth_year", + ) + + # Other controls + col1, col2, col3 = st.columns([2, 2, 1]) with col1: - selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year") - - with col2: selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg") - with col3: + with col2: selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band") - with col4: + with col3: opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity") # Extract data for selected parameters @@ -475,6 +574,9 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): colors = [cmap(val) for val in normalized] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + # Set elevation based on normalized values + gdf_wgs84["elevation"] = normalized + # Create GeoJSON geojson_data = [] for _, row in gdf_wgs84.iterrows(): @@ -484,24 +586,28 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): "properties": { "value": float(row["value"]), "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), }, } geojson_data.append(feature) - # Create pydeck layer + # Create pydeck layer with 3D elevation layer = pdk.Layer( "GeoJsonLayer", geojson_data, opacity=opacity, stroked=True, filled=True, + extruded=True, get_fill_color="properties.fill_color", get_line_color=[80, 80, 80], + get_elevation="properties.elevation", + elevation_scale=500000, line_width_min_pixels=0.5, pickable=True, ) - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) deck = pdk.Deck( layers=[layer], @@ -582,12 +688,16 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + # Set elevation based on normalized values + gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized] + # Create GeoJSON geojson_data = [] for _, row in gdf_wgs84.iterrows(): properties = { "value": float(row["value"]) if not np.isnan(row["value"]) else None, "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), } # Add all aggregation values if available if len(ds["aggregations"]) > 1: @@ -603,20 +713,23 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): } geojson_data.append(feature) - # Create pydeck layer + # Create pydeck layer with 3D elevation layer = pdk.Layer( "GeoJsonLayer", geojson_data, opacity=opacity, stroked=True, filled=True, + extruded=True, get_fill_color="properties.fill_color", get_line_color=[80, 80, 80], + get_elevation="properties.elevation", + elevation_scale=500000, line_width_min_pixels=0.5, pickable=True, ) - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) # Build tooltip HTML for ArcticDEM if len(ds["aggregations"]) > 1: @@ -663,37 +776,43 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor variables = list(ds.data_vars) has_agg = "aggregations" in ds.dims + # Top row: Variable, Aggregation (if applicable), and Opacity if has_agg: - col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) - else: - col1, col2, col3 = st.columns([3, 3, 1]) - - with col1: - selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") - - with col2: - # Convert time to readable format - time_values = pd.to_datetime(ds["time"].values) - time_options = {str(t): t for t in time_values} - selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time") - - if has_agg: - with col3: + col1, col2, col3 = st.columns([2, 2, 1]) + with col1: + selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") + with col2: selected_agg = st.selectbox( "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg" ) - with col4: - opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") - else: with col3: opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") + else: + col1, col2 = st.columns([4, 1]) + with col1: + selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") + with col2: + opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") + + # Bottom row: Time slider (full width) + time_values = pd.to_datetime(ds["time"].values) + time_labels = [t.strftime("%Y-%m-%d") for t in time_values] + selected_time_idx = st.slider( + "Time", + min_value=0, + max_value=len(time_values) - 1, + value=len(time_values) - 1, + format="", + key=f"era5_{temporal_type}_time_slider", + ) + st.caption(f"Selected: {time_labels[selected_time_idx]}") + selected_time = time_values[selected_time_idx] # Extract data for selected parameters - time_val = time_options[selected_time] if has_agg: - data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg) + data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg) else: - data_values = ds[selected_var].sel(time=time_val) + data_values = ds[selected_var].sel(time=selected_time) # Create GeoDataFrame gdf = targets.copy() @@ -707,7 +826,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor # Add all aggregation values for tooltip if has_agg if has_agg and len(ds["aggregations"]) > 1: for agg in ds["aggregations"].values: - agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}") + agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}") # Drop dimension columns to avoid conflicts cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns] if cols_to_drop: @@ -734,12 +853,16 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + # Set elevation based on normalized values + gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized] + # Create GeoJSON geojson_data = [] for _, row in gdf_wgs84.iterrows(): properties = { "value": float(row["value"]) if not np.isnan(row["value"]) else None, "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), } # Add all aggregation values if available if has_agg and len(ds["aggregations"]) > 1: @@ -755,20 +878,23 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor } geojson_data.append(feature) - # Create pydeck layer + # Create pydeck layer with 3D elevation layer = pdk.Layer( "GeoJsonLayer", geojson_data, opacity=opacity, stroked=True, filled=True, + extruded=True, get_fill_color="properties.fill_color", get_line_color=[80, 80, 80], + get_elevation="properties.elevation", + elevation_scale=500000, line_width_min_pixels=0.5, pickable=True, ) - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) # Build tooltip HTML for ERA5 if has_agg and len(ds["aggregations"]) > 1: diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py index 72f31ff..88dc5f5 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/training_data_page.py @@ -125,32 +125,61 @@ def render_training_data_page(): with st.spinner("Computing dataset statistics..."): stats = ensemble.get_stats() - # Display target information - col1, col2 = st.columns(2) + # High-level summary metrics + col1, col2, col3 = st.columns(3) with col1: - st.metric(label="Target", value=stats["target"].replace("darts_", "")) + st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}") with col2: - st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}") + st.metric(label="Total Features", value=f"{stats['total_features']:,}") + with col3: + st.metric(label="Data Sources", value=len(stats["members"])) - # Display member statistics - st.markdown("**Member Statistics:**") + # Detailed member statistics in expandable section + with st.expander("πŸ“¦ Data Source Details", expanded=False): + for member, member_stats in stats["members"].items(): + st.markdown(f"### {member}") - for member, member_stats in stats["members"].items(): - with st.expander(f"πŸ“¦ {member}", expanded=False): - col1, col2 = st.columns(2) - with col1: - st.markdown(f"**Number of Features:** {member_stats['num_features']}") - st.markdown(f"**Number of Variables:** {member_stats['num_variables']}") - with col2: - st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`") + # Create metrics for this member + metric_cols = st.columns(4) + with metric_cols[0]: + st.metric("Features", member_stats["num_features"]) + with metric_cols[1]: + st.metric("Variables", member_stats["num_variables"]) + with metric_cols[2]: + # Display dimensions in a more readable format + dim_str = " Γ— ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) + st.metric("Shape", dim_str) + with metric_cols[3]: + # Calculate total data points + total_points = 1 + for dim_size in member_stats["dimensions"].values(): + total_points *= dim_size + st.metric("Data Points", f"{total_points:,}") - # Display variables as a compact list - st.markdown(f"**Variables ({member_stats['num_variables']}):**") - vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]]) - st.markdown(vars_str) + # Show variables as colored badges + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats["variables"] + ] + ) + st.markdown(vars_html, unsafe_allow_html=True) - # Display total features - st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}") + # Show dimension details + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size}" + for dim_name, dim_size in member_stats["dimensions"].items() + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) + + st.markdown("---") st.markdown("---")