Improve the dataset configuration explorer
This commit is contained in:
parent
c9c6af8370
commit
a6e9a91692
2 changed files with 513 additions and 152 deletions
|
|
@ -1,5 +1,7 @@
|
|||
"""Dataset Statistics Section for Entropice Dashboard."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
|
|
@ -12,8 +14,10 @@ from entropice.dashboard.plots.dataset_statistics import (
|
|||
create_sample_count_bar_chart,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
GridLevel,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
|
|
@ -97,17 +101,13 @@ def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
|
|||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
|
||||
"""Render interactive detailed configuration explorer."""
|
||||
st.markdown(
|
||||
"""
|
||||
Explore detailed statistics for a specific dataset configuration.
|
||||
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
||||
"""
|
||||
)
|
||||
def _render_grid_and_task_selection() -> tuple[GridConfig, TemporalMode, TargetDataset, Task]:
|
||||
"""Render grid, temporal mode, target, and task selection controls.
|
||||
|
||||
# Grid selection
|
||||
Returns:
|
||||
Tuple of (grid_config, temporal_mode, target_dataset, task)
|
||||
|
||||
"""
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
|
@ -159,18 +159,19 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
|
|||
# Find the selected grid config
|
||||
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
|
||||
|
||||
# Get stats for the selected configuration
|
||||
mode_stats = dataset_stats[selected_grid_config.id]
|
||||
return selected_grid_config, selected_temporal_mode, selected_target, selected_task
|
||||
|
||||
# Check if the selected temporal mode has stats
|
||||
if selected_temporal_mode not in mode_stats:
|
||||
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
|
||||
return
|
||||
|
||||
stats = mode_stats[selected_temporal_mode]
|
||||
available_members = list(stats.members.keys())
|
||||
def _render_member_selection(available_members: list[L2SourceDataset]) -> list[L2SourceDataset]:
|
||||
"""Render data source member selection checkboxes.
|
||||
|
||||
# Members selection
|
||||
Args:
|
||||
available_members: List of available member dataset names
|
||||
|
||||
Returns:
|
||||
List of selected member datasets
|
||||
|
||||
"""
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
# Use columns for checkboxes
|
||||
|
|
@ -183,158 +184,496 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
|
|||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||
selected_members.append(member) # type: ignore[arg-type]
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
return selected_members
|
||||
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Get target stats if available
|
||||
if selected_target not in stats.target:
|
||||
st.warning(f"Target {selected_target} is not available for this configuration.")
|
||||
return
|
||||
@st.cache_data(ttl=3600)
|
||||
def _read_member_dataset_cached(
|
||||
grid: str,
|
||||
level: int,
|
||||
temporal_mode: str | int,
|
||||
member: str,
|
||||
):
|
||||
"""Read a member dataset with caching to avoid redundant loads.
|
||||
|
||||
target_stats = stats.target[selected_target]
|
||||
Args:
|
||||
grid: Grid system name
|
||||
level: Grid resolution level
|
||||
temporal_mode: Temporal mode (feature, synopsis, or year)
|
||||
member: Member dataset name
|
||||
|
||||
# Check if task is available
|
||||
if selected_task not in target_stats:
|
||||
st.warning(
|
||||
f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
|
||||
)
|
||||
return
|
||||
Returns:
|
||||
Lazy-loaded xarray dataset, or None if reading fails
|
||||
|
||||
task_stats = target_stats[selected_task]
|
||||
"""
|
||||
try:
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid, # type: ignore[arg-type]
|
||||
level=level,
|
||||
temporal_mode=temporal_mode, # type: ignore[arg-type]
|
||||
)
|
||||
return ensemble.read_member(member, lazy=True)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# High-level metrics
|
||||
st.markdown("#### Configuration Summary")
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
if selected_member_stats:
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
else:
|
||||
min_cells = 0
|
||||
st.metric(
|
||||
"Inference Cells",
|
||||
f"{min_cells:,}",
|
||||
help="Minimum number of cells across all selected data sources",
|
||||
)
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
st.metric("Training Samples", f"{task_stats.training_cells:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = total_features * task_stats.training_cells
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Task-specific statistics
|
||||
st.markdown("#### Task Statistics")
|
||||
task_col1, task_col2, task_col3 = st.columns(3)
|
||||
with task_col1:
|
||||
st.metric("Task Type", selected_task.replace("_", " ").title())
|
||||
with task_col2:
|
||||
st.metric("Coverage", f"{task_stats.coverage:.2f}%")
|
||||
with task_col3:
|
||||
if task_stats.class_counts:
|
||||
st.metric("Number of Classes", len(task_stats.class_counts))
|
||||
else:
|
||||
st.metric("Task Mode", "Regression")
|
||||
def _get_member_datasets(
|
||||
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
|
||||
) -> dict[L2SourceDataset, Any]:
|
||||
"""Read all member datasets once using cached function.
|
||||
|
||||
# Class distribution for classification tasks
|
||||
if task_stats.class_distribution:
|
||||
st.markdown("#### Class Distribution")
|
||||
class_dist_df = pd.DataFrame(
|
||||
Args:
|
||||
temp_ensemble: Temporary ensemble to get grid/level/temporal_mode from
|
||||
selected_members: List of selected member datasets
|
||||
|
||||
Returns:
|
||||
Dictionary mapping members to their loaded datasets (or None if failed)
|
||||
|
||||
"""
|
||||
member_datasets = {}
|
||||
for member in selected_members:
|
||||
ds = _read_member_dataset_cached(
|
||||
grid=temp_ensemble.grid,
|
||||
level=temp_ensemble.level,
|
||||
temporal_mode=temp_ensemble.temporal_mode,
|
||||
member=member,
|
||||
)
|
||||
member_datasets[member] = ds
|
||||
return member_datasets
|
||||
|
||||
|
||||
def _get_member_aggregations(member_datasets: dict[L2SourceDataset, Any]) -> dict[L2SourceDataset, list[str]]:
|
||||
"""Get available aggregation dimensions for each selected member.
|
||||
|
||||
Args:
|
||||
member_datasets: Dictionary mapping members to their loaded datasets
|
||||
|
||||
Returns:
|
||||
Dictionary mapping members to their aggregation dimension names
|
||||
|
||||
"""
|
||||
member_aggregations: dict[L2SourceDataset, list[str]] = {}
|
||||
for member, member_ds in member_datasets.items():
|
||||
if member_ds is None:
|
||||
member_aggregations[member] = []
|
||||
continue
|
||||
# Find aggregation dimensions (only dimensions named "agg" or "aggregations")
|
||||
agg_dims = [dim for dim in member_ds.dims if dim in ("agg", "aggregations")]
|
||||
member_aggregations[member] = sorted(agg_dims)
|
||||
return member_aggregations
|
||||
|
||||
|
||||
def _set_all_aggregations(
|
||||
member_datasets: dict[L2SourceDataset, Any],
|
||||
members_with_aggs: list[L2SourceDataset],
|
||||
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||
selected: bool,
|
||||
):
|
||||
"""Set all aggregation checkboxes to selected or deselected state.
|
||||
|
||||
Args:
|
||||
member_datasets: Dictionary mapping members to their loaded datasets
|
||||
members_with_aggs: List of members that have aggregations
|
||||
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||
selected: True to select all, False to deselect all
|
||||
|
||||
"""
|
||||
for member in members_with_aggs:
|
||||
member_ds = member_datasets.get(member)
|
||||
if member_ds is None:
|
||||
continue
|
||||
agg_dims = member_aggregations.get(member, [])
|
||||
for agg_dim in agg_dims:
|
||||
if agg_dim in member_ds.dims:
|
||||
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||
for val in agg_values:
|
||||
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = selected
|
||||
|
||||
|
||||
def _set_median_only_aggregations(
|
||||
member_datasets: dict[L2SourceDataset, Any],
|
||||
members_with_aggs: list[L2SourceDataset],
|
||||
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||
):
|
||||
"""Set only median aggregations to selected, deselect all others.
|
||||
|
||||
Args:
|
||||
member_datasets: Dictionary mapping members to their loaded datasets
|
||||
members_with_aggs: List of members that have aggregations
|
||||
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||
|
||||
"""
|
||||
for member in members_with_aggs:
|
||||
member_ds = member_datasets.get(member)
|
||||
if member_ds is None:
|
||||
continue
|
||||
agg_dims = member_aggregations.get(member, [])
|
||||
for agg_dim in agg_dims:
|
||||
if agg_dim in member_ds.dims:
|
||||
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||
for val in agg_values:
|
||||
# Select only if value is or contains 'median'
|
||||
is_median = str(val).lower() == "median" or "median" in str(val).lower()
|
||||
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = is_median
|
||||
|
||||
|
||||
def _render_aggregation_form(
|
||||
member_datasets: dict[L2SourceDataset, Any],
|
||||
members_with_aggs: list[L2SourceDataset],
|
||||
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||
) -> dict[str, dict[str, list]]:
|
||||
"""Render aggregation selection form with checkboxes.
|
||||
|
||||
Args:
|
||||
member_datasets: Dictionary mapping members to their loaded datasets
|
||||
members_with_aggs: List of members that have aggregations
|
||||
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||
|
||||
Returns:
|
||||
Dictionary mapping member names to dimension filters
|
||||
|
||||
"""
|
||||
dimension_filters: dict[str, dict[str, list]] = {}
|
||||
|
||||
if not members_with_aggs:
|
||||
return dimension_filters
|
||||
|
||||
member_cols = st.columns(len(members_with_aggs))
|
||||
|
||||
for col_idx, member in enumerate(members_with_aggs):
|
||||
with member_cols[col_idx]:
|
||||
st.markdown(f"**{member}:**")
|
||||
member_ds = member_datasets.get(member)
|
||||
if member_ds is None:
|
||||
continue
|
||||
aggs = member_aggregations[member]
|
||||
|
||||
for agg_dim in aggs:
|
||||
if agg_dim not in member_ds.dims:
|
||||
continue
|
||||
|
||||
# Get coordinate values for this dimension
|
||||
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||
# Convert to strings for display
|
||||
agg_values = [str(v) for v in agg_values]
|
||||
|
||||
st.markdown(f"*{agg_dim}:*")
|
||||
selected_vals = []
|
||||
|
||||
for val in agg_values:
|
||||
if st.checkbox(
|
||||
val,
|
||||
value=True,
|
||||
key=f"feature_agg_{member}_{agg_dim}_{val}",
|
||||
help=f"Include {val} from {agg_dim}",
|
||||
):
|
||||
selected_vals.append(val)
|
||||
|
||||
# Store selected values in dimension_filters
|
||||
if selected_vals and len(selected_vals) < len(agg_values):
|
||||
if member not in dimension_filters:
|
||||
dimension_filters[member] = {}
|
||||
dimension_filters[member][agg_dim] = selected_vals
|
||||
|
||||
return dimension_filters
|
||||
|
||||
|
||||
def _render_aggregation_selection(
|
||||
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
|
||||
) -> dict[str, dict[str, list]]:
|
||||
"""Render aggregation selection controls for selected members.
|
||||
|
||||
Args:
|
||||
temp_ensemble: Temporary ensemble to inspect for aggregations
|
||||
selected_members: List of selected member datasets
|
||||
|
||||
Returns:
|
||||
Dictionary mapping member names to dimension filters
|
||||
|
||||
"""
|
||||
if not selected_members:
|
||||
return {}
|
||||
|
||||
# Read all member datasets once (cached)
|
||||
member_datasets = _get_member_datasets(temp_ensemble, selected_members)
|
||||
|
||||
# Get available aggregations for each selected member
|
||||
member_aggregations = _get_member_aggregations(member_datasets)
|
||||
members_with_aggs: list[L2SourceDataset] = [m for m in selected_members if member_aggregations.get(m)]
|
||||
|
||||
if not members_with_aggs:
|
||||
return {}
|
||||
|
||||
# Add Select All / Deselect All buttons
|
||||
st.markdown("#### Select Aggregations")
|
||||
st.markdown(
|
||||
"Select which spatial aggregations to use for each data source. "
|
||||
"Different members may have different aggregation options depending on the grid configuration."
|
||||
)
|
||||
|
||||
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
|
||||
with col_btn1:
|
||||
if st.button("✅ Select All", use_container_width=True):
|
||||
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=True)
|
||||
with col_btn2:
|
||||
if st.button("📊 Median Only", use_container_width=True):
|
||||
_set_median_only_aggregations(member_datasets, members_with_aggs, member_aggregations)
|
||||
with col_btn3:
|
||||
if st.button("❌ Deselect All", use_container_width=True):
|
||||
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=False)
|
||||
|
||||
# Render the form with checkboxes
|
||||
with st.form("aggregation_selection_form"):
|
||||
dimension_filters = _render_aggregation_form(member_datasets, members_with_aggs, member_aggregations)
|
||||
|
||||
# Submit button for the form
|
||||
submitted = st.form_submit_button("Apply Aggregation Filters", type="primary")
|
||||
|
||||
if not submitted:
|
||||
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
|
||||
st.stop()
|
||||
|
||||
return dimension_filters
|
||||
|
||||
|
||||
def _render_configuration_summary(
|
||||
selected_members: list[L2SourceDataset],
|
||||
selected_member_stats: dict[str, MemberStatistics],
|
||||
selected_target: TargetDataset,
|
||||
selected_task: Task,
|
||||
selected_temporal_mode: TemporalMode,
|
||||
stats: DatasetStatistics,
|
||||
):
|
||||
"""Render configuration summary with metrics and statistics.
|
||||
|
||||
Args:
|
||||
selected_members: List of selected member datasets
|
||||
selected_member_stats: Statistics for selected members
|
||||
selected_target: Selected target dataset
|
||||
selected_task: Selected task type
|
||||
selected_temporal_mode: Selected temporal mode
|
||||
stats: Overall dataset statistics
|
||||
|
||||
"""
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Get target stats if available
|
||||
if selected_target not in stats.target:
|
||||
st.warning(f"Target {selected_target} is not available for this configuration.")
|
||||
return
|
||||
|
||||
target_stats = stats.target[selected_target]
|
||||
|
||||
# Check if task is available
|
||||
if selected_task not in target_stats:
|
||||
st.warning(
|
||||
f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
|
||||
)
|
||||
return
|
||||
|
||||
task_stats = target_stats[selected_task]
|
||||
|
||||
# High-level metrics
|
||||
st.markdown("#### Configuration Summary")
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
if selected_member_stats:
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
else:
|
||||
min_cells = 0
|
||||
st.metric(
|
||||
"Inference Cells",
|
||||
f"{min_cells:,}",
|
||||
help="Minimum number of cells across all selected data sources",
|
||||
)
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
st.metric("Training Samples", f"{task_stats.training_cells:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = total_features * task_stats.training_cells
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Task-specific statistics
|
||||
st.markdown("#### Task Statistics")
|
||||
task_col1, task_col2, task_col3 = st.columns(3)
|
||||
with task_col1:
|
||||
st.metric("Task Type", selected_task.replace("_", " ").title())
|
||||
with task_col2:
|
||||
st.metric("Coverage", f"{task_stats.coverage:.2f}%")
|
||||
with task_col3:
|
||||
if task_stats.class_counts:
|
||||
st.metric("Number of Classes", len(task_stats.class_counts))
|
||||
else:
|
||||
st.metric("Task Mode", "Regression")
|
||||
|
||||
# Class distribution for classification tasks
|
||||
if task_stats.class_distribution:
|
||||
st.markdown("#### Class Distribution")
|
||||
class_dist_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"Class": class_name,
|
||||
"Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
|
||||
"Percentage": f"{pct:.2f}%",
|
||||
}
|
||||
for class_name, pct in task_stats.class_distribution.items()
|
||||
]
|
||||
)
|
||||
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||
source_color_map_raw = {}
|
||||
for member in unique_members_for_color:
|
||||
source = member.split("-")[0]
|
||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||
palette = get_palette(source, n_colors=n_members + 2)
|
||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||
source_color_map_raw[member] = palette[idx + 1]
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
# Create detailed table
|
||||
member_details_dict = {
|
||||
member: {
|
||||
"feature_count": ms.feature_count,
|
||||
"variable_names": ms.variable_names,
|
||||
"dimensions": ms.dimensions,
|
||||
"size_bytes": ms.size_bytes,
|
||||
}
|
||||
for member, ms in selected_member_stats.items()
|
||||
}
|
||||
details_df = create_member_details_table(member_details_dict)
|
||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||
|
||||
# Individual member details
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
{
|
||||
"Class": class_name,
|
||||
"Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
|
||||
"Percentage": f"{pct:.2f}%",
|
||||
}
|
||||
for class_name, pct in task_stats.class_distribution.items()
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size:,}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
st.markdown("---")
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||
source_color_map_raw = {}
|
||||
for member in unique_members_for_color:
|
||||
source = member.split("-")[0]
|
||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||
palette = get_palette(source, n_colors=n_members + 2)
|
||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||
source_color_map_raw[member] = palette[idx + 1]
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
@st.fragment
|
||||
def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||
"""Render interactive detailed configuration explorer.
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
Args:
|
||||
all_stats: Pre-computed statistics cache for optimization
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
# Create detailed table
|
||||
member_details_dict = {
|
||||
member: {
|
||||
"feature_count": ms.feature_count,
|
||||
"variable_names": ms.variable_names,
|
||||
"dimensions": ms.dimensions,
|
||||
"size_bytes": ms.size_bytes,
|
||||
}
|
||||
for member, ms in selected_member_stats.items()
|
||||
}
|
||||
details_df = create_member_details_table(member_details_dict)
|
||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||
"""
|
||||
st.markdown(
|
||||
"""
|
||||
Explore detailed statistics for a specific dataset configuration.
|
||||
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
||||
"""
|
||||
)
|
||||
|
||||
# Individual member details
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
# Grid, temporal mode, target, and task selection
|
||||
selected_grid_config, selected_temporal_mode, selected_target, selected_task = _render_grid_and_task_selection()
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
# Create dataset ensemble on-the-fly and compute statistics
|
||||
temp_ensemble = DatasetEnsemble(
|
||||
grid=selected_grid_config.grid,
|
||||
level=selected_grid_config.level,
|
||||
temporal_mode=selected_temporal_mode,
|
||||
)
|
||||
available_members = temp_ensemble.members
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size:,}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
# Member selection
|
||||
selected_members = _render_member_selection(available_members)
|
||||
|
||||
st.markdown("---")
|
||||
# Aggregation selection
|
||||
dimension_filters = _render_aggregation_selection(temp_ensemble, selected_members)
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
# Optimize: Use pre-computed stats if no dimensional filtering is needed
|
||||
grid_level_key = f"{selected_grid_config.grid}{selected_grid_config.level}" # e.g., "hex3", "healpix6"
|
||||
use_cached_stats = (
|
||||
not dimension_filters # No dimension filters applied
|
||||
and grid_level_key in all_stats # Grid level exists in cache
|
||||
and selected_temporal_mode in all_stats[grid_level_key] # type: ignore[literal-required] # Temporal mode exists
|
||||
)
|
||||
|
||||
if use_cached_stats:
|
||||
# Use pre-computed statistics (much faster)
|
||||
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
else:
|
||||
# Create the actual ensemble with selected members and dimension filters
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=selected_grid_config.grid,
|
||||
level=selected_grid_config.level,
|
||||
temporal_mode=selected_temporal_mode,
|
||||
members=selected_members,
|
||||
dimension_filters=dimension_filters,
|
||||
)
|
||||
# Recompute stats with the filtered ensemble
|
||||
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
|
||||
# Render configuration summary and statistics
|
||||
_render_configuration_summary(
|
||||
selected_members=selected_members,
|
||||
selected_member_stats=selected_member_stats,
|
||||
selected_target=selected_target,
|
||||
selected_task=selected_task,
|
||||
selected_temporal_mode=selected_temporal_mode,
|
||||
stats=stats,
|
||||
)
|
||||
else:
|
||||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from entropice.utils.types import (
|
|||
TargetDataset,
|
||||
Task,
|
||||
TemporalMode,
|
||||
all_l2_source_datasets,
|
||||
all_target_datasets,
|
||||
all_tasks,
|
||||
all_temporal_modes,
|
||||
|
|
@ -43,7 +42,7 @@ class MemberStatistics:
|
|||
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
||||
"""Pre-compute the statistics for a specific dataset member."""
|
||||
member_stats = {}
|
||||
for member in all_l2_source_datasets:
|
||||
for member in e.members:
|
||||
ds = e.read_member(member, lazy=True)
|
||||
size_bytes = ds.nbytes
|
||||
|
||||
|
|
@ -113,6 +112,29 @@ class DatasetStatistics:
|
|||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
||||
|
||||
@classmethod
|
||||
def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
|
||||
"""Compute dataset statistics from a DatasetEnsemble."""
|
||||
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
|
||||
total_cells = len(grid_gdf)
|
||||
target_statistics = {}
|
||||
for target in all_target_datasets:
|
||||
if isinstance(e.temporal_mode, int) and target == "darts_mllabels":
|
||||
# darts_mllabels does not support year-based temporal modes
|
||||
continue
|
||||
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
||||
member_statistics = MemberStatistics.compute(e)
|
||||
|
||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
||||
return cls(
|
||||
total_features=total_features,
|
||||
total_cells=total_cells,
|
||||
size_bytes=total_size_bytes,
|
||||
members=member_statistics,
|
||||
target=target_statistics,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_target_sample_count_df(
|
||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue