diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py
index 34ae0cb..c0c1a56 100644
--- a/src/entropice/dashboard/sections/dataset_statistics.py
+++ b/src/entropice/dashboard/sections/dataset_statistics.py
@@ -1,5 +1,7 @@
"""Dataset Statistics Section for Entropice Dashboard."""
+from typing import Any
+
import pandas as pd
import streamlit as st
@@ -12,8 +14,10 @@ from entropice.dashboard.plots.dataset_statistics import (
create_sample_count_bar_chart,
)
from entropice.dashboard.utils.colors import get_palette
-from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
+from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
+from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
+ GridConfig,
GridLevel,
L2SourceDataset,
TargetDataset,
@@ -97,17 +101,13 @@ def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
st.dataframe(breakdown_df, hide_index=True, width="stretch")
-@st.fragment
-def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
- """Render interactive detailed configuration explorer."""
- st.markdown(
- """
- Explore detailed statistics for a specific dataset configuration.
- Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
- """
- )
+def _render_grid_and_task_selection() -> tuple[GridConfig, TemporalMode, TargetDataset, Task]:
+ """Render grid, temporal mode, target, and task selection controls.
- # Grid selection
+ Returns:
+ Tuple of (grid_config, temporal_mode, target_dataset, task)
+
+ """
grid_options = [gc.display_name for gc in grid_configs]
col1, col2, col3, col4 = st.columns(4)
@@ -159,18 +159,19 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
# Find the selected grid config
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
- # Get stats for the selected configuration
- mode_stats = dataset_stats[selected_grid_config.id]
+ return selected_grid_config, selected_temporal_mode, selected_target, selected_task
- # Check if the selected temporal mode has stats
- if selected_temporal_mode not in mode_stats:
- st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
- return
- stats = mode_stats[selected_temporal_mode]
- available_members = list(stats.members.keys())
+def _render_member_selection(available_members: list[L2SourceDataset]) -> list[L2SourceDataset]:
+ """Render data source member selection checkboxes.
- # Members selection
+ Args:
+ available_members: List of available member dataset names
+
+ Returns:
+ List of selected member datasets
+
+ """
st.markdown("#### Select Data Sources")
# Use columns for checkboxes
@@ -183,158 +184,496 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
selected_members.append(member) # type: ignore[arg-type]
- # Show results if at least one member is selected
- if selected_members:
- # Filter to selected members only
- selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+ return selected_members
- # Calculate total features for selected members
- total_features = sum(ms.feature_count for ms in selected_member_stats.values())
- # Get target stats if available
- if selected_target not in stats.target:
- st.warning(f"Target {selected_target} is not available for this configuration.")
- return
+@st.cache_data(ttl=3600)
+def _read_member_dataset_cached(
+ grid: str,
+ level: int,
+ temporal_mode: str | int,
+ member: str,
+):
+ """Read a member dataset with caching to avoid redundant loads.
- target_stats = stats.target[selected_target]
+ Args:
+ grid: Grid system name
+ level: Grid resolution level
+ temporal_mode: Temporal mode (feature, synopsis, or year)
+ member: Member dataset name
- # Check if task is available
- if selected_task not in target_stats:
- st.warning(
- f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
- )
- return
+ Returns:
+ Lazy-loaded xarray dataset, or None if reading fails
- task_stats = target_stats[selected_task]
+ """
+ try:
+ ensemble = DatasetEnsemble(
+ grid=grid, # type: ignore[arg-type]
+ level=level,
+ temporal_mode=temporal_mode, # type: ignore[arg-type]
+ )
+ return ensemble.read_member(member, lazy=True)
+ except Exception:
+ return None
- # High-level metrics
- st.markdown("#### Configuration Summary")
- col1, col2, col3, col4, col5 = st.columns(5)
- with col1:
- st.metric("Total Features", f"{total_features:,}")
- with col2:
- # Calculate minimum cells across all data sources (for inference capability)
- if selected_member_stats:
- min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
- else:
- min_cells = 0
- st.metric(
- "Inference Cells",
- f"{min_cells:,}",
- help="Minimum number of cells across all selected data sources",
- )
- with col3:
- st.metric("Data Sources", len(selected_members))
- with col4:
- st.metric("Training Samples", f"{task_stats.training_cells:,}")
- with col5:
- # Calculate total data points
- total_points = total_features * task_stats.training_cells
- st.metric("Total Data Points", f"{total_points:,}")
- # Task-specific statistics
- st.markdown("#### Task Statistics")
- task_col1, task_col2, task_col3 = st.columns(3)
- with task_col1:
- st.metric("Task Type", selected_task.replace("_", " ").title())
- with task_col2:
- st.metric("Coverage", f"{task_stats.coverage:.2f}%")
- with task_col3:
- if task_stats.class_counts:
- st.metric("Number of Classes", len(task_stats.class_counts))
- else:
- st.metric("Task Mode", "Regression")
+def _get_member_datasets(
+ temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
+) -> dict[L2SourceDataset, Any]:
+ """Read all member datasets once using cached function.
- # Class distribution for classification tasks
- if task_stats.class_distribution:
- st.markdown("#### Class Distribution")
- class_dist_df = pd.DataFrame(
+ Args:
+ temp_ensemble: Temporary ensemble to get grid/level/temporal_mode from
+ selected_members: List of selected member datasets
+
+ Returns:
+ Dictionary mapping members to their loaded datasets (or None if failed)
+
+ """
+ member_datasets = {}
+ for member in selected_members:
+ ds = _read_member_dataset_cached(
+ grid=temp_ensemble.grid,
+ level=temp_ensemble.level,
+ temporal_mode=temp_ensemble.temporal_mode,
+ member=member,
+ )
+ member_datasets[member] = ds
+ return member_datasets
+
+
+def _get_member_aggregations(member_datasets: dict[L2SourceDataset, Any]) -> dict[L2SourceDataset, list[str]]:
+ """Get available aggregation dimensions for each selected member.
+
+ Args:
+ member_datasets: Dictionary mapping members to their loaded datasets
+
+ Returns:
+ Dictionary mapping members to their aggregation dimension names
+
+ """
+ member_aggregations: dict[L2SourceDataset, list[str]] = {}
+ for member, member_ds in member_datasets.items():
+ if member_ds is None:
+ member_aggregations[member] = []
+ continue
+ # Find aggregation dimensions (only dimensions named "agg" or "aggregations")
+ agg_dims = [dim for dim in member_ds.dims if dim in ("agg", "aggregations")]
+ member_aggregations[member] = sorted(agg_dims)
+ return member_aggregations
+
+
+def _set_all_aggregations(
+ member_datasets: dict[L2SourceDataset, Any],
+ members_with_aggs: list[L2SourceDataset],
+ member_aggregations: dict[L2SourceDataset, list[str]],
+ selected: bool,
+):
+ """Set all aggregation checkboxes to selected or deselected state.
+
+ Args:
+ member_datasets: Dictionary mapping members to their loaded datasets
+ members_with_aggs: List of members that have aggregations
+ member_aggregations: Dictionary mapping members to aggregation dimensions
+ selected: True to select all, False to deselect all
+
+ """
+ for member in members_with_aggs:
+ member_ds = member_datasets.get(member)
+ if member_ds is None:
+ continue
+ agg_dims = member_aggregations.get(member, [])
+ for agg_dim in agg_dims:
+ if agg_dim in member_ds.dims:
+ agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
+ for val in agg_values:
+ st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = selected
+
+
+def _set_median_only_aggregations(
+ member_datasets: dict[L2SourceDataset, Any],
+ members_with_aggs: list[L2SourceDataset],
+ member_aggregations: dict[L2SourceDataset, list[str]],
+):
+ """Set only median aggregations to selected, deselect all others.
+
+ Args:
+ member_datasets: Dictionary mapping members to their loaded datasets
+ members_with_aggs: List of members that have aggregations
+ member_aggregations: Dictionary mapping members to aggregation dimensions
+
+ """
+ for member in members_with_aggs:
+ member_ds = member_datasets.get(member)
+ if member_ds is None:
+ continue
+ agg_dims = member_aggregations.get(member, [])
+ for agg_dim in agg_dims:
+ if agg_dim in member_ds.dims:
+ agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
+ for val in agg_values:
+ # Select only if value is or contains 'median'
+ is_median = str(val).lower() == "median" or "median" in str(val).lower()
+ st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = is_median
+
+
+def _render_aggregation_form(
+ member_datasets: dict[L2SourceDataset, Any],
+ members_with_aggs: list[L2SourceDataset],
+ member_aggregations: dict[L2SourceDataset, list[str]],
+) -> dict[str, dict[str, list]]:
+ """Render aggregation selection form with checkboxes.
+
+ Args:
+ member_datasets: Dictionary mapping members to their loaded datasets
+ members_with_aggs: List of members that have aggregations
+ member_aggregations: Dictionary mapping members to aggregation dimensions
+
+ Returns:
+ Dictionary mapping member names to dimension filters
+
+ """
+ dimension_filters: dict[str, dict[str, list]] = {}
+
+ if not members_with_aggs:
+ return dimension_filters
+
+ member_cols = st.columns(len(members_with_aggs))
+
+ for col_idx, member in enumerate(members_with_aggs):
+ with member_cols[col_idx]:
+ st.markdown(f"**{member}:**")
+ member_ds = member_datasets.get(member)
+ if member_ds is None:
+ continue
+ aggs = member_aggregations[member]
+
+ for agg_dim in aggs:
+ if agg_dim not in member_ds.dims:
+ continue
+
+ # Get coordinate values for this dimension
+ agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
+ # Convert to strings for display
+ agg_values = [str(v) for v in agg_values]
+
+ st.markdown(f"*{agg_dim}:*")
+ selected_vals = []
+
+ for val in agg_values:
+ if st.checkbox(
+ val,
+ value=True,
+ key=f"feature_agg_{member}_{agg_dim}_{val}",
+ help=f"Include {val} from {agg_dim}",
+ ):
+ selected_vals.append(val)
+
+ # Store selected values in dimension_filters
+ if selected_vals and len(selected_vals) < len(agg_values):
+ if member not in dimension_filters:
+ dimension_filters[member] = {}
+ dimension_filters[member][agg_dim] = selected_vals
+
+ return dimension_filters
+
+
+def _render_aggregation_selection(
+ temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
+) -> dict[str, dict[str, list]]:
+ """Render aggregation selection controls for selected members.
+
+ Args:
+ temp_ensemble: Temporary ensemble to inspect for aggregations
+ selected_members: List of selected member datasets
+
+ Returns:
+ Dictionary mapping member names to dimension filters
+
+ """
+ if not selected_members:
+ return {}
+
+ # Read all member datasets once (cached)
+ member_datasets = _get_member_datasets(temp_ensemble, selected_members)
+
+ # Get available aggregations for each selected member
+ member_aggregations = _get_member_aggregations(member_datasets)
+ members_with_aggs: list[L2SourceDataset] = [m for m in selected_members if member_aggregations.get(m)]
+
+ if not members_with_aggs:
+ return {}
+
+ # Add Select All / Deselect All buttons
+ st.markdown("#### Select Aggregations")
+ st.markdown(
+ "Select which spatial aggregations to use for each data source. "
+ "Different members may have different aggregation options depending on the grid configuration."
+ )
+
+ col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
+ with col_btn1:
+ if st.button("✅ Select All", use_container_width=True):
+ _set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=True)
+ with col_btn2:
+ if st.button("📊 Median Only", use_container_width=True):
+ _set_median_only_aggregations(member_datasets, members_with_aggs, member_aggregations)
+ with col_btn3:
+ if st.button("❌ Deselect All", use_container_width=True):
+ _set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=False)
+
+ # Render the form with checkboxes
+ with st.form("aggregation_selection_form"):
+ dimension_filters = _render_aggregation_form(member_datasets, members_with_aggs, member_aggregations)
+
+ # Submit button for the form
+ submitted = st.form_submit_button("Apply Aggregation Filters", type="primary")
+
+ if not submitted:
+ st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
+ st.stop()
+
+ return dimension_filters
+
+
+def _render_configuration_summary(
+ selected_members: list[L2SourceDataset],
+ selected_member_stats: dict[str, MemberStatistics],
+ selected_target: TargetDataset,
+ selected_task: Task,
+ selected_temporal_mode: TemporalMode,
+ stats: DatasetStatistics,
+):
+ """Render configuration summary with metrics and statistics.
+
+ Args:
+ selected_members: List of selected member datasets
+ selected_member_stats: Statistics for selected members
+ selected_target: Selected target dataset
+ selected_task: Selected task type
+ selected_temporal_mode: Selected temporal mode
+ stats: Overall dataset statistics
+
+ """
+ # Calculate total features for selected members
+ total_features = sum(ms.feature_count for ms in selected_member_stats.values())
+
+ # Get target stats if available
+ if selected_target not in stats.target:
+ st.warning(f"Target {selected_target} is not available for this configuration.")
+ return
+
+ target_stats = stats.target[selected_target]
+
+ # Check if task is available
+ if selected_task not in target_stats:
+ st.warning(
+ f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
+ )
+ return
+
+ task_stats = target_stats[selected_task]
+
+ # High-level metrics
+ st.markdown("#### Configuration Summary")
+ col1, col2, col3, col4, col5 = st.columns(5)
+ with col1:
+ st.metric("Total Features", f"{total_features:,}")
+ with col2:
+ # Calculate minimum cells across all data sources (for inference capability)
+ if selected_member_stats:
+ min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
+ else:
+ min_cells = 0
+ st.metric(
+ "Inference Cells",
+ f"{min_cells:,}",
+ help="Minimum number of cells across all selected data sources",
+ )
+ with col3:
+ st.metric("Data Sources", len(selected_members))
+ with col4:
+ st.metric("Training Samples", f"{task_stats.training_cells:,}")
+ with col5:
+ # Calculate total data points
+ total_points = total_features * task_stats.training_cells
+ st.metric("Total Data Points", f"{total_points:,}")
+
+ # Task-specific statistics
+ st.markdown("#### Task Statistics")
+ task_col1, task_col2, task_col3 = st.columns(3)
+ with task_col1:
+ st.metric("Task Type", selected_task.replace("_", " ").title())
+ with task_col2:
+ st.metric("Coverage", f"{task_stats.coverage:.2f}%")
+ with task_col3:
+ if task_stats.class_counts:
+ st.metric("Number of Classes", len(task_stats.class_counts))
+ else:
+ st.metric("Task Mode", "Regression")
+
+ # Class distribution for classification tasks
+ if task_stats.class_distribution:
+ st.markdown("#### Class Distribution")
+ class_dist_df = pd.DataFrame(
+ [
+ {
+ "Class": class_name,
+ "Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
+ "Percentage": f"{pct:.2f}%",
+ }
+ for class_name, pct in task_stats.class_distribution.items()
+ ]
+ )
+ st.dataframe(class_dist_df, hide_index=True, width="stretch")
+
+ # Feature breakdown by source
+ st.markdown("#### Feature Breakdown by Data Source")
+
+ breakdown_data = []
+ for member, member_stats in selected_member_stats.items():
+ breakdown_data.append(
+ {
+ "Data Source": member,
+ "Number of Features": member_stats.feature_count,
+ "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
+ }
+ )
+
+ breakdown_df = pd.DataFrame(breakdown_data)
+
+ # Get all unique data sources and create color map
+ unique_members_for_color = sorted(selected_member_stats.keys())
+ source_color_map_raw = {}
+ for member in unique_members_for_color:
+ source = member.split("-")[0]
+ n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
+ palette = get_palette(source, n_colors=n_members + 2)
+ idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
+ source_color_map_raw[member] = palette[idx + 1]
+
+ # Create and display pie chart
+ fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
+ st.plotly_chart(fig, width="stretch")
+
+ # Show detailed table
+ st.dataframe(breakdown_df, hide_index=True, width="stretch")
+
+ # Detailed member information
+ with st.expander("📦 Detailed Source Information", expanded=False):
+ # Create detailed table
+ member_details_dict = {
+ member: {
+ "feature_count": ms.feature_count,
+ "variable_names": ms.variable_names,
+ "dimensions": ms.dimensions,
+ "size_bytes": ms.size_bytes,
+ }
+ for member, ms in selected_member_stats.items()
+ }
+ details_df = create_member_details_table(member_details_dict)
+ st.dataframe(details_df, hide_index=True, width="stretch")
+
+ # Individual member details
+ for member, member_stats in selected_member_stats.items():
+ st.markdown(f"### {member}")
+
+ # Variables
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
[
- {
- "Class": class_name,
- "Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
- "Percentage": f"{pct:.2f}%",
- }
- for class_name, pct in task_stats.class_distribution.items()
+ f'{v}'
+ for v in member_stats.variable_names
]
)
- st.dataframe(class_dist_df, hide_index=True, width="stretch")
+ st.markdown(vars_html, unsafe_allow_html=True)
- # Feature breakdown by source
- st.markdown("#### Feature Breakdown by Data Source")
-
- breakdown_data = []
- for member, member_stats in selected_member_stats.items():
- breakdown_data.append(
- {
- "Data Source": member,
- "Number of Features": member_stats.feature_count,
- "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
- }
+ # Dimensions
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size:,}"
+ for dim_name, dim_size in member_stats.dimensions.items()
+ ]
)
+ st.markdown(dim_html, unsafe_allow_html=True)
- breakdown_df = pd.DataFrame(breakdown_data)
+ st.markdown("---")
- # Get all unique data sources and create color map
- unique_members_for_color = sorted(selected_member_stats.keys())
- source_color_map_raw = {}
- for member in unique_members_for_color:
- source = member.split("-")[0]
- n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
- palette = get_palette(source, n_colors=n_members + 2)
- idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
- source_color_map_raw[member] = palette[idx + 1]
- # Create and display pie chart
- fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
- st.plotly_chart(fig, width="stretch")
+@st.fragment
+def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
+ """Render interactive detailed configuration explorer.
- # Show detailed table
- st.dataframe(breakdown_df, hide_index=True, width="stretch")
+ Args:
+ all_stats: Pre-computed statistics cache for optimization
- # Detailed member information
- with st.expander("📦 Detailed Source Information", expanded=False):
- # Create detailed table
- member_details_dict = {
- member: {
- "feature_count": ms.feature_count,
- "variable_names": ms.variable_names,
- "dimensions": ms.dimensions,
- "size_bytes": ms.size_bytes,
- }
- for member, ms in selected_member_stats.items()
- }
- details_df = create_member_details_table(member_details_dict)
- st.dataframe(details_df, hide_index=True, width="stretch")
+ """
+ st.markdown(
+ """
+ Explore detailed statistics for a specific dataset configuration.
+ Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
+ """
+ )
- # Individual member details
- for member, member_stats in selected_member_stats.items():
- st.markdown(f"### {member}")
+ # Grid, temporal mode, target, and task selection
+ selected_grid_config, selected_temporal_mode, selected_target, selected_task = _render_grid_and_task_selection()
- # Variables
- st.markdown("**Variables:**")
- vars_html = " ".join(
- [
- f'{v}'
- for v in member_stats.variable_names
- ]
- )
- st.markdown(vars_html, unsafe_allow_html=True)
+ # Create dataset ensemble on-the-fly and compute statistics
+ temp_ensemble = DatasetEnsemble(
+ grid=selected_grid_config.grid,
+ level=selected_grid_config.level,
+ temporal_mode=selected_temporal_mode,
+ )
+ available_members = temp_ensemble.members
- # Dimensions
- st.markdown("**Dimensions:**")
- dim_html = " ".join(
- [
- f''
- f"{dim_name}: {dim_size:,}"
- for dim_name, dim_size in member_stats.dimensions.items()
- ]
- )
- st.markdown(dim_html, unsafe_allow_html=True)
+ # Member selection
+ selected_members = _render_member_selection(available_members)
- st.markdown("---")
+ # Aggregation selection
+ dimension_filters = _render_aggregation_selection(temp_ensemble, selected_members)
+
+ # Show results if at least one member is selected
+ if selected_members:
+ # Optimize: Use pre-computed stats if no dimensional filtering is needed
+ grid_level_key = f"{selected_grid_config.grid}{selected_grid_config.level}" # e.g., "hex3", "healpix6"
+ use_cached_stats = (
+ not dimension_filters # No dimension filters applied
+ and grid_level_key in all_stats # Grid level exists in cache
+ and selected_temporal_mode in all_stats[grid_level_key] # type: ignore[literal-required] # Temporal mode exists
+ )
+
+ if use_cached_stats:
+ # Use pre-computed statistics (much faster)
+ stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
+ # Filter to selected members only
+ selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+ else:
+ # Create the actual ensemble with selected members and dimension filters
+ ensemble = DatasetEnsemble(
+ grid=selected_grid_config.grid,
+ level=selected_grid_config.level,
+ temporal_mode=selected_temporal_mode,
+ members=selected_members,
+ dimension_filters=dimension_filters,
+ )
+ # Recompute stats with the filtered ensemble
+ stats = DatasetStatistics.from_ensemble(ensemble)
+ # Filter to selected members only
+ selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+
+ # Render configuration summary and statistics
+ _render_configuration_summary(
+ selected_members=selected_members,
+ selected_member_stats=selected_member_stats,
+ selected_target=selected_target,
+ selected_task=selected_task,
+ selected_temporal_mode=selected_temporal_mode,
+ stats=stats,
+ )
else:
st.info("👆 Select at least one data source to see feature statistics")
diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py
index 9e9ddaf..d061bc4 100644
--- a/src/entropice/dashboard/utils/stats.py
+++ b/src/entropice/dashboard/utils/stats.py
@@ -21,7 +21,6 @@ from entropice.utils.types import (
TargetDataset,
Task,
TemporalMode,
- all_l2_source_datasets,
all_target_datasets,
all_tasks,
all_temporal_modes,
@@ -43,7 +42,7 @@ class MemberStatistics:
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
"""Pre-compute the statistics for a specific dataset member."""
member_stats = {}
- for member in all_l2_source_datasets:
+ for member in e.members:
ds = e.read_member(member, lazy=True)
size_bytes = ds.nbytes
@@ -113,6 +112,29 @@ class DatasetStatistics:
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
+ @classmethod
+ def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
+ """Compute dataset statistics from a DatasetEnsemble."""
+ grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
+ total_cells = len(grid_gdf)
+ target_statistics = {}
+ for target in all_target_datasets:
+ if isinstance(e.temporal_mode, int) and target == "darts_mllabels":
+ # darts_mllabels does not support year-based temporal modes
+ continue
+ target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
+ member_statistics = MemberStatistics.compute(e)
+
+ total_features = sum(ms.feature_count for ms in member_statistics.values())
+ total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
+ return cls(
+ total_features=total_features,
+ total_cells=total_cells,
+ size_bytes=total_size_bytes,
+ members=member_statistics,
+ target=target_statistics,
+ )
+
@staticmethod
def get_target_sample_count_df(
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],