From 4fecac535c42e80bafd08add95c28010edf06cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 4 Jan 2026 21:35:21 +0100 Subject: [PATCH] Split overview plots --- .../plots/hyperparameter_analysis.py | 14 +- src/entropice/dashboard/plots/inference.py | 15 +- src/entropice/dashboard/plots/overview.py | 244 ++++++++++++++++++ src/entropice/dashboard/plots/source_data.py | 20 +- .../dashboard/plots/training_data.py | 15 +- src/entropice/dashboard/utils/geometry.py | 37 +++ .../dashboard/views/overview_page.py | 134 ++-------- 7 files changed, 316 insertions(+), 163 deletions(-) create mode 100644 src/entropice/dashboard/plots/overview.py create mode 100644 src/entropice/dashboard/utils/geometry.py diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 79d1867..4c9100e 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -3,16 +3,15 @@ from pathlib import Path import altair as alt -import antimeridian import geopandas as gpd import matplotlib.colors as mcolors import numpy as np import pandas as pd import pydeck as pdk import streamlit as st -from shapely.geometry import shape from entropice.dashboard.utils.colors import get_cmap, get_palette +from entropice.dashboard.utils.geometry import fix_hex_geometry from entropice.ml.dataset import DatasetEnsemble @@ -1152,15 +1151,6 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1 st.dataframe(display_df, hide_index=True, width="stretch") -def _fix_hex_geometry(geom): - """Fix hexagon geometry crossing the antimeridian.""" - try: - return shape(antimeridian.fix_shape(geom)) - except ValueError as e: - st.error(f"Error fixing geometry: {e}") - return geom - - @st.fragment def render_confusion_matrix_map(result_path: Path, settings: dict): """Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN). @@ -1288,7 +1278,7 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): # Fix antimeridian issues for hex grids if grid == "hex": - display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) + display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry) # Assign colors based on confusion category if task == "binary": diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py index 29804f7..4713cdc 100644 --- a/src/entropice/dashboard/plots/inference.py +++ b/src/entropice/dashboard/plots/inference.py @@ -5,23 +5,12 @@ import pandas as pd import plotly.graph_objects as go import pydeck as pdk import streamlit as st -from shapely.geometry import shape from entropice.dashboard.utils.colors import get_palette +from entropice.dashboard.utils.geometry import fix_hex_geometry from entropice.dashboard.utils.loaders import TrainingResult -def _fix_hex_geometry(geom): - """Fix hexagon geometry crossing the antimeridian.""" - import antimeridian - - try: - return shape(antimeridian.fix_shape(geom)) - except ValueError as e: - st.error(f"Error fixing geometry: {e}") - return geom - - def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str): """Render summary statistics about inference results. @@ -249,7 +238,7 @@ def render_inference_map(result: TrainingResult): # Fix antimeridian issues for hex grids if grid == "hex": - display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) + display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry) # Assign colors based on predicted class colors_palette = get_palette(task, len(all_classes)) diff --git a/src/entropice/dashboard/plots/overview.py b/src/entropice/dashboard/plots/overview.py new file mode 100644 index 0000000..0770f01 --- /dev/null +++ b/src/entropice/dashboard/plots/overview.py @@ -0,0 +1,244 @@ +"""Visualization functions for the overview page. + +This module contains reusable plotting functions for dataset analysis visualizations, +including sample counts, feature counts, and dataset statistics. +""" + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + + +def create_sample_count_heatmap( + pivot_df: pd.DataFrame, + target: str, + colorscale: list[str] | None = None, +) -> go.Figure: + """Create heatmap showing sample counts by grid and task. + + Args: + pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values. + target: Target dataset name for the title. + colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the heatmap visualization. + + """ + fig = px.imshow( + pivot_df, + labels={ + "x": "Task", + "y": "Grid Configuration", + "color": "Sample Count", + }, + x=pivot_df.columns, + y=pivot_df.index, + color_continuous_scale=colorscale, + aspect="auto", + title=f"Target: {target}", + ) + + # Add text annotations + fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10) + fig.update_layout(height=400) + + return fig + + +def create_sample_count_bar_chart( + sample_df: pd.DataFrame, + task_colors: list[str] | None = None, +) -> go.Figure: + """Create bar chart showing sample counts by grid, target, and task. + + Args: + sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage). + task_colors: Optional color palette for tasks. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + fig = px.bar( + sample_df, + x="Grid", + y="Samples (Coverage)", + color="Task", + facet_col="Target", + barmode="group", + title="Sample Counts by Grid Configuration and Target Dataset", + labels={ + "Grid": "Grid Configuration", + "Samples (Coverage)": "Number of Samples", + }, + color_discrete_sequence=task_colors, + height=500, + ) + + # Update facet labels to be cleaner + fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) + fig.update_xaxes(tickangle=-45) + + return fig + + +def create_feature_count_stacked_bar( + breakdown_df: pd.DataFrame, + source_colors: list[str] | None = None, +) -> go.Figure: + """Create stacked bar chart showing feature counts by data source. + + Args: + breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features. + source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the stacked bar chart visualization. + + """ + fig = px.bar( + breakdown_df, + x="Grid", + y="Number of Features", + color="Data Source", + barmode="stack", + title="Total Features by Data Source Across Grid Configurations", + labels={ + "Grid": "Grid Configuration", + "Number of Features": "Number of Features", + }, + color_discrete_sequence=source_colors, + text_auto=False, + ) + + fig.update_layout(height=500, xaxis_tickangle=-45) + + return fig + + +def create_inference_cells_bar( + comparison_df: pd.DataFrame, + grid_colors: list[str] | None = None, +) -> go.Figure: + """Create bar chart for inference cells by grid configuration. + + Args: + comparison_df: DataFrame with columns: Grid, Inference Cells. + grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + fig = px.bar( + comparison_df, + x="Grid", + y="Inference Cells", + color="Grid", + title="Inference Cells by Grid Configuration", + labels={ + "Grid": "Grid Configuration", + "Inference Cells": "Number of Cells", + }, + color_discrete_sequence=grid_colors, + text="Inference Cells", + ) + + fig.update_traces(texttemplate="%{text:,}", textposition="outside") + fig.update_layout(xaxis_tickangle=-45, showlegend=False) + + return fig + + +def create_total_samples_bar( + comparison_df: pd.DataFrame, + grid_colors: list[str] | None = None, +) -> go.Figure: + """Create bar chart for total samples by grid configuration. + + Args: + comparison_df: DataFrame with columns: Grid, Total Samples. + grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + fig = px.bar( + comparison_df, + x="Grid", + y="Total Samples", + color="Grid", + title="Total Samples by Grid Configuration", + labels={ + "Grid": "Grid Configuration", + "Total Samples": "Number of Samples", + }, + color_discrete_sequence=grid_colors, + text="Total Samples", + ) + + fig.update_traces(texttemplate="%{text:,}", textposition="outside") + fig.update_layout(xaxis_tickangle=-45, showlegend=False) + + return fig + + +def create_feature_breakdown_donut( + grid_data: pd.DataFrame, + grid_config: str, + source_colors: list[str] | None = None, +) -> go.Figure: + """Create donut chart for feature breakdown by data source for a specific grid. + + Args: + grid_data: DataFrame with columns: Data Source, Number of Features. + grid_config: Grid configuration name for the title. + source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the donut chart visualization. + + """ + fig = px.pie( + grid_data, + names="Data Source", + values="Number of Features", + title=grid_config, + hole=0.4, + color_discrete_sequence=source_colors, + ) + + fig.update_traces(textposition="inside", textinfo="percent") + fig.update_layout(showlegend=True, height=350) + + return fig + + +def create_feature_distribution_pie( + breakdown_df: pd.DataFrame, + source_colors: list[str] | None = None, +) -> go.Figure: + """Create pie chart for feature distribution by data source. + + Args: + breakdown_df: DataFrame with columns: Data Source, Number of Features. + source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the pie chart visualization. + + """ + fig = px.pie( + breakdown_df, + names="Data Source", + values="Number of Features", + title="Feature Distribution by Data Source", + hole=0.4, + color_discrete_sequence=source_colors, + ) + + fig.update_traces(textposition="inside", textinfo="percent+label") + + return fig diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py index 12a74c9..cb0299a 100644 --- a/src/entropice/dashboard/plots/source_data.py +++ b/src/entropice/dashboard/plots/source_data.py @@ -1,6 +1,5 @@ """Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5).""" -import antimeridian import geopandas as gpd import numpy as np import pandas as pd @@ -8,22 +7,13 @@ import plotly.graph_objects as go import pydeck as pdk import streamlit as st import xarray as xr -from shapely.geometry import shape from entropice.dashboard.utils.colors import get_cmap +from entropice.dashboard.utils.geometry import fix_hex_geometry # TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations -def _fix_hex_geometry(geom): - """Fix hexagon geometry crossing the antimeridian.""" - try: - return shape(antimeridian.fix_shape(geom)) - except ValueError as e: - st.error(f"Error fixing geometry: {e}") - return geom - - def render_alphaearth_overview(ds: xr.Dataset): """Render overview statistics for AlphaEarth embeddings data. @@ -573,7 +563,7 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): # Fix geometries after CRS conversion if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) # Normalize values for color mapping values = gdf_wgs84["value"].to_numpy() @@ -687,7 +677,7 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): # Fix geometries after CRS conversion if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) # Normalize values for color mapping values = gdf_wgs84["value"].values @@ -816,7 +806,7 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str): # Fix geometries after CRS conversion if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) # Get values for the selected metric values = gdf_wgs84[area_metric].to_numpy() @@ -975,7 +965,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor # Fix geometries after CRS conversion if grid == "hex": - gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) # Normalize values for color mapping values = gdf_wgs84["value"].values diff --git a/src/entropice/dashboard/plots/training_data.py b/src/entropice/dashboard/plots/training_data.py index 73ee446..966c64a 100644 --- a/src/entropice/dashboard/plots/training_data.py +++ b/src/entropice/dashboard/plots/training_data.py @@ -5,9 +5,9 @@ import pandas as pd import plotly.graph_objects as go import pydeck as pdk import streamlit as st -from shapely.geometry import shape from entropice.dashboard.utils.colors import get_palette +from entropice.dashboard.utils.geometry import fix_hex_geometry from entropice.ml.dataset import CategoricalTrainingDataset @@ -105,17 +105,6 @@ def render_all_distribution_histograms( st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%") -def _fix_hex_geometry(geom): - """Fix hexagon geometry crossing the antimeridian.""" - import antimeridian - - try: - return shape(antimeridian.fix_shape(geom)) - except ValueError as e: - st.error(f"Error fixing geometry: {e}") - return geom - - def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task): """Assign colors to geodataframe based on the selected color mode. @@ -204,7 +193,7 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]): gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment] # Fix antimeridian issues - gdf["geometry"] = gdf["geometry"].apply(_fix_hex_geometry) + gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry) # Add binned labels and split information from current dataset gdf["target_class"] = dataset.y.binned.to_numpy() diff --git a/src/entropice/dashboard/utils/geometry.py b/src/entropice/dashboard/utils/geometry.py new file mode 100644 index 0000000..359db95 --- /dev/null +++ b/src/entropice/dashboard/utils/geometry.py @@ -0,0 +1,37 @@ +"""Geometry utilities for dashboard visualizations.""" + +import streamlit as st +from shapely.geometry import shape + +try: + import antimeridian +except ImportError: + antimeridian = None + + +def fix_hex_geometry(geom): + """Fix hexagon geometry crossing the antimeridian. + + This function handles geometries that cross the antimeridian (180° longitude) + which can cause rendering issues in visualization libraries. Uses the antimeridian + library to split and fix such geometries. + + Args: + geom: A geometry object (typically from shapely or geojson). + + Returns: + Fixed geometry object with antimeridian issues resolved. + + Note: + If the antimeridian library is not available or an error occurs, + returns the original geometry unchanged. + + """ + if antimeridian is None: + return geom + + try: + return shape(antimeridian.fix_shape(geom)) + except ValueError as e: + st.error(f"Error fixing geometry: {e}") + return geom diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 8cf243b..8c0d841 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -4,10 +4,18 @@ from datetime import datetime from typing import cast import pandas as pd -import plotly.express as px import streamlit as st from stopuhr import stopwatch +from entropice.dashboard.plots.overview import ( + create_feature_breakdown_donut, + create_feature_count_stacked_bar, + create_feature_distribution_pie, + create_inference_cells_bar, + create_sample_count_bar_chart, + create_sample_count_heatmap, + create_total_samples_bar, +) from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.stats import ( @@ -67,54 +75,20 @@ def render_sample_count_overview(): # Get color palette for sample counts sample_colors = get_palette(f"sample_counts_{target}", n_colors=10) - fig = px.imshow( - pivot_df, - labels={ - "x": "Task", - "y": "Grid Configuration", - "color": "Sample Count", - }, - x=pivot_df.columns, - y=pivot_df.index, - color_continuous_scale=sample_colors, - aspect="auto", - title=f"Target: {target}", - ) - - # Add text annotations - fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10) - - fig.update_layout(height=400) + # Create and display heatmap + fig = create_sample_count_heatmap(pivot_df, target, colorscale=sample_colors) st.plotly_chart(fig, width="stretch") with tab2: st.markdown("### Sample Counts Bar Chart") st.markdown("Showing counts of samples with coverage") - # Create a faceted bar chart showing both targets side by side # Get color palette for tasks n_tasks = sample_df["Task"].nunique() task_colors = get_palette("task_types", n_colors=n_tasks) - fig = px.bar( - sample_df, - x="Grid", - y="Samples (Coverage)", - color="Task", - facet_col="Target", - barmode="group", - title="Sample Counts by Grid Configuration and Target Dataset", - labels={ - "Grid": "Grid Configuration", - "Samples (Coverage)": "Number of Samples", - }, - color_discrete_sequence=task_colors, - height=500, - ) - - # Update facet labels to be cleaner - fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) - fig.update_xaxes(tickangle=-45) + # Create and display bar chart + fig = create_sample_count_bar_chart(sample_df, task_colors=task_colors) st.plotly_chart(fig, width="stretch") with tab3: @@ -153,65 +127,22 @@ def render_feature_count_comparison(): n_sources = len(unique_sources) source_colors = get_palette("data_sources", n_colors=n_sources) - # Create stacked bar chart - fig = px.bar( - breakdown_df, - x="Grid", - y="Number of Features", - color="Data Source", - barmode="stack", - title="Total Features by Data Source Across Grid Configurations", - labels={ - "Grid": "Grid Configuration", - "Number of Features": "Number of Features", - }, - color_discrete_sequence=source_colors, - text_auto=False, - ) - - fig.update_layout(height=500, xaxis_tickangle=-45) + # Create and display stacked bar chart + fig = create_feature_count_stacked_bar(breakdown_df, source_colors=source_colors) st.plotly_chart(fig, width="stretch") # Add secondary metrics col1, col2 = st.columns(2) - with col1: - # Get color palette for grid configs - n_grids = len(comparison_df) - grid_colors = get_palette("grid_configs", n_colors=n_grids) + # Get color palette for grid configs + n_grids = len(comparison_df) + grid_colors = get_palette("grid_configs", n_colors=n_grids) - fig_cells = px.bar( - comparison_df, - x="Grid", - y="Inference Cells", - color="Grid", - title="Inference Cells by Grid Configuration", - labels={ - "Grid": "Grid Configuration", - "Inference Cells": "Number of Cells", - }, - color_discrete_sequence=grid_colors, - text="Inference Cells", - ) - fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside") - fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False) + with col1: + fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors) st.plotly_chart(fig_cells, width="stretch") with col2: - fig_samples = px.bar( - comparison_df, - x="Grid", - y="Total Samples", - color="Grid", - title="Total Samples by Grid Configuration", - labels={ - "Grid": "Grid Configuration", - "Total Samples": "Number of Samples", - }, - color_discrete_sequence=grid_colors, - text="Total Samples", - ) - fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside") - fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False) + fig_samples = create_total_samples_bar(comparison_df, grid_colors=grid_colors) st.plotly_chart(fig_samples, width="stretch") with comp_tab2: @@ -238,16 +169,7 @@ def render_feature_count_comparison(): grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] with cols[col_idx]: - fig = px.pie( - grid_data, - names="Data Source", - values="Number of Features", - title=grid_config, - hole=0.4, - color_discrete_sequence=source_colors, - ) - fig.update_traces(textposition="inside", textinfo="percent") - fig.update_layout(showlegend=True, height=350) + fig = create_feature_breakdown_donut(grid_data, grid_config, source_colors=source_colors) st.plotly_chart(fig, width="stretch") with comp_tab3: @@ -383,16 +305,8 @@ def render_feature_count_explorer(): n_sources = len(breakdown_df) source_colors = get_palette("data_sources", n_colors=n_sources) - # Create pie chart - fig = px.pie( - breakdown_df, - names="Data Source", - values="Number of Features", - title="Feature Distribution by Data Source", - hole=0.4, - color_discrete_sequence=source_colors, - ) - fig.update_traces(textposition="inside", textinfo="percent+label") + # Create and display pie chart + fig = create_feature_distribution_pie(breakdown_df, source_colors=source_colors) st.plotly_chart(fig, width="stretch") # Show detailed table