Finalize the Training Data Page

This commit is contained in:
Tobias Hölzer 2025-12-19 15:36:45 +01:00
parent 696bef39c2
commit 8338efb31e
2 changed files with 271 additions and 116 deletions

View file

@ -29,32 +29,36 @@ def render_alphaearth_overview(ds: xr.Dataset):
ds: xarray Dataset containing AlphaEarth embeddings. ds: xarray Dataset containing AlphaEarth embeddings.
""" """
st.subheader("📊 AlphaEarth Embeddings Statistics") st.subheader("📊 Data Overview")
# Overall statistics # Key metrics
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
with col1: with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}") st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2: with col2:
st.metric("Embedding Dimensions", f"{len(ds['band'])}") st.metric("Embedding Dims", f"{len(ds['band'])}")
with col3: with col3:
st.metric("Years Available", f"{len(ds['year'])}") years = sorted(ds["year"].values)
st.metric("Years", f"{min(years)}{max(years)}")
with col4: with col4:
st.metric("Aggregations", f"{len(ds['agg'])}") st.metric("Aggregations", f"{len(ds['agg'])}")
# Show temporal coverage # Show aggregations as badges in an expander
st.markdown("**Temporal Coverage:**") with st.expander(" Data Details", expanded=False):
years = sorted(ds["year"].values) st.markdown("**Spatial Aggregations:**")
st.write(f"Years: {min(years)} - {max(years)}") aggs = ds["agg"].to_numpy()
aggs_html = " ".join(
# Show aggregations [
st.markdown("**Available Aggregations:**") f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
aggs = ds["agg"].to_numpy() f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
st.write(", ".join(str(a) for a in aggs)) for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
@st.fragment @st.fragment
@ -207,13 +211,13 @@ def render_arcticdem_overview(ds: xr.Dataset):
ds: xarray Dataset containing ArcticDEM data. ds: xarray Dataset containing ArcticDEM data.
""" """
st.subheader("🏔️ ArcticDEM Terrain Statistics") st.subheader("📊 Data Overview")
# Overall statistics # Key metrics
col1, col2, col3 = st.columns(3) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}") st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2: with col2:
st.metric("Variables", f"{len(ds.data_vars)}") st.metric("Variables", f"{len(ds.data_vars)}")
@ -221,19 +225,22 @@ def render_arcticdem_overview(ds: xr.Dataset):
with col3: with col3:
st.metric("Aggregations", f"{len(ds['aggregations'])}") st.metric("Aggregations", f"{len(ds['aggregations'])}")
# Show available variables # Show details in expander
st.markdown("**Available Variables:**") with st.expander(" Data Details", expanded=False):
variables = list(ds.data_vars) st.markdown("**Spatial Aggregations:**")
st.write(", ".join(variables)) aggs = ds["aggregations"].to_numpy()
aggs_html = " ".join(
# Show aggregations [
st.markdown("**Available Aggregations:**") f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
aggs = ds["aggregations"].to_numpy() f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
st.write(", ".join(str(a) for a in aggs)) for a in aggs
]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable # Statistics by variable
st.markdown("---") st.markdown("---")
st.markdown("**Variable Statistics (across all aggregations)**") st.markdown("**📈 Variable Statistics**")
var_stats = [] var_stats = []
for var_name in ds.data_vars: for var_name in ds.data_vars:
@ -307,45 +314,50 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
temporal_type: One of 'yearly', 'seasonal', 'shoulder'. temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
""" """
st.subheader(f"🌡️ ERA5 Climate Statistics ({temporal_type.capitalize()})") st.subheader("📊 Data Overview")
# Overall statistics # Key metrics
has_agg = "aggregations" in ds.dims has_agg = "aggregations" in ds.dims
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
with col1: with col1:
st.metric("Total Cells", f"{len(ds['cell_ids']):,}") st.metric("Cells", f"{len(ds['cell_ids']):,}")
with col2: with col2:
st.metric("Variables", f"{len(ds.data_vars)}") st.metric("Variables", f"{len(ds.data_vars)}")
with col3: with col3:
st.metric("Time Steps", f"{len(ds['time'])}") time_values = pd.to_datetime(ds["time"].values)
st.metric("Time Steps", f"{time_values.min().strftime('%Y')}{time_values.max().strftime('%Y')}")
with col4: with col4:
if has_agg: if has_agg:
st.metric("Aggregations", f"{len(ds['aggregations'])}") st.metric("Aggregations", f"{len(ds['aggregations'])}")
else: else:
st.metric("Aggregations", "1") st.metric("Temporal Type", temporal_type.capitalize())
# Show available variables # Show details in expander
st.markdown("**Available Variables:**") with st.expander(" Data Details", expanded=False):
variables = list(ds.data_vars) st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}")
st.write(", ".join(variables)) st.markdown(
f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}"
)
# Show temporal range if has_agg:
st.markdown("**Temporal Range:**") st.markdown("**Spatial Aggregations:**")
time_values = pd.to_datetime(ds["time"].values) aggs = ds["aggregations"].to_numpy()
st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}") aggs_html = " ".join(
[
if has_agg: f'<span style="background-color: #e8f5e9; color: #2e7d32; padding: 4px 10px; '
st.markdown("**Available Aggregations:**") f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{a}</span>'
aggs = ds["aggregations"].to_numpy() for a in aggs
st.write(", ".join(str(a) for a in aggs)) ]
)
st.markdown(aggs_html, unsafe_allow_html=True)
# Statistics by variable # Statistics by variable
st.markdown("---") st.markdown("---")
st.markdown("**Variable Statistics (across all time steps and aggregations)**") st.markdown("**📈 Variable Statistics**")
var_stats = [] var_stats = []
for var_name in ds.data_vars: for var_name in ds.data_vars:
@ -380,36 +392,115 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
variables = list(ds.data_vars) variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims has_agg = "aggregations" in ds.dims
selected_var = st.selectbox( if has_agg:
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select" col1, col2, col3 = st.columns([2, 2, 1])
) with col1:
selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
)
with col2:
selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg_select"
)
with col3:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
else:
col1, col2 = st.columns([3, 1])
with col1:
selected_var = st.selectbox(
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
)
with col2:
show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std")
show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax")
if selected_var: if selected_var:
var_data = ds[selected_var] var_data = ds[selected_var]
# Calculate mean over space for each time step # Calculate statistics over space for each time step
if has_agg: time_values = pd.to_datetime(ds["time"].to_numpy())
# Average over aggregations first, then over cells
time_series = var_data.mean(dim=["cell_ids", "aggregations"])
else:
time_series = var_data.mean(dim="cell_ids")
time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()}) if has_agg:
# Select specific aggregation, then calculate stats over cells
var_data_agg = var_data.sel(aggregations=selected_agg)
time_mean = var_data_agg.mean(dim="cell_ids").to_numpy()
time_std = var_data_agg.std(dim="cell_ids").to_numpy()
time_min = var_data_agg.min(dim="cell_ids").to_numpy()
time_max = var_data_agg.max(dim="cell_ids").to_numpy()
else:
time_mean = var_data.mean(dim="cell_ids").to_numpy()
time_std = var_data.std(dim="cell_ids").to_numpy()
time_min = var_data.min(dim="cell_ids").to_numpy()
time_max = var_data.max(dim="cell_ids").to_numpy()
fig = go.Figure() fig = go.Figure()
# Add min/max range first (background) - optional
if show_minmax:
fig.add_trace(
go.Scatter(
x=time_values,
y=time_min,
mode="lines",
line={"color": "lightgray", "width": 1, "dash": "dash"},
name="Min/Max Range",
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_max,
mode="lines",
fill="tonexty",
fillcolor="rgba(200, 200, 200, 0.1)",
line={"color": "lightgray", "width": 1, "dash": "dash"},
showlegend=False,
)
)
# Add std band - optional
if show_std:
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean - time_std,
mode="lines",
line={"width": 0},
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter(
x=time_values,
y=time_mean + time_std,
mode="lines",
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.2)",
line={"width": 0},
name="±1 Std",
)
)
# Add mean line on top
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=time_df["Time"], x=time_values,
y=time_df["Value"], y=time_mean,
mode="lines+markers", mode="lines+markers",
name=selected_var, name="Mean",
line={"width": 2}, line={"color": "#1f77b4", "width": 2},
marker={"size": 4},
) )
) )
title_suffix = f" (Aggregation: {selected_agg})" if has_agg else ""
fig.update_layout( fig.update_layout(
title=f"Temporal Trend of {selected_var} (Spatial Mean)", title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}",
xaxis_title="Time", xaxis_title="Time",
yaxis_title=selected_var, yaxis_title=selected_var,
height=400, height=400,
@ -431,19 +522,27 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
""" """
st.subheader("🗺️ AlphaEarth Spatial Distribution") st.subheader("🗺️ AlphaEarth Spatial Distribution")
# Controls # Year slider (full width)
col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) years = sorted(ds["year"].values)
selected_year = st.slider(
"Year",
min_value=int(years[0]),
max_value=int(years[-1]),
value=int(years[-1]),
step=1,
key="alphaearth_year",
)
# Other controls
col1, col2, col3 = st.columns([2, 2, 1])
with col1: with col1:
selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
with col2:
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg") selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
with col3: with col2:
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band") selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
with col4: with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity") opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
# Extract data for selected parameters # Extract data for selected parameters
@ -475,6 +574,9 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) for val in normalized] colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = normalized
# Create GeoJSON # Create GeoJSON
geojson_data = [] geojson_data = []
for _, row in gdf_wgs84.iterrows(): for _, row in gdf_wgs84.iterrows():
@ -484,24 +586,28 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
"properties": { "properties": {
"value": float(row["value"]), "value": float(row["value"]),
"fill_color": row["fill_color"], "fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
}, },
} }
geojson_data.append(feature) geojson_data.append(feature)
# Create pydeck layer # Create pydeck layer with 3D elevation
layer = pdk.Layer( layer = pdk.Layer(
"GeoJsonLayer", "GeoJsonLayer",
geojson_data, geojson_data,
opacity=opacity, opacity=opacity,
stroked=True, stroked=True,
filled=True, filled=True,
extruded=True,
get_fill_color="properties.fill_color", get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80], get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5, line_width_min_pixels=0.5,
pickable=True, pickable=True,
) )
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck( deck = pdk.Deck(
layers=[layer], layers=[layer],
@ -582,12 +688,16 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON # Create GeoJSON
geojson_data = [] geojson_data = []
for _, row in gdf_wgs84.iterrows(): for _, row in gdf_wgs84.iterrows():
properties = { properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None, "value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"], "fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
} }
# Add all aggregation values if available # Add all aggregation values if available
if len(ds["aggregations"]) > 1: if len(ds["aggregations"]) > 1:
@ -603,20 +713,23 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
} }
geojson_data.append(feature) geojson_data.append(feature)
# Create pydeck layer # Create pydeck layer with 3D elevation
layer = pdk.Layer( layer = pdk.Layer(
"GeoJsonLayer", "GeoJsonLayer",
geojson_data, geojson_data,
opacity=opacity, opacity=opacity,
stroked=True, stroked=True,
filled=True, filled=True,
extruded=True,
get_fill_color="properties.fill_color", get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80], get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5, line_width_min_pixels=0.5,
pickable=True, pickable=True,
) )
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ArcticDEM # Build tooltip HTML for ArcticDEM
if len(ds["aggregations"]) > 1: if len(ds["aggregations"]) > 1:
@ -663,37 +776,43 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
variables = list(ds.data_vars) variables = list(ds.data_vars)
has_agg = "aggregations" in ds.dims has_agg = "aggregations" in ds.dims
# Top row: Variable, Aggregation (if applicable), and Opacity
if has_agg: if has_agg:
col1, col2, col3, col4 = st.columns([2, 2, 2, 1]) col1, col2, col3 = st.columns([2, 2, 1])
else: with col1:
col1, col2, col3 = st.columns([3, 3, 1]) selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
# Convert time to readable format
time_values = pd.to_datetime(ds["time"].values)
time_options = {str(t): t for t in time_values}
selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
if has_agg:
with col3:
selected_agg = st.selectbox( selected_agg = st.selectbox(
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg" "Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
) )
with col4:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
with col3: with col3:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
else:
col1, col2 = st.columns([4, 1])
with col1:
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
with col2:
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
# Bottom row: Time slider (full width)
time_values = pd.to_datetime(ds["time"].values)
time_labels = [t.strftime("%Y-%m-%d") for t in time_values]
selected_time_idx = st.slider(
"Time",
min_value=0,
max_value=len(time_values) - 1,
value=len(time_values) - 1,
format="",
key=f"era5_{temporal_type}_time_slider",
)
st.caption(f"Selected: {time_labels[selected_time_idx]}")
selected_time = time_values[selected_time_idx]
# Extract data for selected parameters # Extract data for selected parameters
time_val = time_options[selected_time]
if has_agg: if has_agg:
data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg) data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg)
else: else:
data_values = ds[selected_var].sel(time=time_val) data_values = ds[selected_var].sel(time=selected_time)
# Create GeoDataFrame # Create GeoDataFrame
gdf = targets.copy() gdf = targets.copy()
@ -707,7 +826,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
# Add all aggregation values for tooltip if has_agg # Add all aggregation values for tooltip if has_agg
if has_agg and len(ds["aggregations"]) > 1: if has_agg and len(ds["aggregations"]) > 1:
for agg in ds["aggregations"].values: for agg in ds["aggregations"].values:
agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}") agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}")
# Drop dimension columns to avoid conflicts # Drop dimension columns to avoid conflicts
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns] cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
if cols_to_drop: if cols_to_drop:
@ -734,12 +853,16 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values
gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized]
# Create GeoJSON # Create GeoJSON
geojson_data = [] geojson_data = []
for _, row in gdf_wgs84.iterrows(): for _, row in gdf_wgs84.iterrows():
properties = { properties = {
"value": float(row["value"]) if not np.isnan(row["value"]) else None, "value": float(row["value"]) if not np.isnan(row["value"]) else None,
"fill_color": row["fill_color"], "fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
} }
# Add all aggregation values if available # Add all aggregation values if available
if has_agg and len(ds["aggregations"]) > 1: if has_agg and len(ds["aggregations"]) > 1:
@ -755,20 +878,23 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
} }
geojson_data.append(feature) geojson_data.append(feature)
# Create pydeck layer # Create pydeck layer with 3D elevation
layer = pdk.Layer( layer = pdk.Layer(
"GeoJsonLayer", "GeoJsonLayer",
geojson_data, geojson_data,
opacity=opacity, opacity=opacity,
stroked=True, stroked=True,
filled=True, filled=True,
extruded=True,
get_fill_color="properties.fill_color", get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80], get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5, line_width_min_pixels=0.5,
pickable=True, pickable=True,
) )
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0) view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
# Build tooltip HTML for ERA5 # Build tooltip HTML for ERA5
if has_agg and len(ds["aggregations"]) > 1: if has_agg and len(ds["aggregations"]) > 1:

View file

@ -125,32 +125,61 @@ def render_training_data_page():
with st.spinner("Computing dataset statistics..."): with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats() stats = ensemble.get_stats()
# Display target information # High-level summary metrics
col1, col2 = st.columns(2) col1, col2, col3 = st.columns(3)
with col1: with col1:
st.metric(label="Target", value=stats["target"].replace("darts_", "")) st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2: with col2:
st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}") st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Display member statistics # Detailed member statistics in expandable section
st.markdown("**Member Statistics:**") with st.expander("📦 Data Source Details", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
for member, member_stats in stats["members"].items(): # Create metrics for this member
with st.expander(f"📦 {member}", expanded=False): metric_cols = st.columns(4)
col1, col2 = st.columns(2) with metric_cols[0]:
with col1: st.metric("Features", member_stats["num_features"])
st.markdown(f"**Number of Features:** {member_stats['num_features']}") with metric_cols[1]:
st.markdown(f"**Number of Variables:** {member_stats['num_variables']}") st.metric("Variables", member_stats["num_variables"])
with col2: with metric_cols[2]:
st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`") # Display dimensions in a more readable format
dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Display variables as a compact list # Show variables as colored badges
st.markdown(f"**Variables ({member_stats['num_variables']}):**") st.markdown("**Variables:**")
vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]]) vars_html = " ".join(
st.markdown(vars_str) [
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Display total features # Show dimension details
st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}") st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
st.markdown("---") st.markdown("---")