diff --git a/src/entropice/dashboard/plots/embeddings.py b/src/entropice/dashboard/plots/embeddings.py new file mode 100644 index 0000000..3166df0 --- /dev/null +++ b/src/entropice/dashboard/plots/embeddings.py @@ -0,0 +1,376 @@ +"""Render the AlphaEarth visualization tab.""" + +import geopandas as gpd +import matplotlib.colors as mcolors +import numpy as np +import plotly.graph_objects as go +import pydeck as pdk +import xarray as xr + +from entropice.dashboard.utils.colors import get_cmap, hex_to_rgb +from entropice.dashboard.utils.geometry import fix_hex_geometry + + +def create_embedding_map( + embedding_values: xr.DataArray, + grid_gdf: gpd.GeoDataFrame, + make_3d_map: bool, +) -> pdk.Deck: + """Create a spatial distribution map for AlphaEarth embeddings. + + Args: + embedding_values (xr.DataArray): DataArray containing the already filtered AlphaEarth embeddings. + grid_gdf (gpd.GeoDataFrame): GeoDataFrame containing grid cell geometries. + make_3d_map (bool): Whether to render the map in 3D (extruded) or 2D. + + Returns: + pdk.Deck: A PyDeck map visualization of the AlphaEarth embeddings. + + """ + # Subsample if too many cells for performance + n_cells = len(embedding_values["cell_ids"]) + if n_cells > 100000: + rng = np.random.default_rng(42) # Fixed seed for reproducibility + cell_indices = rng.choice(n_cells, size=100000, replace=False) + embedding_values = embedding_values.isel(cell_ids=cell_indices) + + # Create a copy to avoid modifying the original + gdf = grid_gdf.copy().to_crs("EPSG:4326") + + # Convert to DataFrame for easier merging + embedding_df = embedding_values.to_dataframe(name="embedding_value") + + # Reset index if cell_id is already the index + if gdf.index.name == "cell_id": + gdf = gdf.reset_index() + + # Filter grid to only cells that have embedding data + gdf = gdf[gdf["cell_id"].isin(embedding_df.index)] + gdf = gdf.set_index("cell_id") + + # Merge embedding values with grid geometries + gdf = gdf.join(embedding_df, how="inner") + + # Convert to WGS84 for pydeck + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues for hex cells + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) + + # Get colormap for embeddings + cmap = get_cmap("AlphaEarth") + + # Normalize the embedding values to [0, 1] for color mapping + # Use percentiles to avoid outliers + values = gdf_wgs84["embedding_value"].values + vmin, vmax = np.nanpercentile(values, [2, 98]) + + if vmax > vmin: + normalized_values = np.clip((values - vmin) / (vmax - vmin), 0, 1) + else: + normalized_values = np.zeros_like(values) + + # Map normalized values to colors + colors = [cmap(val) for val in normalized_values] + rgb_colors = [hex_to_rgb(mcolors.to_hex(color)) for color in colors] + gdf_wgs84["fill_color"] = rgb_colors + + # Store embedding value for tooltip + gdf_wgs84["embedding_value_display"] = values + + # Store normalized values for elevation (if 3D) + gdf_wgs84["elevation"] = normalized_values + + # Convert to GeoJSON format + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "fill_color": row["fill_color"], + "embedding_value": float(row["embedding_value_display"]), + "elevation": float(row["elevation"]) if make_3d_map else 0, + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=0.7, + stroked=True, + filled=True, + extruded=make_3d_map, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation" if make_3d_map else 0, + elevation_scale=500000, # Scale normalized values (0-1) to 500km height + pickable=True, + ) + + # Set initial view state (centered on the Arctic) + # Adjust pitch and zoom based on whether we're using 3D + view_state = pdk.ViewState( + latitude=70, + longitude=0, + zoom=2 if not make_3d_map else 1.5, + pitch=0 if not make_3d_map else 45, + ) + + # Build tooltip HTML + tooltip_html = "Embedding Value: {embedding_value}" + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + return deck + + +def create_embedding_trend_plot(embedding_values: xr.DataArray) -> go.Figure: + """Create a trend plot for AlphaEarth embeddings over time. + + Contains a line plot with shaded areas representing the 10th to 90th percentiles. + Min and Max values are marked through a dashed line. + + Args: + embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings with a 'year' dimension. + + Returns: + go.Figure: A Plotly figure showing the trend of embeddings over time. + + """ + # Subsample if too many cells for performance + n_cells = len(embedding_values["cell_ids"]) + if n_cells > 10000: + rng = np.random.default_rng(42) # Fixed seed for reproducibility + cell_indices = rng.choice(n_cells, size=10000, replace=False) + embedding_values = embedding_values.isel(cell_ids=cell_indices) + + # Calculate statistics over space (cell_ids) for each year + years = embedding_values["year"].values + + # Compute statistics across cells for each year + mean_values = embedding_values.mean(dim="cell_ids").to_numpy() + min_values = embedding_values.min(dim="cell_ids").to_numpy() + max_values = embedding_values.max(dim="cell_ids").to_numpy() + p10_values = embedding_values.quantile(0.10, dim="cell_ids").to_numpy() + p90_values = embedding_values.quantile(0.90, dim="cell_ids").to_numpy() + + fig = go.Figure() + + # Add min/max range first (background) - dashed lines + fig.add_trace( + go.Scatter( + x=years, + y=min_values, + mode="lines", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + name="Min/Max Range", + showlegend=True, + ) + ) + + fig.add_trace( + go.Scatter( + x=years, + y=max_values, + mode="lines", + fill="tonexty", + fillcolor="rgba(200, 200, 200, 0.1)", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + showlegend=False, + ) + ) + + # Add 10th-90th percentile band + fig.add_trace( + go.Scatter( + x=years, + y=p10_values, + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=years, + y=p90_values, + mode="lines", + fill="tonexty", + fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth + line={"width": 0}, + name="10th-90th Percentile", + ) + ) + + # Add mean line on top + fig.add_trace( + go.Scatter( + x=years, + y=mean_values, + mode="lines+markers", + name="Mean", + line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth + marker={"size": 6}, + ) + ) + + fig.update_layout( + title="Embedding Values Over Time (Spatial Statistics)", + xaxis_title="Year", + yaxis_title="Embedding Value", + yaxis={"zeroline": True, "zerolinewidth": 2, "zerolinecolor": "gray"}, + height=450, + hovermode="x unified", + legend={ + "orientation": "h", + "yanchor": "bottom", + "y": 1.02, + "xanchor": "right", + "x": 1, + }, + ) + + # Format x-axis to show years as integers + fig.update_xaxes(dtick=1) + + return fig + + +def create_embedding_distribution_plot(embedding_values: xr.DataArray) -> go.Figure: + """Create a distribution plot showing the min/max, 10th/90th percentiles, and mean of AlphaEarth bands. + + Args: + embedding_values (xr.DataArray): DataArray containing the AlphaEarth embeddings. + + Returns: + go.Figure: A Plotly figure showing the distribution of embeddings. + + """ + # Subsample if too many cells for performance + n_cells = len(embedding_values["cell_ids"]) + if n_cells > 10000: + rng = np.random.default_rng(42) # Fixed seed for reproducibility + cell_indices = rng.choice(n_cells, size=10000, replace=False) + embedding_values = embedding_values.isel(cell_ids=cell_indices) + + # Get band dimension + bands = embedding_values["band"].values + + # Calculate statistics for each band across all cells + band_stats = [] + for band in bands: + band_data = embedding_values.sel(band=band).values.flatten() + # Remove NaN values + band_data = band_data[~np.isnan(band_data)] + + if len(band_data) > 0: + band_stats.append( + { + "Band": str(band), + "Mean": float(np.mean(band_data)), + "Min": float(np.min(band_data)), + "Max": float(np.max(band_data)), + "P10": float(np.percentile(band_data, 10)), + "P90": float(np.percentile(band_data, 90)), + } + ) + + # Create DataFrame from statistics + import pandas as pd + + band_df = pd.DataFrame(band_stats) + + fig = go.Figure() + + # Add min/max range first (background) - dashed lines + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Min"], + mode="lines", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + name="Min/Max Range", + showlegend=True, + ) + ) + + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Max"], + mode="lines", + fill="tonexty", + fillcolor="rgba(200, 200, 200, 0.1)", + line={"color": "lightgray", "width": 1, "dash": "dash"}, + showlegend=False, + ) + ) + + # Add 10th-90th percentile band + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["P10"], + mode="lines", + line={"width": 0}, + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["P90"], + mode="lines", + fill="tonexty", + fillcolor="rgba(76, 175, 80, 0.2)", # Green shade for AlphaEarth + line={"width": 0}, + name="10th-90th Percentile", + ) + ) + + # Add mean line on top + fig.add_trace( + go.Scatter( + x=band_df["Band"], + y=band_df["Mean"], + mode="lines+markers", + name="Mean", + line={"color": "#2E7D32", "width": 2}, # Darker green for AlphaEarth + marker={"size": 4}, + ) + ) + + fig.update_layout( + title="Embedding Distribution by Band (Spatial Statistics)", + xaxis_title="Band", + yaxis_title="Embedding Value", + height=450, + hovermode="x unified", + legend={ + "orientation": "h", + "yanchor": "bottom", + "y": 1.02, + "xanchor": "right", + "x": 1, + }, + ) + + return fig diff --git a/src/entropice/dashboard/sections/alphaearth.py b/src/entropice/dashboard/sections/alphaearth.py new file mode 100644 index 0000000..58cf9c4 --- /dev/null +++ b/src/entropice/dashboard/sections/alphaearth.py @@ -0,0 +1,238 @@ +"""AlphaEarth embeddings dashboard section.""" + +import geopandas as gpd +import matplotlib.colors as mcolors +import numpy as np +import streamlit as st +import xarray as xr + +from entropice.dashboard.plots.embeddings import ( + create_embedding_distribution_plot, + create_embedding_map, + create_embedding_trend_plot, +) +from entropice.dashboard.sections.dataset_statistics import render_member_details +from entropice.dashboard.utils.colors import get_cmap +from entropice.dashboard.utils.stats import MemberStatistics + + +def _get_band_agg_options(embeddings: xr.Dataset): + """Get band and aggregation selection options from user.""" + bands = embeddings["band"].values.tolist() + aggregations = embeddings["agg"].values.tolist() + + cols = st.columns([2, 2]) + with cols[0]: + band = st.selectbox( + "Select Embedding Band", + options=bands, + index=0, + help="Select which embedding band to visualize on the map.", + key="embedding_band_select", + ) + with cols[1]: + aggregation = st.selectbox( + "Select Aggregation Method", + options=aggregations, + index=0, + help="Select the aggregation method for the embeddings to visualize.", + key="embedding_agg_select", + ) + + return band, aggregation + + +@st.fragment +def _render_embedding_map(embedding_values: xr.DataArray, grid_gdf: gpd.GeoDataFrame): + st.subheader("AlphaEarth Embedding Map") + + st.markdown( + """ + This interactive map visualizes the spatial distribution of the selected embedding band across + the Arctic region. Each grid cell is colored according to its embedding value, revealing spatial + patterns in the satellite imagery features. High-resolution embeddings can indicate areas with + distinctive characteristics that may be relevant for RTS detection. + + **Map controls:** + - **Hover** over cells to see exact embedding values + - **3D mode**: Elevation represents embedding magnitude - higher areas have larger values + - **Rotate** (3D mode): Hold Ctrl/Cmd and drag to rotate the view + - **Zoom/Pan**: Scroll to zoom, click and drag to pan + """ + ) + + cols = st.columns([4, 1]) + with cols[0]: + if "year" in embedding_values.dims or "year" in embedding_values.coords: + year_values = embedding_values["year"].values.tolist() + year = st.slider( + "Select Year", + min_value=int(min(year_values)), + max_value=int(max(year_values)), + value=int(max(year_values)), + step=1, + help="Select the year for which to visualize the embeddings.", + ) + embedding_values = embedding_values.sel(year=year) + with cols[1]: + make_3d_map = st.checkbox("3D Map", value=True) + + # Check if subsampling will occur + n_cells = len(embedding_values["cell_ids"]) + if n_cells > 100000: + st.info(f"πŸ—ΊοΈ **Map subsampled:** Displaying 100,000 randomly selected cells out of {n_cells:,} for performance.") + + map_deck = create_embedding_map( + embedding_values=embedding_values, + grid_gdf=grid_gdf, + make_3d_map=make_3d_map, + ) + + st.pydeck_chart(map_deck, width="stretch") + + # Add legend + with st.expander("Legend", expanded=True): + st.markdown("**Embedding Value**") + + # Get the actual values to show accurate min/max (same as in the map function) + values = embedding_values.values.flatten() + values = values[~np.isnan(values)] + vmin, vmax = np.nanpercentile(values, [2, 98]) + + vmin_str = f"{vmin:.4f}" + vmax_str = f"{vmax:.4f}" + + # Get the same colormap used in the map + cmap = get_cmap("AlphaEarth") + # Sample 4 colors from the colormap to create the gradient + gradient_colors = [mcolors.to_hex(cmap(i)) for i in [0.0, 0.33, 0.67, 1.0]] + gradient_css = ", ".join(gradient_colors) + + # Create a simple gradient legend + st.markdown( + f'
' + f'{vmin_str}' + f'
' + f'{vmax_str}' + f"
", + unsafe_allow_html=True, + ) + + st.caption( + "Color intensity represents embedding values from low (green) to high (yellow). " + "Values are normalized using the 2nd-98th percentile range to avoid outliers." + ) + + if make_3d_map: + st.markdown("---") + st.markdown("**3D Elevation:**") + st.caption( + "Height represents normalized embedding values. Rotate the map by holding Ctrl/Cmd and dragging." + ) + + +def _render_trend(embeddin_values: xr.DataArray): + st.subheader("AlphaEarth Embedding Trends Over Time") + + st.markdown( + """ + This visualization shows how embedding values have changed over time across the study area. + The plot aggregates spatial statistics (mean, percentiles) for each year, revealing temporal + patterns in the satellite imagery embeddings that may correlate with environmental changes. + + **Understanding the plot:** + - **Mean line** (dark green): Average embedding value across all grid cells for each year + - **10th-90th percentile band** (light green): Range containing 80% of the values, showing + typical variation + - **Min/Max range** (gray): Full extent of values, highlighting outliers + """ + ) + + # Show dataset filtering info + band_val = embeddin_values["band"].values.item() if embeddin_values["band"].size == 1 else "multiple" + agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple" + st.caption( + f"πŸ“Š **Dataset selection:** Band `{band_val}`, Aggregation `{agg_val}` " + f"({len(embeddin_values['year'])} years, {len(embeddin_values['cell_ids']):,} cells)" + ) + + # Check if subsampling will occur + n_cells = len(embeddin_values["cell_ids"]) + if n_cells > 10000: + st.info( + f"πŸ“Š **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} " + "for performance. Statistics remain representative." + ) + + fig = create_embedding_trend_plot(embedding_values=embeddin_values) + st.plotly_chart(fig, width="stretch") + + +def _render_distribution(embeddin_values: xr.DataArray): + st.subheader("AlphaEarth Embedding Distribution") + + st.markdown( + """ + This plot shows the statistical distribution of embedding values across all 64 embedding + dimensions (bands). AlphaEarth embeddings are learned representations from satellite imagery, + with each band capturing different aspects of the landscape (e.g., vegetation, terrain, ice + cover, land use). + + **Understanding the plot:** + - **X-axis**: Embedding bands (A00-A63), each representing a learned feature from satellite + imagery + - **Mean line** (dark green): Average value across all grid cells for each band + - **10th-90th percentile band** (light green): Central distribution of values, excluding + outliers + - **Min/Max range** (gray): Full value range showing extreme values + + Different bands may capture different landscape features - bands with higher variance often + represent more spatially heterogeneous characteristics. + """ + ) + + # Show dataset filtering info + agg_val = embeddin_values["agg"].values.item() if embeddin_values["agg"].size == 1 else "multiple" + n_bands = len(embeddin_values["band"]) + n_cells = len(embeddin_values["cell_ids"]) + st.caption(f"πŸ“Š **Dataset selection:** Aggregation `{agg_val}` ({n_bands} bands, {n_cells:,} cells)") + + # Check if subsampling will occur + if n_cells > 10000: + st.info( + f"πŸ“Š **Dataset subsampled:** Using 10,000 randomly selected cells out of {n_cells:,} " + "for performance. Statistics remain representative." + ) + + fig = create_embedding_distribution_plot(embedding_values=embeddin_values) + st.plotly_chart(fig, width="stretch") + + +@st.fragment +def render_alphaearth_tab(embeddings: xr.Dataset, grid_gdf: gpd.GeoDataFrame, member_stats: MemberStatistics): + """Render the AlphaEarth visualization tab. + + Args: + embeddings: The AlphaEarth dataset member, lazily loaded. + grid_gdf: GeoDataFrame with grid cell geometries + member_stats: Statistics for the AlphaEarth member. + + """ + # Render different visualizations + with st.expander("AlphaEarth Embedding Statistics", expanded=True): + render_member_details("AlphaEarth", member_stats) + + st.divider() + + band, aggregation = _get_band_agg_options(embeddings) + embedding_values = embeddings["embeddings"].sel(agg=aggregation).compute() + + _render_distribution(embedding_values) + st.divider() + + if "year" in embedding_values.dims or "year" in embedding_values.coords: + _render_trend(embedding_values.sel(band=band)) + st.divider() + + _render_embedding_map(embedding_values.sel(band=band), grid_gdf) diff --git a/src/entropice/dashboard/sections/areas.py b/src/entropice/dashboard/sections/areas.py index dd4c732..5745679 100644 --- a/src/entropice/dashboard/sections/areas.py +++ b/src/entropice/dashboard/sections/areas.py @@ -23,7 +23,7 @@ def _render_area_map(grid_gdf: gpd.GeoDataFrame): key="metric", ) with cols[1]: - make_3d_map = cast(bool, st.checkbox("3D Map", value=True, key="area_3d_map")) + make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) map_deck = create_grid_areas_map(grid_gdf, metric, make_3d_map) st.pydeck_chart(map_deck) diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py index 7105402..e0aa834 100644 --- a/src/entropice/dashboard/sections/dataset_statistics.py +++ b/src/entropice/dashboard/sections/dataset_statistics.py @@ -436,6 +436,42 @@ def _render_aggregation_selection( return dimension_filters +def render_member_details(member: str, member_stats: MemberStatistics): + """Render detailed information for a single member. + + Displays variables and dimensions with styled badges. + + Args: + member: Member dataset name + member_stats: Statistics for the member + + """ + st.markdown(f"### {member}") + + # Variables + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats.variable_names + ] + ) + st.markdown(vars_html, unsafe_allow_html=True) + + # Dimensions + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size:,}" + for dim_name, dim_size in member_stats.dimensions.items() + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) + + def render_ensemble_details( selected_members: list[L2SourceDataset], selected_member_stats: dict[L2SourceDataset, MemberStatistics], @@ -502,33 +538,9 @@ def render_ensemble_details( st.dataframe(details_df, hide_index=True, width="stretch") # Individual member details - for member, member_stats in selected_member_stats.items(): - st.markdown(f"### {member}") - - # Variables - st.markdown("**Variables:**") - vars_html = " ".join( - [ - f'{v}' - for v in member_stats.variable_names - ] - ) - st.markdown(vars_html, unsafe_allow_html=True) - - # Dimensions - st.markdown("**Dimensions:**") - dim_html = " ".join( - [ - f'' - f"{dim_name}: {dim_size:,}" - for dim_name, dim_size in member_stats.dimensions.items() - ] - ) - st.markdown(dim_html, unsafe_allow_html=True) - - st.markdown("---") + for member, stats in selected_member_stats.items(): + render_member_details(member, stats) + st.divider() def _render_configuration_summary( diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py index 04b3176..b627f61 100644 --- a/src/entropice/dashboard/sections/experiment_results.py +++ b/src/entropice/dashboard/sections/experiment_results.py @@ -52,6 +52,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: "Experiment", options=["All", *experiments], index=0, + key="exp_results_experiment", ) else: selected_experiment = "All" @@ -61,6 +62,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: "Task", options=["All", *tasks], index=0, + key="exp_results_task", ) with filter_cols[2]: @@ -68,6 +70,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: "Model", options=["All", *models], index=0, + key="exp_results_model", ) with filter_cols[3]: @@ -75,6 +78,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: "Grid", options=["All", *grids], index=0, + key="exp_results_grid", ) # Apply filters diff --git a/src/entropice/dashboard/sections/targets.py b/src/entropice/dashboard/sections/targets.py index 8c41cd1..1903c0f 100644 --- a/src/entropice/dashboard/sections/targets.py +++ b/src/entropice/dashboard/sections/targets.py @@ -170,6 +170,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS "Select Target Dataset", options=sorted(train_data_dict.keys()), index=0, + key="target_map_dataset", ), ) with cols[1]: @@ -179,6 +180,7 @@ def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingS "Select Task", options=sorted(train_data_dict[selected_target].keys()), index=0, + key="target_map_task", ), ) with cols[2]: diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index 83c9796..cc3419c 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -15,8 +15,9 @@ from shapely.geometry import shape import entropice.spatial.grids import entropice.utils.paths from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo +from entropice.ml.dataset import DatasetEnsemble, TrainingSet from entropice.ml.training import TrainingSettings -from entropice.utils.types import GridConfig +from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks def _fix_hex_geometry(geom): @@ -239,3 +240,13 @@ def load_all_training_results() -> list[TrainingResult]: # Sort by creation time (most recent first) training_results.sort(key=lambda tr: tr.created_at, reverse=True) return training_results + + +def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Task, TrainingSet]]: + """Load training sets for all target-task combinations in the ensemble.""" + train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {} + for target in all_target_datasets: + train_data_dict[target] = {} + for task in all_tasks: + train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task) + return train_data_dict diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py index d061bc4..e88b83c 100644 --- a/src/entropice/dashboard/utils/stats.py +++ b/src/entropice/dashboard/utils/stats.py @@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass from typing import Literal import pandas as pd +import xarray as xr from stopuhr import stopwatch import entropice.spatial.grids @@ -39,11 +40,19 @@ class MemberStatistics: size_bytes: int # Size of this member's data on disk in bytes @classmethod - def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]: + def compute( + cls, + e: DatasetEnsemble, + member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None, + ) -> dict[L2SourceDataset, "MemberStatistics"]: """Pre-compute the statistics for a specific dataset member.""" + member_datasets = member_datasets or {} member_stats = {} for member in e.members: - ds = e.read_member(member, lazy=True) + if member in member_datasets: + ds = member_datasets[member] + else: + ds = e.read_member(member, lazy=True) size_bytes = ds.nbytes n_cols_member = len(ds.data_vars) @@ -113,7 +122,11 @@ class DatasetStatistics: target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task @classmethod - def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics": + def from_ensemble( + cls, + e: DatasetEnsemble, + member_datasets: dict[L2SourceDataset, xr.Dataset] | None = None, + ) -> "DatasetStatistics": """Compute dataset statistics from a DatasetEnsemble.""" grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered total_cells = len(grid_gdf) @@ -123,7 +136,7 @@ class DatasetStatistics: # darts_mllabels does not support year-based temporal modes continue target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells) - member_statistics = MemberStatistics.compute(e) + member_statistics = MemberStatistics.compute(e, member_datasets=member_datasets) total_features = sum(ms.feature_count for ms in member_statistics.values()) total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) diff --git a/src/entropice/dashboard/views/dataset_page.py b/src/entropice/dashboard/views/dataset_page.py index e2b0a19..032511d 100644 --- a/src/entropice/dashboard/views/dataset_page.py +++ b/src/entropice/dashboard/views/dataset_page.py @@ -3,21 +3,20 @@ from typing import cast import streamlit as st +import xarray as xr from stopuhr import stopwatch +from entropice.dashboard.sections.alphaearth import render_alphaearth_tab from entropice.dashboard.sections.areas import render_area_information_tab from entropice.dashboard.sections.dataset_statistics import render_ensemble_details from entropice.dashboard.sections.targets import render_target_information_tab +from entropice.dashboard.utils.loaders import load_training_sets from entropice.dashboard.utils.stats import DatasetStatistics -from entropice.ml.dataset import DatasetEnsemble, TrainingSet +from entropice.ml.dataset import DatasetEnsemble from entropice.utils.types import ( GridConfig, L2SourceDataset, - TargetDataset, - Task, TemporalMode, - all_target_datasets, - all_tasks, grid_configs, ) @@ -38,6 +37,7 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble: options=grid_options, index=0, help="Select the grid system and resolution level", + key="dataset_page_grid", ) # Find the selected grid config @@ -48,12 +48,13 @@ def render_dataset_configuration_sidebar() -> DatasetEnsemble: "Temporal Mode", options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]), index=0, - format_func=lambda x: "Synopsis (all years)" + format_func=lambda x: "Synopsis (mean + trend)" if x == "synopsis" else "Years-as-Features" if x == "feature" else f"Year {x}", help="Select temporal mode: 'synopsis' for temporal features or specific year", + key="dataset_page_temporal_mode", ) # Members selection @@ -108,23 +109,20 @@ def render_dataset_page(): st.divider() + member_datasets = cast( + dict[L2SourceDataset, xr.Dataset], + {member: ensemble.read_member(member, lazy=True) for member in ensemble.members}, + ) # Render dataset statistics section - stats = DatasetStatistics.from_ensemble(ensemble) + stats = DatasetStatistics.from_ensemble(ensemble, member_datasets=member_datasets) render_ensemble_details(ensemble.members, stats.members) st.divider() # Load data and precompute visualizations - # First, load for all task - target combinations the training data - train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {} - for target in all_target_datasets: - train_data_dict[target] = {} - for task in all_tasks: - train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task) # Preload the grid GeoDataFrame grid_gdf = ensemble.read_grid() - era5_members = [m for m in ensemble.members if m.startswith("ERA5")] # Create tabs for different data views tab_names = ["🎯 Targets", "πŸ“ Areas"] # Add tabs for each member based on what's in the ensemble @@ -132,21 +130,27 @@ def render_dataset_page(): tab_names.append("🌍 AlphaEarth") if "ArcticDEM" in ensemble.members: tab_names.append("πŸ”οΈ ArcticDEM") + era5_members = [m for m in ensemble.members if m.startswith("ERA5")] if era5_members: tab_names.append("🌑️ ERA5") tabs = st.tabs(tab_names) with tabs[0]: st.header("🎯 Target Labels Visualization") - if False: #! debug + if False: # ! DEBUG + train_data_dict = load_training_sets(ensemble) render_target_information_tab(train_data_dict) with tabs[1]: st.header("πŸ“ Areas Visualization") - render_area_information_tab(grid_gdf) + if False: # ! DEBUG + render_area_information_tab(grid_gdf) tab_index = 2 if "AlphaEarth" in ensemble.members: with tabs[tab_index]: st.header("🌍 AlphaEarth Visualization") + alphaearth_ds = member_datasets["AlphaEarth"] + alphaearth_stats = stats.members["AlphaEarth"] + render_alphaearth_tab(alphaearth_ds, grid_gdf, alphaearth_stats) tab_index += 1 if "ArcticDEM" in ensemble.members: with tabs[tab_index]: