diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 6a54837..85e3671 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -18,7 +18,6 @@ from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.overview_page import render_overview_page from entropice.dashboard.views.training_analysis_page import render_training_analysis_page -from entropice.dashboard.views.training_data_page import render_training_data_page def main(): @@ -28,7 +27,6 @@ def main(): # Setup Navigation overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) data_page = st.Page(render_dataset_page, title="Dataset", icon="📊") - training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖") model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") @@ -38,8 +36,7 @@ def main(): { "Overview": [overview_page], "Data": [data_page], - "Training": [training_data_page, training_analysis_page, autogluon_page], - "Model State": [model_state_page], + "Experiments": [training_analysis_page, autogluon_page, model_state_page], "Inference": [inference_page], } ) diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py deleted file mode 100644 index cb0299a..0000000 --- a/src/entropice/dashboard/plots/source_data.py +++ /dev/null @@ -1,1054 +0,0 @@ -"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5).""" - -import geopandas as gpd -import numpy as np -import pandas as pd -import plotly.graph_objects as go -import pydeck as pdk -import streamlit as st -import xarray as xr - -from entropice.dashboard.utils.colors import get_cmap -from entropice.dashboard.utils.geometry import fix_hex_geometry - -# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations - - -def render_alphaearth_overview(ds: xr.Dataset): - """Render overview statistics for AlphaEarth embeddings data. - - Args: - ds: xarray Dataset containing AlphaEarth embeddings. - - """ - st.subheader("📊 Data Overview") - - # Key metrics - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric("Cells", f"{len(ds['cell_ids']):,}") - - with col2: - st.metric("Embedding Dims", f"{len(ds['band'])}") - - with col3: - years = sorted(ds["year"].values) - st.metric("Years", f"{min(years)}–{max(years)}") - - with col4: - st.metric("Aggregations", f"{len(ds['agg'])}") - - # Show aggregations as badges in an expander - with st.expander("ℹ️ Data Details", expanded=False): - st.markdown("**Spatial Aggregations:**") - aggs = ds["agg"].to_numpy() - aggs_html = " ".join( - [ - f'{a}' - for a in aggs - ] - ) - st.markdown(aggs_html, unsafe_allow_html=True) - - -@st.fragment -def render_alphaearth_plots(ds: xr.Dataset): - """Render interactive plots for AlphaEarth embeddings data. - - Args: - ds: xarray Dataset containing AlphaEarth embeddings. - - """ - st.markdown("---") - st.markdown("**Embedding Distribution by Band**") - - embeddings_data = ds["embeddings"] - - # Select year and aggregation for visualization - col1, col2 = st.columns(2) - with col1: - selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year") - with col2: - selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg") - - # Get data for selected year and aggregation - year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg) - - # Calculate statistics for each band - band_stats = [] - for band_idx in range(len(ds["band"])): - band_data = year_agg_data.isel(band=band_idx).values.flatten() - band_data = band_data[~np.isnan(band_data)] # Remove NaN values - - if len(band_data) > 0: - band_stats.append( - { - "Band": band_idx, - "Mean": float(np.mean(band_data)), - "Std": float(np.std(band_data)), - "Min": float(np.min(band_data)), - "25%": float(np.percentile(band_data, 25)), - "Median": float(np.median(band_data)), - "75%": float(np.percentile(band_data, 75)), - "Max": float(np.max(band_data)), - } - ) - - band_df = pd.DataFrame(band_stats) - - # Create plot showing distribution across bands - fig = go.Figure() - - # Add min/max range first (background) - fig.add_trace( - go.Scatter( - x=band_df["Band"], - y=band_df["Min"], - mode="lines", - line={"color": "lightgray", "width": 1, "dash": "dash"}, - name="Min/Max Range", - showlegend=True, - ) - ) - - fig.add_trace( - go.Scatter( - x=band_df["Band"], - y=band_df["Max"], - mode="lines", - fill="tonexty", - fillcolor="rgba(200, 200, 200, 0.1)", - line={"color": "lightgray", "width": 1, "dash": "dash"}, - showlegend=False, - ) - ) - - # Add std band - fig.add_trace( - go.Scatter( - x=band_df["Band"], - y=band_df["Mean"] - band_df["Std"], - mode="lines", - line={"width": 0}, - showlegend=False, - hoverinfo="skip", - ) - ) - - fig.add_trace( - go.Scatter( - x=band_df["Band"], - y=band_df["Mean"] + band_df["Std"], - mode="lines", - fill="tonexty", - fillcolor="rgba(31, 119, 180, 0.2)", - line={"width": 0}, - name="±1 Std", - ) - ) - - # Add mean line on top - fig.add_trace( - go.Scatter( - x=band_df["Band"], - y=band_df["Mean"], - mode="lines+markers", - name="Mean", - line={"color": "#1f77b4", "width": 2}, - marker={"size": 4}, - ) - ) - - fig.update_layout( - title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})", - xaxis_title="Band", - yaxis_title="Embedding Value", - height=450, - hovermode="x unified", - ) - - st.plotly_chart(fig, width="stretch") - - # Band statistics - with st.expander("📈 Statistics by Embedding Band", expanded=False): - st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:") - - # Calculate statistics for each band - band_stats = [] - for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands - band_data = embeddings_data.isel(band=band_idx) - band_stats.append( - { - "Band": int(band_idx), - "Mean": float(band_data.mean().values), - "Std": float(band_data.std().values), - "Min": float(band_data.min().values), - "Max": float(band_data.max().values), - } - ) - - band_df = pd.DataFrame(band_stats) - st.dataframe(band_df, width="stretch", hide_index=True) - - if len(ds["band"]) > 10: - st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions") - - -def render_arcticdem_overview(ds: xr.Dataset): - """Render overview statistics for ArcticDEM terrain data. - - Args: - ds: xarray Dataset containing ArcticDEM data. - - """ - st.subheader("📊 Data Overview") - - # Key metrics - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Cells", f"{len(ds['cell_ids']):,}") - - with col2: - st.metric("Variables", f"{len(ds.data_vars)}") - - with col3: - st.metric("Aggregations", f"{len(ds['aggregations'])}") - - # Show details in expander - with st.expander("ℹ️ Data Details", expanded=False): - st.markdown("**Spatial Aggregations:**") - aggs = ds["aggregations"].to_numpy() - aggs_html = " ".join( - [ - f'{a}' - for a in aggs - ] - ) - st.markdown(aggs_html, unsafe_allow_html=True) - - # Statistics by variable - st.markdown("---") - st.markdown("**📈 Variable Statistics**") - - var_stats = [] - for var_name in ds.data_vars: - var_data = ds[var_name] - var_stats.append( - { - "Variable": var_name, - "Mean": float(var_data.mean().values), - "Std": float(var_data.std().values), - "Min": float(var_data.min().values), - "Max": float(var_data.max().values), - "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values), - } - ) - - stats_df = pd.DataFrame(var_stats) - st.dataframe(stats_df, width="stretch", hide_index=True) - - -@st.fragment -def render_arcticdem_plots(ds: xr.Dataset): - """Render interactive plots for ArcticDEM terrain data. - - Args: - ds: xarray Dataset containing ArcticDEM data. - - """ - st.markdown("---") - st.markdown("**Variable Distributions**") - - variables = list(ds.data_vars) - - # Select a variable to visualize - selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select") - - if selected_var: - var_data = ds[selected_var] - - # Create histogram - fig = go.Figure() - - for agg in ds["aggregations"].to_numpy(): - agg_data = var_data.sel(aggregations=agg).to_numpy().flatten() - agg_data = agg_data[~np.isnan(agg_data)] - - fig.add_trace( - go.Histogram( - x=agg_data, - name=str(agg), - opacity=0.7, - nbinsx=50, - ) - ) - - fig.update_layout( - title=f"Distribution of {selected_var} by Aggregation", - xaxis_title=selected_var, - yaxis_title="Count", - barmode="overlay", - height=400, - ) - - st.plotly_chart(fig, width="stretch") - - -def render_era5_overview(ds: xr.Dataset, temporal_type: str): - """Render overview statistics for ERA5 climate data. - - Args: - ds: xarray Dataset containing ERA5 data. - temporal_type: One of 'yearly', 'seasonal', 'shoulder'. - - """ - st.subheader("📊 Data Overview") - - # Key metrics - has_agg = "aggregations" in ds.dims - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric("Cells", f"{len(ds['cell_ids']):,}") - - with col2: - st.metric("Variables", f"{len(ds.data_vars)}") - - with col3: - time_values = pd.to_datetime(ds["time"].values) - st.metric( - "Time Steps", - f"{time_values.min().strftime('%Y')}–{time_values.max().strftime('%Y')}", - ) - - with col4: - if has_agg: - st.metric("Aggregations", f"{len(ds['aggregations'])}") - else: - st.metric("Temporal Type", temporal_type.capitalize()) - - # Show details in expander - with st.expander("ℹ️ Data Details", expanded=False): - st.markdown(f"**Temporal Type:** {temporal_type.capitalize()}") - st.markdown( - f"**Date Range:** {time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}" - ) - - if has_agg: - st.markdown("**Spatial Aggregations:**") - aggs = ds["aggregations"].to_numpy() - aggs_html = " ".join( - [ - f'{a}' - for a in aggs - ] - ) - st.markdown(aggs_html, unsafe_allow_html=True) - - # Statistics by variable - st.markdown("---") - st.markdown("**📈 Variable Statistics**") - - var_stats = [] - for var_name in ds.data_vars: - var_data = ds[var_name] - var_stats.append( - { - "Variable": var_name, - "Mean": float(var_data.mean().values), - "Std": float(var_data.std().values), - "Min": float(var_data.min().values), - "Max": float(var_data.max().values), - "Missing %": float((var_data.isnull().sum() / var_data.size * 100).values), - } - ) - - stats_df = pd.DataFrame(var_stats) - st.dataframe(stats_df, width="stretch", hide_index=True) - - -@st.fragment -def render_era5_plots(ds: xr.Dataset, temporal_type: str): - """Render interactive plots for ERA5 climate data. - - Args: - ds: xarray Dataset containing ERA5 data. - temporal_type: One of 'yearly', 'seasonal', 'shoulder'. - - """ - st.markdown("---") - st.markdown("**Temporal Trends**") - - variables = list(ds.data_vars) - has_agg = "aggregations" in ds.dims - - if has_agg: - col1, col2, col3 = st.columns([2, 2, 1]) - with col1: - selected_var = st.selectbox( - "Select variable to visualize", - options=variables, - key=f"era5_{temporal_type}_var_select", - ) - with col2: - selected_agg = st.selectbox( - "Aggregation", - options=ds["aggregations"].values, - key=f"era5_{temporal_type}_agg_select", - ) - with col3: - show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") - show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax") - else: - col1, col2 = st.columns([3, 1]) - with col1: - selected_var = st.selectbox( - "Select variable to visualize", - options=variables, - key=f"era5_{temporal_type}_var_select", - ) - with col2: - show_std = st.checkbox("Show ±1 Std", value=True, key=f"era5_{temporal_type}_show_std") - show_minmax = st.checkbox("Show Min/Max", value=False, key=f"era5_{temporal_type}_show_minmax") - - if selected_var: - var_data = ds[selected_var] - - # Calculate statistics over space for each time step - time_values = pd.to_datetime(ds["time"].to_numpy()) - - if has_agg: - # Select specific aggregation, then calculate stats over cells - var_data_agg = var_data.sel(aggregations=selected_agg) - time_mean = var_data_agg.mean(dim="cell_ids").to_numpy() - time_std = var_data_agg.std(dim="cell_ids").to_numpy() - time_min = var_data_agg.min(dim="cell_ids").to_numpy() - time_max = var_data_agg.max(dim="cell_ids").to_numpy() - else: - time_mean = var_data.mean(dim="cell_ids").to_numpy() - time_std = var_data.std(dim="cell_ids").to_numpy() - time_min = var_data.min(dim="cell_ids").to_numpy() - time_max = var_data.max(dim="cell_ids").to_numpy() - - fig = go.Figure() - - # Add min/max range first (background) - optional - if show_minmax: - fig.add_trace( - go.Scatter( - x=time_values, - y=time_min, - mode="lines", - line={"color": "lightgray", "width": 1, "dash": "dash"}, - name="Min/Max Range", - showlegend=True, - ) - ) - - fig.add_trace( - go.Scatter( - x=time_values, - y=time_max, - mode="lines", - fill="tonexty", - fillcolor="rgba(200, 200, 200, 0.1)", - line={"color": "lightgray", "width": 1, "dash": "dash"}, - showlegend=False, - ) - ) - - # Add std band - optional - if show_std: - fig.add_trace( - go.Scatter( - x=time_values, - y=time_mean - time_std, - mode="lines", - line={"width": 0}, - showlegend=False, - hoverinfo="skip", - ) - ) - - fig.add_trace( - go.Scatter( - x=time_values, - y=time_mean + time_std, - mode="lines", - fill="tonexty", - fillcolor="rgba(31, 119, 180, 0.2)", - line={"width": 0}, - name="±1 Std", - ) - ) - - # Add mean line on top - fig.add_trace( - go.Scatter( - x=time_values, - y=time_mean, - mode="lines+markers", - name="Mean", - line={"color": "#1f77b4", "width": 2}, - marker={"size": 4}, - ) - ) - - title_suffix = f" (Aggregation: {selected_agg})" if has_agg else "" - fig.update_layout( - title=f"Temporal Trend of {selected_var} (Spatial Statistics){title_suffix}", - xaxis_title="Time", - yaxis_title=selected_var, - height=400, - hovermode="x unified", - ) - - st.plotly_chart(fig, width="stretch") - - -@st.fragment -def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): - """Render interactive pydeck map for AlphaEarth embeddings. - - Args: - ds: xarray Dataset containing AlphaEarth data. - targets: GeoDataFrame with geometry for each cell. - grid: Grid type ('hex' or 'healpix'). - - """ - st.subheader("🗺️ AlphaEarth Spatial Distribution") - - # Year slider (full width) - years = sorted(ds["year"].values) - selected_year = st.slider( - "Year", - min_value=int(years[0]), - max_value=int(years[-1]), - value=int(years[-1]), - step=1, - key="alphaearth_year", - ) - - # Other controls - col1, col2, col3 = st.columns([2, 2, 1]) - - with col1: - selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg") - - with col2: - selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band") - - with col3: - opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity") - - # Extract data for selected parameters - data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band) - - # Create GeoDataFrame - gdf = targets.copy() - gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] - gdf = gdf.set_index("cell_id") - - # Add values - values_df = data_values.to_dataframe(name="value") - gdf = gdf.join(values_df, how="inner") - - # Convert to WGS84 first - gdf_wgs84 = gdf.to_crs("EPSG:4326") - - # Fix geometries after CRS conversion - if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) - - # Normalize values for color mapping - values = gdf_wgs84["value"].to_numpy() - vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers - normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) - - # Apply colormap - cmap = get_cmap("embeddings") - colors = [cmap(val) for val in normalized] - gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] - - # Set elevation based on normalized values - gdf_wgs84["elevation"] = normalized - - # Create GeoJSON - geojson_data = [] - for _, row in gdf_wgs84.iterrows(): - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": { - "value": float(row["value"]), - "fill_color": row["fill_color"], - "elevation": float(row["elevation"]), - }, - } - geojson_data.append(feature) - - # Create pydeck layer with 3D elevation - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=True, - get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - get_elevation="properties.elevation", - elevation_scale=500000, - line_width_min_pixels=0.5, - pickable=True, - ) - - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) - - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": "Value: {value}", - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - st.pydeck_chart(deck) - - # Show statistics - st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}") - - -@st.fragment -def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): - """Render interactive pydeck map for ArcticDEM terrain data. - - Args: - ds: xarray Dataset containing ArcticDEM data. - targets: GeoDataFrame with geometry for each cell. - grid: Grid type ('hex' or 'healpix'). - - """ - st.subheader("🗺️ ArcticDEM Spatial Distribution") - - # Controls - variables = list(ds.data_vars) - col1, col2, col3 = st.columns([3, 3, 1]) - - with col1: - selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var") - - with col2: - selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg") - - with col3: - opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity") - - # Extract data for selected parameters - data_values = ds[selected_var].sel(aggregations=selected_agg) - - # Create GeoDataFrame - gdf = targets.copy() - gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] - gdf = gdf.set_index("cell_id") - - # Add values - values_df = data_values.to_dataframe(name="value") - gdf = gdf.join(values_df, how="inner") - - # Add all aggregation values for tooltip - if len(ds["aggregations"]) > 1: - for agg in ds["aggregations"].values: - agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}") - # Drop the aggregations column to avoid conflicts - if "aggregations" in agg_data.columns: - agg_data = agg_data.drop(columns=["aggregations"]) - gdf = gdf.join(agg_data, how="left") - - # Convert to WGS84 first - gdf_wgs84 = gdf.to_crs("EPSG:4326") - - # Fix geometries after CRS conversion - if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) - - # Normalize values for color mapping - values = gdf_wgs84["value"].values - values_clean = values[~np.isnan(values)] - - if len(values_clean) > 0: - vmin, vmax = np.nanpercentile(values_clean, [2, 98]) - normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) - - # Apply colormap - cmap = get_cmap(f"arcticdem_{selected_var}") - colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] - gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] - - # Set elevation based on normalized values - gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized] - - # Create GeoJSON - geojson_data = [] - for _, row in gdf_wgs84.iterrows(): - properties = { - "value": float(row["value"]) if not np.isnan(row["value"]) else None, - "fill_color": row["fill_color"], - "elevation": float(row["elevation"]), - } - # Add all aggregation values if available - if len(ds["aggregations"]) > 1: - for agg in ds["aggregations"].values: - agg_col = f"agg_{agg}" - if agg_col in row.index: - properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None - - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": properties, - } - geojson_data.append(feature) - - # Create pydeck layer with 3D elevation - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=True, - get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - get_elevation="properties.elevation", - elevation_scale=500000, - line_width_min_pixels=0.5, - pickable=True, - ) - - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) - - # Build tooltip HTML for ArcticDEM - if len(ds["aggregations"]) > 1: - tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"] - tooltip_lines.append("All aggregations:
") - for agg in ds["aggregations"].values: - tooltip_lines.append(f"  {agg}: {{agg_{agg}}}
") - tooltip_html = "".join(tooltip_lines) - else: - tooltip_html = f"{selected_var}: {{value}}" - - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": tooltip_html, - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - st.pydeck_chart(deck) - - # Show statistics - st.caption( - f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | " - f"Std: {np.nanstd(values_clean):.2f}" - ) - else: - st.warning("No valid data available for selected parameters") - - -@st.fragment -def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str): - """Render interactive pydeck map for grid cell areas. - - Args: - grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio. - grid: Grid type ('hex' or 'healpix'). - - """ - st.subheader("🗺️ Grid Cell Areas Distribution") - - # Controls - col1, col2 = st.columns([3, 1]) - - with col1: - area_metric = st.selectbox( - "Area Metric", - options=["cell_area", "land_area", "water_area", "land_ratio"], - format_func=lambda x: x.replace("_", " ").title(), - key="areas_metric", - ) - - with col2: - opacity = st.slider( - "Opacity", - min_value=0.1, - max_value=1.0, - value=0.7, - step=0.1, - key="areas_map_opacity", - ) - - # Create GeoDataFrame - gdf = grid_gdf.copy() - - # Convert to WGS84 first - gdf_wgs84 = gdf.to_crs("EPSG:4326") - - # Fix geometries after CRS conversion - if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) - - # Get values for the selected metric - values = gdf_wgs84[area_metric].to_numpy() - - # Normalize values for color mapping - vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers - normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) - - # Apply colormap based on metric type - if area_metric == "land_ratio": - cmap = get_cmap("terrain") # Different colormap for ratio - else: - cmap = get_cmap("terrain") - - colors = [cmap(val) for val in normalized] - gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] - - # Set elevation based on normalized values for 3D visualization - gdf_wgs84["elevation"] = normalized - - # Create GeoJSON - geojson_data = [] - for _, row in gdf_wgs84.iterrows(): - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": { - "cell_area": f"{float(row['cell_area']):.2f}", - "land_area": f"{float(row['land_area']):.2f}", - "water_area": f"{float(row['water_area']):.2f}", - "land_ratio": f"{float(row['land_ratio']):.2%}", - "fill_color": row["fill_color"], - "elevation": float(row["elevation"]), - }, - } - geojson_data.append(feature) - - # Create pydeck layer with 3D elevation - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=True, - get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - get_elevation="properties.elevation", - elevation_scale=500000, - line_width_min_pixels=0.5, - pickable=True, - ) - - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) - - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": "Cell Area: {cell_area} km²
" - "Land Area: {land_area} km²
" - "Water Area: {water_area} km²
" - "Land Ratio: {land_ratio}", - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - st.pydeck_chart(deck) - - # Show statistics - st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}") - - # Show additional info - st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.") - - -@st.fragment -def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str): - """Render interactive pydeck map for ERA5 climate data. - - Args: - ds: xarray Dataset containing ERA5 data. - targets: GeoDataFrame with geometry for each cell. - grid: Grid type ('hex' or 'healpix'). - temporal_type: One of 'yearly', 'seasonal', 'shoulder'. - - """ - st.subheader("🗺️ ERA5 Spatial Distribution") - - # Controls - variables = list(ds.data_vars) - has_agg = "aggregations" in ds.dims - - # Top row: Variable, Aggregation (if applicable), and Opacity - if has_agg: - col1, col2, col3 = st.columns([2, 2, 1]) - with col1: - selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") - with col2: - selected_agg = st.selectbox( - "Aggregation", - options=ds["aggregations"].values, - key=f"era5_{temporal_type}_agg", - ) - with col3: - opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") - else: - col1, col2 = st.columns([4, 1]) - with col1: - selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var") - with col2: - opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity") - - # Bottom row: Time slider (full width) - time_values = pd.to_datetime(ds["time"].values) - time_labels = [t.strftime("%Y-%m-%d") for t in time_values] - selected_time_idx = st.slider( - "Time", - min_value=0, - max_value=len(time_values) - 1, - value=len(time_values) - 1, - format="", - key=f"era5_{temporal_type}_time_slider", - ) - st.caption(f"Selected: {time_labels[selected_time_idx]}") - selected_time = time_values[selected_time_idx] - - # Extract data for selected parameters - if has_agg: - data_values = ds[selected_var].sel(time=selected_time, aggregations=selected_agg) - else: - data_values = ds[selected_var].sel(time=selected_time) - - # Create GeoDataFrame - gdf = targets.copy() - gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)] - gdf = gdf.set_index("cell_id") - - # Add values - values_df = data_values.to_dataframe(name="value") - gdf = gdf.join(values_df, how="inner") - - # Add all aggregation values for tooltip if has_agg - if has_agg and len(ds["aggregations"]) > 1: - for agg in ds["aggregations"].values: - agg_data = ds[selected_var].sel(time=selected_time, aggregations=agg).to_dataframe(name=f"agg_{agg}") - # Drop dimension columns to avoid conflicts - cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns] - if cols_to_drop: - agg_data = agg_data.drop(columns=cols_to_drop) - gdf = gdf.join(agg_data, how="left") - - # Convert to WGS84 first - gdf_wgs84 = gdf.to_crs("EPSG:4326") - - # Fix geometries after CRS conversion - if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) - - # Normalize values for color mapping - values = gdf_wgs84["value"].values - values_clean = values[~np.isnan(values)] - - if len(values_clean) > 0: - vmin, vmax = np.nanpercentile(values_clean, [2, 98]) - normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) - - # Apply colormap - use variable-specific colors - cmap = get_cmap(f"era5_{selected_var}") - colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized] - gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] - - # Set elevation based on normalized values - gdf_wgs84["elevation"] = [val if not np.isnan(val) else 0 for val in normalized] - - # Create GeoJSON - geojson_data = [] - for _, row in gdf_wgs84.iterrows(): - properties = { - "value": float(row["value"]) if not np.isnan(row["value"]) else None, - "fill_color": row["fill_color"], - "elevation": float(row["elevation"]), - } - # Add all aggregation values if available - if has_agg and len(ds["aggregations"]) > 1: - for agg in ds["aggregations"].values: - agg_col = f"agg_{agg}" - if agg_col in row.index: - properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None - - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": properties, - } - geojson_data.append(feature) - - # Create pydeck layer with 3D elevation - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=True, - get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - get_elevation="properties.elevation", - elevation_scale=500000, - line_width_min_pixels=0.5, - pickable=True, - ) - - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) - - # Build tooltip HTML for ERA5 - if has_agg and len(ds["aggregations"]) > 1: - tooltip_lines = [f"{selected_var} (selected: {selected_agg}): {{value}}
"] - tooltip_lines.append("All aggregations:
") - for agg in ds["aggregations"].values: - tooltip_lines.append(f"  {agg}: {{agg_{agg}}}
") - tooltip_html = "".join(tooltip_lines) - else: - tooltip_html = f"{selected_var}: {{value}}" - - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": tooltip_html, - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - st.pydeck_chart(deck) - - # Show statistics - st.caption( - f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | " - f"Std: {np.nanstd(values_clean):.4f}" - ) - else: - st.warning("No valid data available for selected parameters") diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py deleted file mode 100644 index 966c64a..0000000 --- a/src/entropice/dashboard/plots/training_data.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Plotting functions for training data visualizations.""" - -import geopandas as gpd -import pandas as pd -import plotly.graph_objects as go -import pydeck as pdk -import streamlit as st - -from entropice.dashboard.utils.colors import get_palette -from entropice.dashboard.utils.geometry import fix_hex_geometry -from entropice.ml.dataset import CategoricalTrainingDataset - - -def render_all_distribution_histograms( - train_data_dict: dict[str, CategoricalTrainingDataset], -): - """Render histograms for all three tasks side by side. - - Args: - train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. - - """ - st.subheader("📊 Target Distribution by Task") - - # Create a 3-column layout for the three tasks - cols = st.columns(3) - - tasks = ["binary", "count", "density"] - task_titles = { - "binary": "Binary Classification", - "count": "Count Classification", - "density": "Density Classification", - } - - for idx, task in enumerate(tasks): - dataset = train_data_dict[task] - categories = dataset.y.binned.cat.categories.tolist() - colors = get_palette(task, len(categories)) - - with cols[idx]: - st.markdown(f"**{task_titles[task]}**") - - # Create histogram data - counts_df = pd.DataFrame( - { - "Category": categories, - "Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories], - "Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories], - } - ) - - # Create stacked bar chart - fig = go.Figure() - - fig.add_trace( - go.Bar( - name="Train", - x=counts_df["Category"], - y=counts_df["Train"], - marker_color=colors, - opacity=0.9, - text=counts_df["Train"], - textposition="inside", - textfont={"size": 10, "color": "white"}, - ) - ) - - fig.add_trace( - go.Bar( - name="Test", - x=counts_df["Category"], - y=counts_df["Test"], - marker_color=colors, - opacity=0.6, - text=counts_df["Test"], - textposition="inside", - textfont={"size": 10, "color": "white"}, - ) - ) - - fig.update_layout( - barmode="group", - height=400, - margin={"l": 20, "r": 20, "t": 20, "b": 20}, - showlegend=True, - legend={ - "orientation": "h", - "yanchor": "bottom", - "y": 1.02, - "xanchor": "right", - "x": 1, - }, - xaxis_title=None, - yaxis_title="Count", - xaxis={"tickangle": -45}, - ) - - st.plotly_chart(fig, width="stretch") - - # Show summary statistics - total = len(dataset) - train_pct = (dataset.split == "train").sum() / total * 100 - test_pct = (dataset.split == "test").sum() / total * 100 - - st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%") - - -def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task): - """Assign colors to geodataframe based on the selected color mode. - - Args: - gdf: GeoDataFrame to add colors to - color_mode: One of 'target_class' or 'split' - dataset: CategoricalTrainingDataset - selected_task: Task name for color palette selection - - Returns: - GeoDataFrame with 'fill_color' column added - - """ - if color_mode == "target_class": - categories = dataset.y.binned.cat.categories.tolist() - colors_palette = get_palette(selected_task, len(categories)) - - # Create color mapping - color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)} - gdf["color"] = gdf["target_class"].map(color_map) - - # Convert hex colors to RGB - def hex_to_rgb(hex_color): - hex_color = hex_color.lstrip("#") - return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)] - - gdf["fill_color"] = gdf["color"].apply(hex_to_rgb) - - elif color_mode == "split": - split_colors = { - "train": [66, 135, 245], - "test": [245, 135, 66], - } # Blue # Orange - gdf["fill_color"] = gdf["split"].map(split_colors) - - return gdf - - -@st.fragment -def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]): - """Render a pydeck spatial map showing training data distribution with interactive controls. - - This is a Streamlit fragment that reruns independently when users interact with the - visualization controls (color mode and opacity), without re-running the entire page. - - Args: - train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. - - """ - st.subheader("🗺️ Spatial Distribution Map") - - # Create controls in columns - col1, col2 = st.columns([3, 1]) - - with col1: - vis_mode = st.selectbox( - "Visualization mode", - options=["binary", "count", "density", "split"], - format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split", - key="spatial_map_mode", - ) - - with col2: - opacity = st.slider( - "Opacity", - min_value=0.1, - max_value=1.0, - value=0.7, - step=0.1, - key="spatial_map_opacity", - ) - - # Determine which task dataset to use and color mode - if vis_mode == "split": - # Use binary dataset for split visualization - dataset = train_data_dict["binary"] - color_mode = "split" - selected_task = "binary" - else: - # Use the selected task - dataset = train_data_dict[vis_mode] - color_mode = "target_class" - selected_task = vis_mode - - # Prepare data for visualization - dataset.dataset should already be a GeoDataFrame - gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment] - - # Fix antimeridian issues - gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry) - - # Add binned labels and split information from current dataset - gdf["target_class"] = dataset.y.binned.to_numpy() - gdf["split"] = dataset.split.to_numpy() - gdf["raw_value"] = dataset.z.to_numpy() - - # Add information from all three tasks for tooltip - gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy() - gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy() - gdf["count_raw"] = train_data_dict["count"].z.to_numpy() - gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy() - gdf["density_raw"] = train_data_dict["density"].z.to_numpy() - - # Convert to WGS84 for pydeck - gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment] - - # Assign colors based on the selected mode - gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task) - - # Convert to GeoJSON format and add elevation for 3D visualization - geojson_data = [] - # Normalize raw values for elevation (only for count and density) - use_elevation = vis_mode in ["count", "density"] - if use_elevation: - raw_values = gdf_wgs84["raw_value"] - min_val, max_val = raw_values.min(), raw_values.max() - # Normalize to 0-1 range for better 3D visualization - if max_val > min_val: - gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0) - else: - gdf_wgs84["elevation"] = 0 - - for _, row in gdf_wgs84.iterrows(): - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": { - "target_class": str(row["target_class"]), - "split": str(row["split"]), - "raw_value": float(row["raw_value"]), - "fill_color": row["fill_color"], - "elevation": float(row["elevation"]) if use_elevation else 0, - "binary_label": str(row["binary_label"]), - "count_category": str(row["count_category"]), - "count_raw": int(row["count_raw"]), - "density_category": str(row["density_category"]), - "density_raw": f"{float(row['density_raw']):.4f}", - }, - } - geojson_data.append(feature) - - # Create pydeck layer - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=use_elevation, - wireframe=False, - get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - line_width_min_pixels=0.5, - get_elevation="properties.elevation" if use_elevation else 0, - elevation_scale=500000, # Scale normalized values (0-1) to 500km height - pickable=True, - ) - - # Set initial view state (centered on the Arctic) - # Adjust pitch and zoom based on whether we're using elevation - view_state = pdk.ViewState( - latitude=70, - longitude=0, - zoom=2 if not use_elevation else 1.5, - pitch=0 if not use_elevation else 45, - ) - - # Create deck - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": "Binary: {binary_label}
" - "Count Category: {count_category}
" - "Count Raw: {count_raw}
" - "Density Category: {density_category}
" - "Density Raw: {density_raw}
" - "Split: {split}", - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - # Render the map - st.pydeck_chart(deck) - - # Show info about 3D visualization - if use_elevation: - st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.") - - # Add legend - with st.expander("Legend", expanded=True): - if color_mode == "target_class": - st.markdown("**Target Classes:**") - categories = dataset.y.binned.cat.categories.tolist() - colors_palette = get_palette(selected_task, len(categories)) - intervals = dataset.y.intervals - - # For count and density tasks, show intervals - if selected_task in ["count", "density"]: - for i, cat in enumerate(categories): - color = colors_palette[i] - interval_min, interval_max = intervals[i] - - # Format interval display - if interval_min is None or interval_max is None: - interval_str = "" - elif selected_task == "count": - # Integer values for count - if interval_min == interval_max: - interval_str = f" ({int(interval_min)})" - else: - interval_str = f" ({int(interval_min)}-{int(interval_max)})" - else: # density - # Percentage values for density - if interval_min == interval_max: - interval_str = f" ({interval_min * 100:.4f}%)" - else: - interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)" - - st.markdown( - f'
' - f'
' - f"{cat}{interval_str}
", - unsafe_allow_html=True, - ) - else: - # Binary task: use original column layout - legend_cols = st.columns(len(categories)) - for i, cat in enumerate(categories): - with legend_cols[i]: - color = colors_palette[i] - st.markdown( - f'
' - f'
' - f"{cat}
", - unsafe_allow_html=True, - ) - if use_elevation: - st.markdown("---") - st.markdown("**Elevation (3D):**") - min_val = gdf_wgs84["raw_value"].min() - max_val = gdf_wgs84["raw_value"].max() - st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)") - elif color_mode == "split": - st.markdown("**Data Split:**") - legend_html = ( - '
' - '
' - '
' - "Train
" - '
' - '
' - "Test
" - ) - st.markdown(legend_html, unsafe_allow_html=True) diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py index 8c01478..b19305d 100644 --- a/src/entropice/dashboard/sections/dataset_statistics.py +++ b/src/entropice/dashboard/sections/dataset_statistics.py @@ -431,7 +431,7 @@ def _render_aggregation_selection( if not submitted: st.info("👆 Click 'Apply Aggregation Filters' to update the configuration") - st.stop() + return dimension_filters return dimension_filters diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py index a5088eb..c8a59a4 100644 --- a/src/entropice/dashboard/sections/experiment_results.py +++ b/src/entropice/dashboard/sections/experiment_results.py @@ -103,64 +103,63 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: ) # Expandable details for each result - st.subheader("Individual Experiment Details") + with st.expander("Show Individual Experiment Details", expanded=False): + for tr in filtered_results: + tr_info = tr.display_info + display_name = tr_info.get_display_name("model_first") + with st.expander(display_name): + col1, col2 = st.columns([1, 2]) - for tr in filtered_results: - tr_info = tr.display_info - display_name = tr_info.get_display_name("model_first") - with st.expander(display_name): - col1, col2 = st.columns([1, 2]) + with col1: + grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) + st.write( + "**Configuration:**\n" + f"- **Experiment:** {tr.experiment}\n" + f"- **Task:** {tr.settings.task}\n" + f"- **Target:** {tr.settings.target}\n" + f"- **Model:** {tr.settings.model}\n" + f"- **Grid:** {grid_config.display_name}\n" + f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n" + f"- **Temporal Mode:** {tr.settings.temporal_mode}\n" + f"- **Members:** {', '.join(tr.settings.members)}\n" + f"- **CV Splits:** {tr.settings.cv_splits}\n" + f"- **Classes:** {tr.settings.classes}\n" + ) - with col1: - grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) - st.write( - "**Configuration:**\n" - f"- **Experiment:** {tr.experiment}\n" - f"- **Task:** {tr.settings.task}\n" - f"- **Target:** {tr.settings.target}\n" - f"- **Model:** {tr.settings.model}\n" - f"- **Grid:** {grid_config.display_name}\n" - f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n" - f"- **Temporal Mode:** {tr.settings.temporal_mode}\n" - f"- **Members:** {', '.join(tr.settings.members)}\n" - f"- **CV Splits:** {tr.settings.cv_splits}\n" - f"- **Classes:** {tr.settings.classes}\n" - ) + file_str = "\n**Files:**\n" + for file in tr.files: + if file.name == "search_settings.toml": + file_str += f"- ⚙️ `{file.name}`\n" + elif file.name == "best_estimator_model.pkl": + file_str += f"- 🧮 `{file.name}`\n" + elif file.name == "search_results.parquet": + file_str += f"- 📊 `{file.name}`\n" + elif file.name == "predicted_probabilities.parquet": + file_str += f"- 🎯 `{file.name}`\n" + else: + file_str += f"- 📄 `{file.name}`\n" + st.write(file_str) + with col2: + st.write("**CV Score Summary:**") - file_str = "\n**Files:**\n" - for file in tr.files: - if file.name == "search_settings.toml": - file_str += f"- ⚙️ `{file.name}`\n" - elif file.name == "best_estimator_model.pkl": - file_str += f"- 🧮 `{file.name}`\n" - elif file.name == "search_results.parquet": - file_str += f"- 📊 `{file.name}`\n" - elif file.name == "predicted_probabilities.parquet": - file_str += f"- 🎯 `{file.name}`\n" + # Extract all test scores + metric_df = tr.get_metric_dataframe() + if metric_df is not None: + st.dataframe(metric_df, width="stretch", hide_index=True) else: - file_str += f"- 📄 `{file.name}`\n" - st.write(file_str) - with col2: - st.write("**CV Score Summary:**") + st.write("No test scores found in results.") - # Extract all test scores - metric_df = tr.get_metric_dataframe() - if metric_df is not None: - st.dataframe(metric_df, width="stretch", hide_index=True) - else: - st.write("No test scores found in results.") + # Show parameter space explored + if "initial_K" in tr.results.columns: # Common parameter + st.write("\n**Parameter Ranges Explored:**") + for param in ["initial_K", "eps_cl", "eps_e"]: + if param in tr.results.columns: + min_val = tr.results[param].min() + max_val = tr.results[param].max() + unique_vals = tr.results[param].nunique() + st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") - # Show parameter space explored - if "initial_K" in tr.results.columns: # Common parameter - st.write("\n**Parameter Ranges Explored:**") - for param in ["initial_K", "eps_cl", "eps_e"]: - if param in tr.results.columns: - min_val = tr.results[param].min() - max_val = tr.results[param].max() - unique_vals = tr.results[param].nunique() - st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") - - with st.expander("Show CV Results DataFrame"): + st.write("**CV Results DataFrame:**") st.dataframe(tr.results, width="stretch", hide_index=True) - st.write(f"\n**Path:** `{tr.path}`") + st.write(f"\n**Path:** `{tr.path}`") diff --git a/src/entropice/dashboard/sections/storage_statistics.py b/src/entropice/dashboard/sections/storage_statistics.py new file mode 100644 index 0000000..7ddf4e3 --- /dev/null +++ b/src/entropice/dashboard/sections/storage_statistics.py @@ -0,0 +1,163 @@ +"""Storage Statistics Section for Entropice Dashboard.""" + +import pandas as pd +import plotly.graph_objects as go +import streamlit as st + +from entropice.dashboard.utils.loaders import StorageInfo, load_storage_statistics +from entropice.utils.paths import DATA_DIR + + +def _format_bytes(bytes_value: int) -> str: + """Format bytes into human-readable string.""" + value = float(bytes_value) + for unit in ["B", "KB", "MB", "GB", "TB"]: + if value < 1024.0: + return f"{value:.2f} {unit}" + value /= 1024.0 + return f"{value:.2f} PB" + + +def _create_storage_bar_chart(storage_infos: list[StorageInfo]) -> go.Figure: + """Create a horizontal bar chart showing storage usage by subdirectory.""" + if not storage_infos: + return go.Figure() + + # Prepare data + names = [info.name for info in storage_infos] + sizes = [info.size_bytes / (1024**3) for info in storage_infos] # Convert to GB + file_counts = [info.file_count for info in storage_infos] + + # Create figure + fig = go.Figure() + + # Add bar trace + fig.add_trace( + go.Bar( + y=names, + x=sizes, + orientation="h", + text=[f"{s:.2f} GB" for s in sizes], + textposition="auto", + hovertemplate="%{y}
Size: %{x:.2f} GB
Files: %{customdata:,}", + customdata=file_counts, + marker={ + "color": sizes, + "colorscale": "Blues", + "showscale": False, + }, + ) + ) + + # Update layout + fig.update_layout( + title="Storage Usage by Subdirectory", + xaxis_title="Size (GB)", + yaxis_title="Directory", + height=max(400, len(names) * 40), # Dynamic height based on number of directories + showlegend=False, + margin={"l": 200, "r": 50, "t": 50, "b": 50}, + ) + + return fig + + +def _create_storage_pie_chart(storage_infos: list[StorageInfo]) -> go.Figure: + """Create a pie chart showing storage distribution.""" + if not storage_infos: + return go.Figure() + + # Prepare data + names = [info.name for info in storage_infos] + sizes = [info.size_bytes for info in storage_infos] + + # Create figure + fig = go.Figure( + data=[ + go.Pie( + labels=names, + values=sizes, + textinfo="label+percent", + hovertemplate="%{label}
Size: %{customdata}
%{percent}", + customdata=[info.display_size for info in storage_infos], + ) + ] + ) + + # Update layout + fig.update_layout( + title="Storage Distribution", + height=500, + ) + + return fig + + +def render_storage_statistics(): + """Render the storage statistics section showing disk usage for DATA_DIR subdirectories.""" + st.header("💾 Storage Statistics") + + st.markdown( + f""" + This section shows the disk usage of subdirectories in the data directory: + **`{DATA_DIR}`** + + Data is collected using [dust](https://github.com/bootandy/dust), a modern disk usage analyzer. + Statistics are cached for 5 minutes to reduce overhead. + """ + ) + + # Load storage statistics + with st.spinner("Analyzing storage usage..."): + storage_infos, total_size, total_files = load_storage_statistics() + + if not storage_infos: + st.warning("No storage data available. The data directory may be empty or inaccessible.") + return + + # Display summary metrics + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Storage Used", _format_bytes(total_size)) + with col2: + st.metric("Total Files", f"{total_files:,}") + with col3: + st.metric("Number of Subdirectories", len(storage_infos)) + + # Create tabs for different visualizations + tab1, tab2, tab3 = st.tabs(["📊 Bar Chart", "🥧 Pie Chart", "📋 Detailed Table"]) + + with tab1: + st.plotly_chart(_create_storage_bar_chart(storage_infos), use_container_width=True) + + with tab2: + st.plotly_chart(_create_storage_pie_chart(storage_infos), use_container_width=True) + + with tab3: + # Create DataFrame for detailed view + df = pd.DataFrame( + [ + { + "Directory": info.name, + "Size": info.display_size, + "Size (Bytes)": info.size_bytes, + "Files": info.file_count, + "Percentage": f"{(info.size_bytes / total_size * 100):.2f}%", + } + for info in storage_infos + ] + ) + + st.dataframe( + df[["Directory", "Size", "Files", "Percentage"]], + use_container_width=True, + hide_index=True, + ) + + # Add download button for detailed data + st.download_button( + label="📥 Download Storage Statistics (CSV)", + data=df.to_csv(index=False), + file_name="entropice_storage_statistics.csv", + mime="text/csv", + ) diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index c0b0c1d..fe34802 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -1,6 +1,8 @@ """Data utilities for Entropice dashboard.""" +import json import pickle +import subprocess from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -252,3 +254,148 @@ def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Ta for task in all_tasks: train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task) return train_data_dict + + +@dataclass +class StorageInfo: + """Storage information for a directory.""" + + name: str + size_bytes: int + file_count: int + display_size: str + + +def _parse_size_to_bytes(size_str: str) -> int: + """Convert dust's human-readable size string to bytes. + + Examples: "92K" -> 92*1024, "1.5M" -> 1.5*1024*1024, "928B" -> 928 + """ + size_str = size_str.strip().upper() + if not size_str: + return 0 + + # Extract numeric part and unit + numeric_part = "" + unit = "" + for char in size_str: + if char.isdigit() or char == ".": + numeric_part += char + else: + unit += char + + try: + value = float(numeric_part) if numeric_part else 0 + except ValueError: + return 0 + + # Convert based on unit + unit = unit.strip() + multipliers = { + "B": 1, + "K": 1024, + "M": 1024**2, + "G": 1024**3, + "T": 1024**4, + "P": 1024**5, + } + + return int(value * multipliers.get(unit, 1)) + + +def _run_dust_command(data_dir: Path, for_files: bool = False) -> dict | None: + """Run dust command and return parsed JSON output. + + Args: + data_dir: Directory to analyze + for_files: If True, count files (-f flag); if False, count disk space + + Returns: + Parsed JSON dict or None if command failed + + """ + cmd = ["dust", "-j", "-d", "1"] + if for_files: + cmd.append("-f") + cmd.append(str(data_dir)) + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + return None + return json.loads(result.stdout) + except (subprocess.TimeoutExpired, json.JSONDecodeError): + return None + + +def _build_file_counts_lookup(files_data: dict | None) -> dict[str, int]: + """Build lookup dict for file counts from dust JSON output.""" + file_counts = {} + if files_data and "children" in files_data: + for child in files_data["children"]: + name = Path(child["name"]).name + count_str = child.get("size", "0") + file_counts[name] = _parse_size_to_bytes(count_str) + return file_counts + + +@st.cache_data(ttl=300) # Cache for 5 minutes +def load_storage_statistics() -> tuple[list[StorageInfo], int, int]: + """Load storage statistics for DATA_DIR subdirectories using dust. + + Returns: + Tuple of (subdirectory stats list, total size in bytes, total file count) + + """ + data_dir = entropice.utils.paths.DATA_DIR + + if not data_dir.exists(): + return [], 0, 0 + + try: + # Run dust for disk space and file counts + space_data = _run_dust_command(data_dir, for_files=False) + files_data = _run_dust_command(data_dir, for_files=True) + + if not space_data: + st.warning("Failed to get storage statistics from dust") + return [], 0, 0 + + # Build lookup dict for file counts + file_counts = _build_file_counts_lookup(files_data) + + # Extract subdirectory information from space data + storage_infos = [] + total_size = 0 + total_files = 0 + + if "children" in space_data: + for child in space_data["children"]: + full_path = child.get("name", "") + dir_name = Path(full_path).name + size_str = child.get("size", "0") + size_bytes = _parse_size_to_bytes(size_str) + file_count = file_counts.get(dir_name, 0) + + storage_infos.append( + StorageInfo( + name=dir_name, + size_bytes=size_bytes, + file_count=file_count, + display_size=size_str, + ) + ) + total_size += size_bytes + total_files += file_count + + # Sort by size descending + storage_infos.sort(key=lambda x: x.size_bytes, reverse=True) + + return storage_infos, total_size, total_files + + except FileNotFoundError: + st.error("dust command not found. Please install dust: https://github.com/bootandy/dust") + return [], 0, 0 + except Exception as e: + st.error(f"Error getting storage statistics: {e}") + return [], 0, 0 diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 90e3499..dded4fd 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -8,6 +8,7 @@ from entropice.dashboard.sections.experiment_results import ( render_experiment_results, render_training_results_summary, ) +from entropice.dashboard.sections.storage_statistics import render_storage_statistics from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics @@ -52,5 +53,10 @@ def render_overview_page(): render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df) + st.divider() + + # Render storage statistics section + render_storage_statistics() + st.balloons() stopwatch.summary() diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py deleted file mode 100644 index a7d94d7..0000000 --- a/src/entropice/dashboard/views/training_data_page.py +++ /dev/null @@ -1,481 +0,0 @@ -"""Training Data page: Visualization of training data distributions.""" - -from typing import cast - -import streamlit as st -from stopuhr import stopwatch - -from entropice.dashboard.plots.source_data import ( - render_alphaearth_map, - render_alphaearth_overview, - render_alphaearth_plots, - render_arcticdem_map, - render_arcticdem_overview, - render_arcticdem_plots, - render_areas_map, - render_era5_map, - render_era5_overview, - render_era5_plots, -) -from entropice.dashboard.plots.training_data import ( - render_all_distribution_histograms, - render_spatial_map, -) -from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data -from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble -from entropice.spatial import grids -from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs - - -def render_dataset_configuration_sidebar(): - """Render dataset configuration selector in sidebar with form. - - Stores the selected ensemble in session state when form is submitted. - """ - with st.sidebar.form("dataset_config_form"): - st.header("Dataset Configuration") - - # Grid selection - grid_options = [gc.display_name for gc in grid_configs] - - grid_level_combined = st.selectbox( - "Grid Configuration", - options=grid_options, - index=0, - help="Select the grid system and resolution level", - ) - - # Find the selected grid config - selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) - - # Target feature selection - target = st.selectbox( - "Target Feature", - options=["darts_rts", "darts_mllabels"], - index=0, - help="Select the target variable for training", - ) - - # Members selection - st.subheader("Dataset Members") - - all_members = cast( - list[L2SourceDataset], - ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"], - ) - selected_members: list[L2SourceDataset] = [] - - for member in all_members: - if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): - selected_members.append(member) # type: ignore[arg-type] - - # Form submit button - load_button = st.form_submit_button( - "Load Dataset", - type="primary", - use_container_width=True, - disabled=len(selected_members) == 0, - ) - - # Create DatasetEnsemble only when form is submitted - if load_button: - ensemble = DatasetEnsemble( - grid=selected_grid_config.grid, - level=selected_grid_config.level, - target=cast(TargetDataset, target), - members=selected_members, - ) - # Store ensemble in session state - st.session_state["dataset_ensemble"] = ensemble - st.session_state["dataset_loaded"] = True - - -def render_dataset_statistics(ensemble: DatasetEnsemble): - """Render dataset statistics and configuration overview. - - Args: - ensemble: The dataset ensemble configuration. - - """ - st.markdown("### 📊 Dataset Configuration") - - # Display current configuration in columns - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric(label="Grid Type", value=ensemble.grid.upper()) - - with col2: - st.metric(label="Grid Level", value=ensemble.level) - - with col3: - st.metric(label="Target Feature", value=ensemble.target.replace("darts_", "")) - - with col4: - st.metric(label="Members", value=len(ensemble.members)) - - # Display members in an expandable section - with st.expander("🗂️ Dataset Members", expanded=False): - members_cols = st.columns(len(ensemble.members)) - for idx, member in enumerate(ensemble.members): - with members_cols[idx]: - st.markdown(f"✓ **{member}**") - - # Display dataset ID in a styled container - st.info(f"**Dataset ID:** `{ensemble.id()}`") - - # Display detailed dataset statistics - st.markdown("---") - st.markdown("### 📈 Dataset Statistics") - - with st.spinner("Computing dataset statistics..."): - stats = ensemble.get_stats() - - # High-level summary metrics - col1, col2, col3 = st.columns(3) - with col1: - st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}") - with col2: - st.metric(label="Total Features", value=f"{stats['total_features']:,}") - with col3: - st.metric(label="Data Sources", value=len(stats["members"])) - - # Detailed member statistics in expandable section - with st.expander("📦 Data Source Details", expanded=False): - for member, member_stats in stats["members"].items(): - st.markdown(f"### {member}") - - # Create metrics for this member - metric_cols = st.columns(4) - with metric_cols[0]: - st.metric("Features", member_stats["num_features"]) - with metric_cols[1]: - st.metric("Variables", member_stats["num_variables"]) - with metric_cols[2]: - # Display dimensions in a more readable format - dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr] - st.metric("Shape", dim_str) - with metric_cols[3]: - # Calculate total data points - total_points = 1 - for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr] - total_points *= dim_size - st.metric("Data Points", f"{total_points:,}") - - # Show variables as colored badges - st.markdown("**Variables:**") - vars_html = " ".join( - [ - f'{v}' - for v in member_stats["variables"] # type: ignore[union-attr] - ] - ) - st.markdown(vars_html, unsafe_allow_html=True) - - # Show dimension details - st.markdown("**Dimensions:**") - dim_html = " ".join( - [ - f'' - f"{dim_name}: {dim_size}" - for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr] - ] - ) - st.markdown(dim_html, unsafe_allow_html=True) - - st.markdown("---") - - -def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]): - """Render target labels distribution and spatial visualization. - - Args: - ensemble: The dataset ensemble configuration. - train_data_dict: Pre-loaded training data for all tasks. - - """ - st.markdown("### Target Labels Distribution and Spatial Visualization") - - # Calculate total samples (use binary as reference) - total_samples = len(train_data_dict["binary"]) - train_samples = (train_data_dict["binary"].split == "train").sum().item() - test_samples = (train_data_dict["binary"].split == "test").sum().item() - - st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks") - - # Render distribution histograms - st.markdown("---") - render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type] - - st.markdown("---") - - # Render spatial map - binary_dataset = train_data_dict["binary"] - assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" - - render_spatial_map(train_data_dict) - - -def render_areas_view(ensemble: DatasetEnsemble, grid_gdf): - """Render grid cell areas and land/water distribution. - - Args: - ensemble: The dataset ensemble configuration. - grid_gdf: Pre-loaded grid GeoDataFrame. - - """ - st.markdown("### Grid Cell Areas and Land/Water Distribution") - - st.markdown( - "This visualization shows the spatial distribution of cell areas, land areas, " - "water areas, and land ratio across the grid. The grid has been filtered to " - "include only cells in the permafrost region (>50° latitude, <85° latitude) " - "with >10% land coverage." - ) - - st.success( - f"Loaded {len(grid_gdf)} grid cells with areas ranging from " - f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²" - ) - - # Show summary statistics - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Cells", f"{len(grid_gdf):,}") - with col2: - st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²") - with col3: - st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}") - with col4: - total_land = grid_gdf["land_area"].sum() - st.metric("Total Land Area", f"{total_land:,.0f} km²") - - st.markdown("---") - - # Check if we should skip map rendering for performance - if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " - "due to performance considerations." - ) - else: - render_areas_map(grid_gdf, ensemble.grid) - - -def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets): - """Render AlphaEarth embeddings analysis. - - Args: - ensemble: The dataset ensemble configuration. - alphaearth_ds: Pre-loaded AlphaEarth dataset. - targets: Pre-loaded targets GeoDataFrame. - - """ - st.markdown("### AlphaEarth Embeddings Analysis") - - st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells") - - render_alphaearth_overview(alphaearth_ds) - render_alphaearth_plots(alphaearth_ds) - - st.markdown("---") - - # Check if we should skip map rendering for performance - if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " - "due to performance considerations." - ) - else: - render_alphaearth_map(alphaearth_ds, targets, ensemble.grid) - - -def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets): - """Render ArcticDEM terrain analysis. - - Args: - ensemble: The dataset ensemble configuration. - arcticdem_ds: Pre-loaded ArcticDEM dataset. - targets: Pre-loaded targets GeoDataFrame. - - """ - st.markdown("### ArcticDEM Terrain Analysis") - - st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells") - - render_arcticdem_overview(arcticdem_ds) - render_arcticdem_plots(arcticdem_ds) - - st.markdown("---") - - # Check if we should skip map rendering for performance - if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " - "due to performance considerations." - ) - else: - render_arcticdem_map(arcticdem_ds, targets, ensemble.grid) - - -@st.fragment -def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets): - """Render ERA5 climate data analysis. - - Args: - ensemble: The dataset ensemble configuration. - era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples. - targets: Pre-loaded targets GeoDataFrame. - - """ - st.markdown("### ERA5 Climate Data Analysis") - - # Let user select which ERA5 temporal aggregation to view - era5_options = { - "ERA5-yearly": "Yearly", - "ERA5-seasonal": "Seasonal (Winter/Summer)", - "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)", - } - - available_era5 = {k: v for k, v in era5_options.items() if k in era5_data} - - selected_era5 = st.selectbox( - "Select ERA5 temporal aggregation", - options=list(available_era5.keys()), - format_func=lambda x: available_era5[x], - key="era5_temporal_select", - ) - - if selected_era5 and selected_era5 in era5_data: - era5_ds, temporal_type = era5_data[selected_era5] - - render_era5_overview(era5_ds, temporal_type) - render_era5_plots(era5_ds, temporal_type) - - st.markdown("---") - - # Check if we should skip map rendering for performance - if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): - st.warning( - "🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " - "due to performance considerations." - ) - else: - render_era5_map(era5_ds, targets, ensemble.grid, temporal_type) - - -def render_training_data_page(): - """Render the Training Data page of the dashboard.""" - st.title("🎯 Training Data") - - st.markdown( - """ - Explore and visualize the training data for RTS prediction models. - Configure your dataset by selecting grid configuration, target dataset, - and data sources in the sidebar, then click "Load Dataset" to begin. - """ - ) - - # Render sidebar configuration - render_dataset_configuration_sidebar() - - # Check if dataset is loaded in session state - if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state: - st.info( - "👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data" - ) - return - - # Get ensemble from session state - ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"] - - st.divider() - - # Load all necessary data once - with st.spinner("Loading dataset..."): - # Load training data for all tasks - train_data_dict = load_all_training_data(ensemble) - - # Load grid data - grid_gdf = grids.open(ensemble.grid, ensemble.level) - - # Load targets (needed by all source data views) - targets = ensemble._read_target() - - # Load AlphaEarth data if in members - alphaearth_ds = None - if "AlphaEarth" in ensemble.members: - alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth") - - # Load ArcticDEM data if in members - arcticdem_ds = None - if "ArcticDEM" in ensemble.members: - arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM") - - # Load ERA5 data for all temporal aggregations in members - era5_data = {} - era5_members = [m for m in ensemble.members if m.startswith("ERA5")] - for era5_member in era5_members: - era5_ds, _ = load_source_data(ensemble, era5_member) - temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder' - era5_data[era5_member] = (era5_ds, temporal_type) - - st.success( - f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features" - ) - - # Render dataset statistics - render_dataset_statistics(ensemble) - - st.markdown("---") - - # Create tabs for different data views - tab_names = ["📊 Labels", "📐 Areas"] - - # Add tabs for each member based on what's in the ensemble - if "AlphaEarth" in ensemble.members: - tab_names.append("🌍 AlphaEarth") - if "ArcticDEM" in ensemble.members: - tab_names.append("🏔️ ArcticDEM") - - # Check for ERA5 members - if era5_members: - tab_names.append("🌡️ ERA5") - - tabs = st.tabs(tab_names) - - # Track current tab index - tab_idx = 0 - - # Labels tab - with tabs[tab_idx]: - render_labels_view(ensemble, train_data_dict) - tab_idx += 1 - - # Areas tab - with tabs[tab_idx]: - render_areas_view(ensemble, grid_gdf) - tab_idx += 1 - - # AlphaEarth tab - if "AlphaEarth" in ensemble.members: - with tabs[tab_idx]: - render_alphaearth_view(ensemble, alphaearth_ds, targets) - tab_idx += 1 - - # ArcticDEM tab - if "ArcticDEM" in ensemble.members: - with tabs[tab_idx]: - render_arcticdem_view(ensemble, arcticdem_ds, targets) - tab_idx += 1 - - # ERA5 tab (combining all temporal variants) - if era5_members: - with tabs[tab_idx]: - render_era5_view(ensemble, era5_data, targets) - - # Show balloons once after all tabs are rendered - st.balloons() - stopwatch.summary()