Split overview plots

This commit is contained in:
Tobias Hölzer 2026-01-04 21:35:21 +01:00
parent 393cc968cb
commit 4fecac535c
7 changed files with 316 additions and 163 deletions

View file

@ -3,16 +3,15 @@
from pathlib import Path from pathlib import Path
import altair as alt import altair as alt
import antimeridian
import geopandas as gpd import geopandas as gpd
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pydeck as pdk import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_cmap, get_palette from entropice.dashboard.utils.colors import get_cmap, get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
@ -1152,15 +1151,6 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
st.dataframe(display_df, hide_index=True, width="stretch") st.dataframe(display_df, hide_index=True, width="stretch")
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
@st.fragment @st.fragment
def render_confusion_matrix_map(result_path: Path, settings: dict): def render_confusion_matrix_map(result_path: Path, settings: dict):
"""Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN). """Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
@ -1288,7 +1278,7 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
# Fix antimeridian issues for hex grids # Fix antimeridian issues for hex grids
if grid == "hex": if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Assign colors based on confusion category # Assign colors based on confusion category
if task == "binary": if task == "binary":

View file

@ -5,23 +5,12 @@ import pandas as pd
import plotly.graph_objects as go import plotly.graph_objects as go
import pydeck as pdk import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.dashboard.utils.loaders import TrainingResult from entropice.dashboard.utils.loaders import TrainingResult
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
import antimeridian
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str): def render_inference_statistics(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render summary statistics about inference results. """Render summary statistics about inference results.
@ -249,7 +238,7 @@ def render_inference_map(result: TrainingResult):
# Fix antimeridian issues for hex grids # Fix antimeridian issues for hex grids
if grid == "hex": if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(_fix_hex_geometry) display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Assign colors based on predicted class # Assign colors based on predicted class
colors_palette = get_palette(task, len(all_classes)) colors_palette = get_palette(task, len(all_classes))

View file

@ -0,0 +1,244 @@
"""Visualization functions for the overview page.
This module contains reusable plotting functions for dataset analysis visualizations,
including sample counts, feature counts, and dataset statistics.
"""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
def create_sample_count_heatmap(
pivot_df: pd.DataFrame,
target: str,
colorscale: list[str] | None = None,
) -> go.Figure:
"""Create heatmap showing sample counts by grid and task.
Args:
pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values.
target: Target dataset name for the title.
colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the heatmap visualization.
"""
fig = px.imshow(
pivot_df,
labels={
"x": "Task",
"y": "Grid Configuration",
"color": "Sample Count",
},
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=colorscale,
aspect="auto",
title=f"Target: {target}",
)
# Add text annotations
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
fig.update_layout(height=400)
return fig
def create_sample_count_bar_chart(
sample_df: pd.DataFrame,
task_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart showing sample counts by grid, target, and task.
Args:
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
task_colors: Optional color palette for tasks. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
sample_df,
x="Grid",
y="Samples (Coverage)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={
"Grid": "Grid Configuration",
"Samples (Coverage)": "Number of Samples",
},
color_discrete_sequence=task_colors,
height=500,
)
# Update facet labels to be cleaner
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_xaxes(tickangle=-45)
return fig
def create_feature_count_stacked_bar(
breakdown_df: pd.DataFrame,
source_colors: list[str] | None = None,
) -> go.Figure:
"""Create stacked bar chart showing feature counts by data source.
Args:
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the stacked bar chart visualization.
"""
fig = px.bar(
breakdown_df,
x="Grid",
y="Number of Features",
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
labels={
"Grid": "Grid Configuration",
"Number of Features": "Number of Features",
},
color_discrete_sequence=source_colors,
text_auto=False,
)
fig.update_layout(height=500, xaxis_tickangle=-45)
return fig
def create_inference_cells_bar(
comparison_df: pd.DataFrame,
grid_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart for inference cells by grid configuration.
Args:
comparison_df: DataFrame with columns: Grid, Inference Cells.
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
comparison_df,
x="Grid",
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
},
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
return fig
def create_total_samples_bar(
comparison_df: pd.DataFrame,
grid_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart for total samples by grid configuration.
Args:
comparison_df: DataFrame with columns: Grid, Total Samples.
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
return fig
def create_feature_breakdown_donut(
grid_data: pd.DataFrame,
grid_config: str,
source_colors: list[str] | None = None,
) -> go.Figure:
"""Create donut chart for feature breakdown by data source for a specific grid.
Args:
grid_data: DataFrame with columns: Data Source, Number of Features.
grid_config: Grid configuration name for the title.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the donut chart visualization.
"""
fig = px.pie(
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
return fig
def create_feature_distribution_pie(
breakdown_df: pd.DataFrame,
source_colors: list[str] | None = None,
) -> go.Figure:
"""Create pie chart for feature distribution by data source.
Args:
breakdown_df: DataFrame with columns: Data Source, Number of Features.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the pie chart visualization.
"""
fig = px.pie(
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent+label")
return fig

View file

@ -1,6 +1,5 @@
"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5).""" """Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
import antimeridian
import geopandas as gpd import geopandas as gpd
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -8,22 +7,13 @@ import plotly.graph_objects as go
import pydeck as pdk import pydeck as pdk
import streamlit as st import streamlit as st
import xarray as xr import xarray as xr
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_cmap from entropice.dashboard.utils.colors import get_cmap
from entropice.dashboard.utils.geometry import fix_hex_geometry
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations # TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def render_alphaearth_overview(ds: xr.Dataset): def render_alphaearth_overview(ds: xr.Dataset):
"""Render overview statistics for AlphaEarth embeddings data. """Render overview statistics for AlphaEarth embeddings data.
@ -573,7 +563,7 @@ def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
# Fix geometries after CRS conversion # Fix geometries after CRS conversion
if grid == "hex": if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Normalize values for color mapping # Normalize values for color mapping
values = gdf_wgs84["value"].to_numpy() values = gdf_wgs84["value"].to_numpy()
@ -687,7 +677,7 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
# Fix geometries after CRS conversion # Fix geometries after CRS conversion
if grid == "hex": if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Normalize values for color mapping # Normalize values for color mapping
values = gdf_wgs84["value"].values values = gdf_wgs84["value"].values
@ -816,7 +806,7 @@ def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
# Fix geometries after CRS conversion # Fix geometries after CRS conversion
if grid == "hex": if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Get values for the selected metric # Get values for the selected metric
values = gdf_wgs84[area_metric].to_numpy() values = gdf_wgs84[area_metric].to_numpy()
@ -975,7 +965,7 @@ def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, tempor
# Fix geometries after CRS conversion # Fix geometries after CRS conversion
if grid == "hex": if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Normalize values for color mapping # Normalize values for color mapping
values = gdf_wgs84["value"].values values = gdf_wgs84["value"].values

View file

@ -5,9 +5,9 @@ import pandas as pd
import plotly.graph_objects as go import plotly.graph_objects as go
import pydeck as pdk import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import CategoricalTrainingDataset from entropice.ml.dataset import CategoricalTrainingDataset
@ -105,17 +105,6 @@ def render_all_distribution_histograms(
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%") st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
import antimeridian
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task): def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
"""Assign colors to geodataframe based on the selected color mode. """Assign colors to geodataframe based on the selected color mode.
@ -204,7 +193,7 @@ def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment] gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
# Fix antimeridian issues # Fix antimeridian issues
gdf["geometry"] = gdf["geometry"].apply(_fix_hex_geometry) gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
# Add binned labels and split information from current dataset # Add binned labels and split information from current dataset
gdf["target_class"] = dataset.y.binned.to_numpy() gdf["target_class"] = dataset.y.binned.to_numpy()

View file

@ -0,0 +1,37 @@
"""Geometry utilities for dashboard visualizations."""
import streamlit as st
from shapely.geometry import shape
try:
import antimeridian
except ImportError:
antimeridian = None
def fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian.
This function handles geometries that cross the antimeridian (180° longitude)
which can cause rendering issues in visualization libraries. Uses the antimeridian
library to split and fix such geometries.
Args:
geom: A geometry object (typically from shapely or geojson).
Returns:
Fixed geometry object with antimeridian issues resolved.
Note:
If the antimeridian library is not available or an error occurs,
returns the original geometry unchanged.
"""
if antimeridian is None:
return geom
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom

View file

@ -4,10 +4,18 @@ from datetime import datetime
from typing import cast from typing import cast
import pandas as pd import pandas as pd
import plotly.express as px
import streamlit as st import streamlit as st
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice.dashboard.plots.overview import (
create_feature_breakdown_donut,
create_feature_count_stacked_bar,
create_feature_distribution_pie,
create_inference_cells_bar,
create_sample_count_bar_chart,
create_sample_count_heatmap,
create_total_samples_bar,
)
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import ( from entropice.dashboard.utils.stats import (
@ -67,54 +75,20 @@ def render_sample_count_overview():
# Get color palette for sample counts # Get color palette for sample counts
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10) sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
fig = px.imshow( # Create and display heatmap
pivot_df, fig = create_sample_count_heatmap(pivot_df, target, colorscale=sample_colors)
labels={
"x": "Task",
"y": "Grid Configuration",
"color": "Sample Count",
},
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=sample_colors,
aspect="auto",
title=f"Target: {target}",
)
# Add text annotations
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
fig.update_layout(height=400)
st.plotly_chart(fig, width="stretch") st.plotly_chart(fig, width="stretch")
with tab2: with tab2:
st.markdown("### Sample Counts Bar Chart") st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with coverage") st.markdown("Showing counts of samples with coverage")
# Create a faceted bar chart showing both targets side by side
# Get color palette for tasks # Get color palette for tasks
n_tasks = sample_df["Task"].nunique() n_tasks = sample_df["Task"].nunique()
task_colors = get_palette("task_types", n_colors=n_tasks) task_colors = get_palette("task_types", n_colors=n_tasks)
fig = px.bar( # Create and display bar chart
sample_df, fig = create_sample_count_bar_chart(sample_df, task_colors=task_colors)
x="Grid",
y="Samples (Coverage)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={
"Grid": "Grid Configuration",
"Samples (Coverage)": "Number of Samples",
},
color_discrete_sequence=task_colors,
height=500,
)
# Update facet labels to be cleaner
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_xaxes(tickangle=-45)
st.plotly_chart(fig, width="stretch") st.plotly_chart(fig, width="stretch")
with tab3: with tab3:
@ -153,65 +127,22 @@ def render_feature_count_comparison():
n_sources = len(unique_sources) n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources) source_colors = get_palette("data_sources", n_colors=n_sources)
# Create stacked bar chart # Create and display stacked bar chart
fig = px.bar( fig = create_feature_count_stacked_bar(breakdown_df, source_colors=source_colors)
breakdown_df,
x="Grid",
y="Number of Features",
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
labels={
"Grid": "Grid Configuration",
"Number of Features": "Number of Features",
},
color_discrete_sequence=source_colors,
text_auto=False,
)
fig.update_layout(height=500, xaxis_tickangle=-45)
st.plotly_chart(fig, width="stretch") st.plotly_chart(fig, width="stretch")
# Add secondary metrics # Add secondary metrics
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
with col1: # Get color palette for grid configs
# Get color palette for grid configs n_grids = len(comparison_df)
n_grids = len(comparison_df) grid_colors = get_palette("grid_configs", n_colors=n_grids)
grid_colors = get_palette("grid_configs", n_colors=n_grids)
fig_cells = px.bar( with col1:
comparison_df, fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
x="Grid",
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
},
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_cells, width="stretch") st.plotly_chart(fig_cells, width="stretch")
with col2: with col2:
fig_samples = px.bar( fig_samples = create_total_samples_bar(comparison_df, grid_colors=grid_colors)
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
st.plotly_chart(fig_samples, width="stretch") st.plotly_chart(fig_samples, width="stretch")
with comp_tab2: with comp_tab2:
@ -238,16 +169,7 @@ def render_feature_count_comparison():
grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
with cols[col_idx]: with cols[col_idx]:
fig = px.pie( fig = create_feature_breakdown_donut(grid_data, grid_config, source_colors=source_colors)
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
st.plotly_chart(fig, width="stretch") st.plotly_chart(fig, width="stretch")
with comp_tab3: with comp_tab3:
@ -383,16 +305,8 @@ def render_feature_count_explorer():
n_sources = len(breakdown_df) n_sources = len(breakdown_df)
source_colors = get_palette("data_sources", n_colors=n_sources) source_colors = get_palette("data_sources", n_colors=n_sources)
# Create pie chart # Create and display pie chart
fig = px.pie( fig = create_feature_distribution_pie(breakdown_df, source_colors=source_colors)
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_sequence=source_colors,
)
fig.update_traces(textposition="inside", textinfo="percent+label")
st.plotly_chart(fig, width="stretch") st.plotly_chart(fig, width="stretch")
# Show detailed table # Show detailed table