Improve the dataset configuration explorer

This commit is contained in:
Tobias Hölzer 2026-01-16 21:35:33 +01:00
parent c9c6af8370
commit a6e9a91692
2 changed files with 513 additions and 152 deletions

View file

@ -1,5 +1,7 @@
"""Dataset Statistics Section for Entropice Dashboard.""" """Dataset Statistics Section for Entropice Dashboard."""
from typing import Any
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
@ -12,8 +14,10 @@ from entropice.dashboard.plots.dataset_statistics import (
create_sample_count_bar_chart, create_sample_count_bar_chart,
) )
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import ( from entropice.utils.types import (
GridConfig,
GridLevel, GridLevel,
L2SourceDataset, L2SourceDataset,
TargetDataset, TargetDataset,
@ -97,17 +101,13 @@ def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
st.dataframe(breakdown_df, hide_index=True, width="stretch") st.dataframe(breakdown_df, hide_index=True, width="stretch")
@st.fragment def _render_grid_and_task_selection() -> tuple[GridConfig, TemporalMode, TargetDataset, Task]:
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901 """Render grid, temporal mode, target, and task selection controls.
"""Render interactive detailed configuration explorer."""
st.markdown(
"""
Explore detailed statistics for a specific dataset configuration.
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
"""
)
# Grid selection Returns:
Tuple of (grid_config, temporal_mode, target_dataset, task)
"""
grid_options = [gc.display_name for gc in grid_configs] grid_options = [gc.display_name for gc in grid_configs]
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
@ -159,18 +159,19 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
# Find the selected grid config # Find the selected grid config
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display) selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
# Get stats for the selected configuration return selected_grid_config, selected_temporal_mode, selected_target, selected_task
mode_stats = dataset_stats[selected_grid_config.id]
# Check if the selected temporal mode has stats
if selected_temporal_mode not in mode_stats:
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
return
stats = mode_stats[selected_temporal_mode] def _render_member_selection(available_members: list[L2SourceDataset]) -> list[L2SourceDataset]:
available_members = list(stats.members.keys()) """Render data source member selection checkboxes.
# Members selection Args:
available_members: List of available member dataset names
Returns:
List of selected member datasets
"""
st.markdown("#### Select Data Sources") st.markdown("#### Select Data Sources")
# Use columns for checkboxes # Use columns for checkboxes
@ -183,11 +184,277 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"): if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
selected_members.append(member) # type: ignore[arg-type] selected_members.append(member) # type: ignore[arg-type]
# Show results if at least one member is selected return selected_members
if selected_members:
# Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
@st.cache_data(ttl=3600)
def _read_member_dataset_cached(
grid: str,
level: int,
temporal_mode: str | int,
member: str,
):
"""Read a member dataset with caching to avoid redundant loads.
Args:
grid: Grid system name
level: Grid resolution level
temporal_mode: Temporal mode (feature, synopsis, or year)
member: Member dataset name
Returns:
Lazy-loaded xarray dataset, or None if reading fails
"""
try:
ensemble = DatasetEnsemble(
grid=grid, # type: ignore[arg-type]
level=level,
temporal_mode=temporal_mode, # type: ignore[arg-type]
)
return ensemble.read_member(member, lazy=True)
except Exception:
return None
def _get_member_datasets(
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
) -> dict[L2SourceDataset, Any]:
"""Read all member datasets once using cached function.
Args:
temp_ensemble: Temporary ensemble to get grid/level/temporal_mode from
selected_members: List of selected member datasets
Returns:
Dictionary mapping members to their loaded datasets (or None if failed)
"""
member_datasets = {}
for member in selected_members:
ds = _read_member_dataset_cached(
grid=temp_ensemble.grid,
level=temp_ensemble.level,
temporal_mode=temp_ensemble.temporal_mode,
member=member,
)
member_datasets[member] = ds
return member_datasets
def _get_member_aggregations(member_datasets: dict[L2SourceDataset, Any]) -> dict[L2SourceDataset, list[str]]:
"""Get available aggregation dimensions for each selected member.
Args:
member_datasets: Dictionary mapping members to their loaded datasets
Returns:
Dictionary mapping members to their aggregation dimension names
"""
member_aggregations: dict[L2SourceDataset, list[str]] = {}
for member, member_ds in member_datasets.items():
if member_ds is None:
member_aggregations[member] = []
continue
# Find aggregation dimensions (only dimensions named "agg" or "aggregations")
agg_dims = [dim for dim in member_ds.dims if dim in ("agg", "aggregations")]
member_aggregations[member] = sorted(agg_dims)
return member_aggregations
def _set_all_aggregations(
member_datasets: dict[L2SourceDataset, Any],
members_with_aggs: list[L2SourceDataset],
member_aggregations: dict[L2SourceDataset, list[str]],
selected: bool,
):
"""Set all aggregation checkboxes to selected or deselected state.
Args:
member_datasets: Dictionary mapping members to their loaded datasets
members_with_aggs: List of members that have aggregations
member_aggregations: Dictionary mapping members to aggregation dimensions
selected: True to select all, False to deselect all
"""
for member in members_with_aggs:
member_ds = member_datasets.get(member)
if member_ds is None:
continue
agg_dims = member_aggregations.get(member, [])
for agg_dim in agg_dims:
if agg_dim in member_ds.dims:
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
for val in agg_values:
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = selected
def _set_median_only_aggregations(
member_datasets: dict[L2SourceDataset, Any],
members_with_aggs: list[L2SourceDataset],
member_aggregations: dict[L2SourceDataset, list[str]],
):
"""Set only median aggregations to selected, deselect all others.
Args:
member_datasets: Dictionary mapping members to their loaded datasets
members_with_aggs: List of members that have aggregations
member_aggregations: Dictionary mapping members to aggregation dimensions
"""
for member in members_with_aggs:
member_ds = member_datasets.get(member)
if member_ds is None:
continue
agg_dims = member_aggregations.get(member, [])
for agg_dim in agg_dims:
if agg_dim in member_ds.dims:
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
for val in agg_values:
# Select only if value is or contains 'median'
is_median = str(val).lower() == "median" or "median" in str(val).lower()
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = is_median
def _render_aggregation_form(
member_datasets: dict[L2SourceDataset, Any],
members_with_aggs: list[L2SourceDataset],
member_aggregations: dict[L2SourceDataset, list[str]],
) -> dict[str, dict[str, list]]:
"""Render aggregation selection form with checkboxes.
Args:
member_datasets: Dictionary mapping members to their loaded datasets
members_with_aggs: List of members that have aggregations
member_aggregations: Dictionary mapping members to aggregation dimensions
Returns:
Dictionary mapping member names to dimension filters
"""
dimension_filters: dict[str, dict[str, list]] = {}
if not members_with_aggs:
return dimension_filters
member_cols = st.columns(len(members_with_aggs))
for col_idx, member in enumerate(members_with_aggs):
with member_cols[col_idx]:
st.markdown(f"**{member}:**")
member_ds = member_datasets.get(member)
if member_ds is None:
continue
aggs = member_aggregations[member]
for agg_dim in aggs:
if agg_dim not in member_ds.dims:
continue
# Get coordinate values for this dimension
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
# Convert to strings for display
agg_values = [str(v) for v in agg_values]
st.markdown(f"*{agg_dim}:*")
selected_vals = []
for val in agg_values:
if st.checkbox(
val,
value=True,
key=f"feature_agg_{member}_{agg_dim}_{val}",
help=f"Include {val} from {agg_dim}",
):
selected_vals.append(val)
# Store selected values in dimension_filters
if selected_vals and len(selected_vals) < len(agg_values):
if member not in dimension_filters:
dimension_filters[member] = {}
dimension_filters[member][agg_dim] = selected_vals
return dimension_filters
def _render_aggregation_selection(
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
) -> dict[str, dict[str, list]]:
"""Render aggregation selection controls for selected members.
Args:
temp_ensemble: Temporary ensemble to inspect for aggregations
selected_members: List of selected member datasets
Returns:
Dictionary mapping member names to dimension filters
"""
if not selected_members:
return {}
# Read all member datasets once (cached)
member_datasets = _get_member_datasets(temp_ensemble, selected_members)
# Get available aggregations for each selected member
member_aggregations = _get_member_aggregations(member_datasets)
members_with_aggs: list[L2SourceDataset] = [m for m in selected_members if member_aggregations.get(m)]
if not members_with_aggs:
return {}
# Add Select All / Deselect All buttons
st.markdown("#### Select Aggregations")
st.markdown(
"Select which spatial aggregations to use for each data source. "
"Different members may have different aggregation options depending on the grid configuration."
)
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
with col_btn1:
if st.button("✅ Select All", use_container_width=True):
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=True)
with col_btn2:
if st.button("📊 Median Only", use_container_width=True):
_set_median_only_aggregations(member_datasets, members_with_aggs, member_aggregations)
with col_btn3:
if st.button("❌ Deselect All", use_container_width=True):
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=False)
# Render the form with checkboxes
with st.form("aggregation_selection_form"):
dimension_filters = _render_aggregation_form(member_datasets, members_with_aggs, member_aggregations)
# Submit button for the form
submitted = st.form_submit_button("Apply Aggregation Filters", type="primary")
if not submitted:
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
st.stop()
return dimension_filters
def _render_configuration_summary(
selected_members: list[L2SourceDataset],
selected_member_stats: dict[str, MemberStatistics],
selected_target: TargetDataset,
selected_task: Task,
selected_temporal_mode: TemporalMode,
stats: DatasetStatistics,
):
"""Render configuration summary with metrics and statistics.
Args:
selected_members: List of selected member datasets
selected_member_stats: Statistics for selected members
selected_target: Selected target dataset
selected_task: Selected task type
selected_temporal_mode: Selected temporal mode
stats: Overall dataset statistics
"""
# Calculate total features for selected members # Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values()) total_features = sum(ms.feature_count for ms in selected_member_stats.values())
@ -335,6 +602,78 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
st.markdown(dim_html, unsafe_allow_html=True) st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---") st.markdown("---")
@st.fragment
def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
"""Render interactive detailed configuration explorer.
Args:
all_stats: Pre-computed statistics cache for optimization
"""
st.markdown(
"""
Explore detailed statistics for a specific dataset configuration.
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
"""
)
# Grid, temporal mode, target, and task selection
selected_grid_config, selected_temporal_mode, selected_target, selected_task = _render_grid_and_task_selection()
# Create dataset ensemble on-the-fly and compute statistics
temp_ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
temporal_mode=selected_temporal_mode,
)
available_members = temp_ensemble.members
# Member selection
selected_members = _render_member_selection(available_members)
# Aggregation selection
dimension_filters = _render_aggregation_selection(temp_ensemble, selected_members)
# Show results if at least one member is selected
if selected_members:
# Optimize: Use pre-computed stats if no dimensional filtering is needed
grid_level_key = f"{selected_grid_config.grid}{selected_grid_config.level}" # e.g., "hex3", "healpix6"
use_cached_stats = (
not dimension_filters # No dimension filters applied
and grid_level_key in all_stats # Grid level exists in cache
and selected_temporal_mode in all_stats[grid_level_key] # type: ignore[literal-required] # Temporal mode exists
)
if use_cached_stats:
# Use pre-computed statistics (much faster)
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
# Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
else:
# Create the actual ensemble with selected members and dimension filters
ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
temporal_mode=selected_temporal_mode,
members=selected_members,
dimension_filters=dimension_filters,
)
# Recompute stats with the filtered ensemble
stats = DatasetStatistics.from_ensemble(ensemble)
# Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
# Render configuration summary and statistics
_render_configuration_summary(
selected_members=selected_members,
selected_member_stats=selected_member_stats,
selected_target=selected_target,
selected_task=selected_task,
selected_temporal_mode=selected_temporal_mode,
stats=stats,
)
else: else:
st.info("👆 Select at least one data source to see feature statistics") st.info("👆 Select at least one data source to see feature statistics")

View file

@ -21,7 +21,6 @@ from entropice.utils.types import (
TargetDataset, TargetDataset,
Task, Task,
TemporalMode, TemporalMode,
all_l2_source_datasets,
all_target_datasets, all_target_datasets,
all_tasks, all_tasks,
all_temporal_modes, all_temporal_modes,
@ -43,7 +42,7 @@ class MemberStatistics:
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]: def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
"""Pre-compute the statistics for a specific dataset member.""" """Pre-compute the statistics for a specific dataset member."""
member_stats = {} member_stats = {}
for member in all_l2_source_datasets: for member in e.members:
ds = e.read_member(member, lazy=True) ds = e.read_member(member, lazy=True)
size_bytes = ds.nbytes size_bytes = ds.nbytes
@ -113,6 +112,29 @@ class DatasetStatistics:
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
@classmethod
def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
"""Compute dataset statistics from a DatasetEnsemble."""
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
total_cells = len(grid_gdf)
target_statistics = {}
for target in all_target_datasets:
if isinstance(e.temporal_mode, int) and target == "darts_mllabels":
# darts_mllabels does not support year-based temporal modes
continue
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
member_statistics = MemberStatistics.compute(e)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
return cls(
total_features=total_features,
total_cells=total_cells,
size_bytes=total_size_bytes,
members=member_statistics,
target=target_statistics,
)
@staticmethod @staticmethod
def get_target_sample_count_df( def get_target_sample_count_df(
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]], all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],