Improve the dataset configuration explorer
This commit is contained in:
parent
c9c6af8370
commit
a6e9a91692
2 changed files with 513 additions and 152 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
"""Dataset Statistics Section for Entropice Dashboard."""
|
"""Dataset Statistics Section for Entropice Dashboard."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
|
@ -12,8 +14,10 @@ from entropice.dashboard.plots.dataset_statistics import (
|
||||||
create_sample_count_bar_chart,
|
create_sample_count_bar_chart,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.utils.types import (
|
from entropice.utils.types import (
|
||||||
|
GridConfig,
|
||||||
GridLevel,
|
GridLevel,
|
||||||
L2SourceDataset,
|
L2SourceDataset,
|
||||||
TargetDataset,
|
TargetDataset,
|
||||||
|
|
@ -97,17 +101,13 @@ def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
|
||||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
def _render_grid_and_task_selection() -> tuple[GridConfig, TemporalMode, TargetDataset, Task]:
|
||||||
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
|
"""Render grid, temporal mode, target, and task selection controls.
|
||||||
"""Render interactive detailed configuration explorer."""
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
Explore detailed statistics for a specific dataset configuration.
|
|
||||||
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Grid selection
|
Returns:
|
||||||
|
Tuple of (grid_config, temporal_mode, target_dataset, task)
|
||||||
|
|
||||||
|
"""
|
||||||
grid_options = [gc.display_name for gc in grid_configs]
|
grid_options = [gc.display_name for gc in grid_configs]
|
||||||
|
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
@ -159,18 +159,19 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
|
||||||
# Find the selected grid config
|
# Find the selected grid config
|
||||||
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
|
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
|
||||||
|
|
||||||
# Get stats for the selected configuration
|
return selected_grid_config, selected_temporal_mode, selected_target, selected_task
|
||||||
mode_stats = dataset_stats[selected_grid_config.id]
|
|
||||||
|
|
||||||
# Check if the selected temporal mode has stats
|
|
||||||
if selected_temporal_mode not in mode_stats:
|
|
||||||
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
|
|
||||||
return
|
|
||||||
|
|
||||||
stats = mode_stats[selected_temporal_mode]
|
def _render_member_selection(available_members: list[L2SourceDataset]) -> list[L2SourceDataset]:
|
||||||
available_members = list(stats.members.keys())
|
"""Render data source member selection checkboxes.
|
||||||
|
|
||||||
# Members selection
|
Args:
|
||||||
|
available_members: List of available member dataset names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of selected member datasets
|
||||||
|
|
||||||
|
"""
|
||||||
st.markdown("#### Select Data Sources")
|
st.markdown("#### Select Data Sources")
|
||||||
|
|
||||||
# Use columns for checkboxes
|
# Use columns for checkboxes
|
||||||
|
|
@ -183,11 +184,277 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
|
||||||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||||
selected_members.append(member) # type: ignore[arg-type]
|
selected_members.append(member) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Show results if at least one member is selected
|
return selected_members
|
||||||
if selected_members:
|
|
||||||
# Filter to selected members only
|
|
||||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data(ttl=3600)
|
||||||
|
def _read_member_dataset_cached(
|
||||||
|
grid: str,
|
||||||
|
level: int,
|
||||||
|
temporal_mode: str | int,
|
||||||
|
member: str,
|
||||||
|
):
|
||||||
|
"""Read a member dataset with caching to avoid redundant loads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid: Grid system name
|
||||||
|
level: Grid resolution level
|
||||||
|
temporal_mode: Temporal mode (feature, synopsis, or year)
|
||||||
|
member: Member dataset name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lazy-loaded xarray dataset, or None if reading fails
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ensemble = DatasetEnsemble(
|
||||||
|
grid=grid, # type: ignore[arg-type]
|
||||||
|
level=level,
|
||||||
|
temporal_mode=temporal_mode, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return ensemble.read_member(member, lazy=True)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_member_datasets(
|
||||||
|
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
|
||||||
|
) -> dict[L2SourceDataset, Any]:
|
||||||
|
"""Read all member datasets once using cached function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_ensemble: Temporary ensemble to get grid/level/temporal_mode from
|
||||||
|
selected_members: List of selected member datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping members to their loaded datasets (or None if failed)
|
||||||
|
|
||||||
|
"""
|
||||||
|
member_datasets = {}
|
||||||
|
for member in selected_members:
|
||||||
|
ds = _read_member_dataset_cached(
|
||||||
|
grid=temp_ensemble.grid,
|
||||||
|
level=temp_ensemble.level,
|
||||||
|
temporal_mode=temp_ensemble.temporal_mode,
|
||||||
|
member=member,
|
||||||
|
)
|
||||||
|
member_datasets[member] = ds
|
||||||
|
return member_datasets
|
||||||
|
|
||||||
|
|
||||||
|
def _get_member_aggregations(member_datasets: dict[L2SourceDataset, Any]) -> dict[L2SourceDataset, list[str]]:
|
||||||
|
"""Get available aggregation dimensions for each selected member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping members to their aggregation dimension names
|
||||||
|
|
||||||
|
"""
|
||||||
|
member_aggregations: dict[L2SourceDataset, list[str]] = {}
|
||||||
|
for member, member_ds in member_datasets.items():
|
||||||
|
if member_ds is None:
|
||||||
|
member_aggregations[member] = []
|
||||||
|
continue
|
||||||
|
# Find aggregation dimensions (only dimensions named "agg" or "aggregations")
|
||||||
|
agg_dims = [dim for dim in member_ds.dims if dim in ("agg", "aggregations")]
|
||||||
|
member_aggregations[member] = sorted(agg_dims)
|
||||||
|
return member_aggregations
|
||||||
|
|
||||||
|
|
||||||
|
def _set_all_aggregations(
|
||||||
|
member_datasets: dict[L2SourceDataset, Any],
|
||||||
|
members_with_aggs: list[L2SourceDataset],
|
||||||
|
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||||
|
selected: bool,
|
||||||
|
):
|
||||||
|
"""Set all aggregation checkboxes to selected or deselected state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
members_with_aggs: List of members that have aggregations
|
||||||
|
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||||
|
selected: True to select all, False to deselect all
|
||||||
|
|
||||||
|
"""
|
||||||
|
for member in members_with_aggs:
|
||||||
|
member_ds = member_datasets.get(member)
|
||||||
|
if member_ds is None:
|
||||||
|
continue
|
||||||
|
agg_dims = member_aggregations.get(member, [])
|
||||||
|
for agg_dim in agg_dims:
|
||||||
|
if agg_dim in member_ds.dims:
|
||||||
|
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||||
|
for val in agg_values:
|
||||||
|
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = selected
|
||||||
|
|
||||||
|
|
||||||
|
def _set_median_only_aggregations(
|
||||||
|
member_datasets: dict[L2SourceDataset, Any],
|
||||||
|
members_with_aggs: list[L2SourceDataset],
|
||||||
|
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||||
|
):
|
||||||
|
"""Set only median aggregations to selected, deselect all others.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
members_with_aggs: List of members that have aggregations
|
||||||
|
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||||
|
|
||||||
|
"""
|
||||||
|
for member in members_with_aggs:
|
||||||
|
member_ds = member_datasets.get(member)
|
||||||
|
if member_ds is None:
|
||||||
|
continue
|
||||||
|
agg_dims = member_aggregations.get(member, [])
|
||||||
|
for agg_dim in agg_dims:
|
||||||
|
if agg_dim in member_ds.dims:
|
||||||
|
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||||
|
for val in agg_values:
|
||||||
|
# Select only if value is or contains 'median'
|
||||||
|
is_median = str(val).lower() == "median" or "median" in str(val).lower()
|
||||||
|
st.session_state[f"feature_agg_{member}_{agg_dim}_{val}"] = is_median
|
||||||
|
|
||||||
|
|
||||||
|
def _render_aggregation_form(
|
||||||
|
member_datasets: dict[L2SourceDataset, Any],
|
||||||
|
members_with_aggs: list[L2SourceDataset],
|
||||||
|
member_aggregations: dict[L2SourceDataset, list[str]],
|
||||||
|
) -> dict[str, dict[str, list]]:
|
||||||
|
"""Render aggregation selection form with checkboxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_datasets: Dictionary mapping members to their loaded datasets
|
||||||
|
members_with_aggs: List of members that have aggregations
|
||||||
|
member_aggregations: Dictionary mapping members to aggregation dimensions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping member names to dimension filters
|
||||||
|
|
||||||
|
"""
|
||||||
|
dimension_filters: dict[str, dict[str, list]] = {}
|
||||||
|
|
||||||
|
if not members_with_aggs:
|
||||||
|
return dimension_filters
|
||||||
|
|
||||||
|
member_cols = st.columns(len(members_with_aggs))
|
||||||
|
|
||||||
|
for col_idx, member in enumerate(members_with_aggs):
|
||||||
|
with member_cols[col_idx]:
|
||||||
|
st.markdown(f"**{member}:**")
|
||||||
|
member_ds = member_datasets.get(member)
|
||||||
|
if member_ds is None:
|
||||||
|
continue
|
||||||
|
aggs = member_aggregations[member]
|
||||||
|
|
||||||
|
for agg_dim in aggs:
|
||||||
|
if agg_dim not in member_ds.dims:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get coordinate values for this dimension
|
||||||
|
agg_values = member_ds.coords[agg_dim].to_numpy().tolist()
|
||||||
|
# Convert to strings for display
|
||||||
|
agg_values = [str(v) for v in agg_values]
|
||||||
|
|
||||||
|
st.markdown(f"*{agg_dim}:*")
|
||||||
|
selected_vals = []
|
||||||
|
|
||||||
|
for val in agg_values:
|
||||||
|
if st.checkbox(
|
||||||
|
val,
|
||||||
|
value=True,
|
||||||
|
key=f"feature_agg_{member}_{agg_dim}_{val}",
|
||||||
|
help=f"Include {val} from {agg_dim}",
|
||||||
|
):
|
||||||
|
selected_vals.append(val)
|
||||||
|
|
||||||
|
# Store selected values in dimension_filters
|
||||||
|
if selected_vals and len(selected_vals) < len(agg_values):
|
||||||
|
if member not in dimension_filters:
|
||||||
|
dimension_filters[member] = {}
|
||||||
|
dimension_filters[member][agg_dim] = selected_vals
|
||||||
|
|
||||||
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def _render_aggregation_selection(
|
||||||
|
temp_ensemble: DatasetEnsemble, selected_members: list[L2SourceDataset]
|
||||||
|
) -> dict[str, dict[str, list]]:
|
||||||
|
"""Render aggregation selection controls for selected members.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_ensemble: Temporary ensemble to inspect for aggregations
|
||||||
|
selected_members: List of selected member datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping member names to dimension filters
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not selected_members:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Read all member datasets once (cached)
|
||||||
|
member_datasets = _get_member_datasets(temp_ensemble, selected_members)
|
||||||
|
|
||||||
|
# Get available aggregations for each selected member
|
||||||
|
member_aggregations = _get_member_aggregations(member_datasets)
|
||||||
|
members_with_aggs: list[L2SourceDataset] = [m for m in selected_members if member_aggregations.get(m)]
|
||||||
|
|
||||||
|
if not members_with_aggs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Add Select All / Deselect All buttons
|
||||||
|
st.markdown("#### Select Aggregations")
|
||||||
|
st.markdown(
|
||||||
|
"Select which spatial aggregations to use for each data source. "
|
||||||
|
"Different members may have different aggregation options depending on the grid configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
|
||||||
|
with col_btn1:
|
||||||
|
if st.button("✅ Select All", use_container_width=True):
|
||||||
|
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=True)
|
||||||
|
with col_btn2:
|
||||||
|
if st.button("📊 Median Only", use_container_width=True):
|
||||||
|
_set_median_only_aggregations(member_datasets, members_with_aggs, member_aggregations)
|
||||||
|
with col_btn3:
|
||||||
|
if st.button("❌ Deselect All", use_container_width=True):
|
||||||
|
_set_all_aggregations(member_datasets, members_with_aggs, member_aggregations, selected=False)
|
||||||
|
|
||||||
|
# Render the form with checkboxes
|
||||||
|
with st.form("aggregation_selection_form"):
|
||||||
|
dimension_filters = _render_aggregation_form(member_datasets, members_with_aggs, member_aggregations)
|
||||||
|
|
||||||
|
# Submit button for the form
|
||||||
|
submitted = st.form_submit_button("Apply Aggregation Filters", type="primary")
|
||||||
|
|
||||||
|
if not submitted:
|
||||||
|
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def _render_configuration_summary(
|
||||||
|
selected_members: list[L2SourceDataset],
|
||||||
|
selected_member_stats: dict[str, MemberStatistics],
|
||||||
|
selected_target: TargetDataset,
|
||||||
|
selected_task: Task,
|
||||||
|
selected_temporal_mode: TemporalMode,
|
||||||
|
stats: DatasetStatistics,
|
||||||
|
):
|
||||||
|
"""Render configuration summary with metrics and statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_members: List of selected member datasets
|
||||||
|
selected_member_stats: Statistics for selected members
|
||||||
|
selected_target: Selected target dataset
|
||||||
|
selected_task: Selected task type
|
||||||
|
selected_temporal_mode: Selected temporal mode
|
||||||
|
stats: Overall dataset statistics
|
||||||
|
|
||||||
|
"""
|
||||||
# Calculate total features for selected members
|
# Calculate total features for selected members
|
||||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||||
|
|
||||||
|
|
@ -335,6 +602,78 @@ def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||||
|
"""Render interactive detailed configuration explorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_stats: Pre-computed statistics cache for optimization
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Explore detailed statistics for a specific dataset configuration.
|
||||||
|
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grid, temporal mode, target, and task selection
|
||||||
|
selected_grid_config, selected_temporal_mode, selected_target, selected_task = _render_grid_and_task_selection()
|
||||||
|
|
||||||
|
# Create dataset ensemble on-the-fly and compute statistics
|
||||||
|
temp_ensemble = DatasetEnsemble(
|
||||||
|
grid=selected_grid_config.grid,
|
||||||
|
level=selected_grid_config.level,
|
||||||
|
temporal_mode=selected_temporal_mode,
|
||||||
|
)
|
||||||
|
available_members = temp_ensemble.members
|
||||||
|
|
||||||
|
# Member selection
|
||||||
|
selected_members = _render_member_selection(available_members)
|
||||||
|
|
||||||
|
# Aggregation selection
|
||||||
|
dimension_filters = _render_aggregation_selection(temp_ensemble, selected_members)
|
||||||
|
|
||||||
|
# Show results if at least one member is selected
|
||||||
|
if selected_members:
|
||||||
|
# Optimize: Use pre-computed stats if no dimensional filtering is needed
|
||||||
|
grid_level_key = f"{selected_grid_config.grid}{selected_grid_config.level}" # e.g., "hex3", "healpix6"
|
||||||
|
use_cached_stats = (
|
||||||
|
not dimension_filters # No dimension filters applied
|
||||||
|
and grid_level_key in all_stats # Grid level exists in cache
|
||||||
|
and selected_temporal_mode in all_stats[grid_level_key] # type: ignore[literal-required] # Temporal mode exists
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cached_stats:
|
||||||
|
# Use pre-computed statistics (much faster)
|
||||||
|
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
|
||||||
|
# Filter to selected members only
|
||||||
|
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||||
|
else:
|
||||||
|
# Create the actual ensemble with selected members and dimension filters
|
||||||
|
ensemble = DatasetEnsemble(
|
||||||
|
grid=selected_grid_config.grid,
|
||||||
|
level=selected_grid_config.level,
|
||||||
|
temporal_mode=selected_temporal_mode,
|
||||||
|
members=selected_members,
|
||||||
|
dimension_filters=dimension_filters,
|
||||||
|
)
|
||||||
|
# Recompute stats with the filtered ensemble
|
||||||
|
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||||
|
# Filter to selected members only
|
||||||
|
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||||
|
|
||||||
|
# Render configuration summary and statistics
|
||||||
|
_render_configuration_summary(
|
||||||
|
selected_members=selected_members,
|
||||||
|
selected_member_stats=selected_member_stats,
|
||||||
|
selected_target=selected_target,
|
||||||
|
selected_task=selected_task,
|
||||||
|
selected_temporal_mode=selected_temporal_mode,
|
||||||
|
stats=stats,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
st.info("👆 Select at least one data source to see feature statistics")
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ from entropice.utils.types import (
|
||||||
TargetDataset,
|
TargetDataset,
|
||||||
Task,
|
Task,
|
||||||
TemporalMode,
|
TemporalMode,
|
||||||
all_l2_source_datasets,
|
|
||||||
all_target_datasets,
|
all_target_datasets,
|
||||||
all_tasks,
|
all_tasks,
|
||||||
all_temporal_modes,
|
all_temporal_modes,
|
||||||
|
|
@ -43,7 +42,7 @@ class MemberStatistics:
|
||||||
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
||||||
"""Pre-compute the statistics for a specific dataset member."""
|
"""Pre-compute the statistics for a specific dataset member."""
|
||||||
member_stats = {}
|
member_stats = {}
|
||||||
for member in all_l2_source_datasets:
|
for member in e.members:
|
||||||
ds = e.read_member(member, lazy=True)
|
ds = e.read_member(member, lazy=True)
|
||||||
size_bytes = ds.nbytes
|
size_bytes = ds.nbytes
|
||||||
|
|
||||||
|
|
@ -113,6 +112,29 @@ class DatasetStatistics:
|
||||||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||||
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ensemble(cls, e: DatasetEnsemble) -> "DatasetStatistics":
|
||||||
|
"""Compute dataset statistics from a DatasetEnsemble."""
|
||||||
|
grid_gdf = entropice.spatial.grids.open(e.grid, e.level) # Ensure grid is registered
|
||||||
|
total_cells = len(grid_gdf)
|
||||||
|
target_statistics = {}
|
||||||
|
for target in all_target_datasets:
|
||||||
|
if isinstance(e.temporal_mode, int) and target == "darts_mllabels":
|
||||||
|
# darts_mllabels does not support year-based temporal modes
|
||||||
|
continue
|
||||||
|
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
||||||
|
member_statistics = MemberStatistics.compute(e)
|
||||||
|
|
||||||
|
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||||
|
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
||||||
|
return cls(
|
||||||
|
total_features=total_features,
|
||||||
|
total_cells=total_cells,
|
||||||
|
size_bytes=total_size_bytes,
|
||||||
|
members=member_statistics,
|
||||||
|
target=target_statistics,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_target_sample_count_df(
|
def get_target_sample_count_df(
|
||||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue