Fix training and overview page

This commit is contained in:
Tobias Hölzer 2026-01-16 20:33:10 +01:00
parent 4445834895
commit c9c6af8370
17 changed files with 1643 additions and 1125 deletions

View file

@ -68,7 +68,7 @@ dependencies = [
"pandas-stubs>=2.3.3.251201,<3",
"pytest>=9.0.2,<10",
"autogluon-tabular[all,mitra]>=1.5.0",
"shap>=0.50.0,<0.51",
"shap>=0.50.0,<0.51", "h5py>=3.15.1,<4",
]
[project.scripts]

View file

@ -5,12 +5,14 @@ Pages:
- Overview: List of available result directories with some summary statistics.
- Training Data: Visualization of training data distributions.
- Training Results Analysis: Analysis of training results and model performance.
- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
- Model State: Visualization of model state and features.
- Inference: Visualization of inference results.
"""
import streamlit as st
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
@ -26,13 +28,14 @@ def main():
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
"Training": [training_data_page, training_analysis_page],
"Training": [training_data_page, training_analysis_page, autogluon_page],
"Model State": [model_state_page],
"Inference": [inference_page],
}

View file

@ -0,0 +1,458 @@
"""Visualization functions for the overview page.
This module contains reusable plotting functions for dataset analysis visualizations,
including sample counts, feature counts, and dataset statistics.
"""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from entropice.dashboard.utils.colors import get_palette
def _get_colors_from_members(unique_members: list[str]) -> dict[str, str]:
unique_members = sorted(unique_members)
unique_sources = {member.split("-")[0] for member in unique_members}
source_member_map = {
source: [member for member in unique_members if member.split("-")[0] == source] for source in unique_sources
}
source_color_map = {}
for source, members in source_member_map.items():
n_members = len(members) + 2 # Avoid very light/dark colors
palette = get_palette(source, n_colors=n_members)
for idx, member in enumerate(members, start=1):
source_color_map[member] = palette[idx]
return source_color_map
def create_sample_count_bar_chart(sample_df: pd.DataFrame) -> go.Figure:
"""Create bar chart showing sample counts by grid, target and temporal mode.
Args:
sample_df: DataFrame with columns: Grid, Target, Samples (Coverage).
Returns:
Plotly Figure object containing the bar chart visualization.
"""
assert "Grid" in sample_df.columns
assert "Target" in sample_df.columns
assert "Temporal Mode" in sample_df.columns
assert "Samples (Coverage)" in sample_df.columns
targets = sorted(sample_df["Target"].unique())
modes = sorted(sample_df["Temporal Mode"].unique())
# Create subplots manually to have better control over colors
fig = make_subplots(
rows=1,
cols=len(targets),
subplot_titles=[f"Target: {target}" for target in targets],
shared_yaxes=True,
)
# Get color palette for this plot
colors = get_palette("Number of Samples", n_colors=len(modes))
# Track which modes have been added to legend
modes_in_legend = set()
for col_idx, target in enumerate(targets, start=1):
target_data = sample_df[sample_df["Target"] == target]
for mode_idx, mode in enumerate(modes):
mode_data = target_data[target_data["Temporal Mode"] == mode]
if len(mode_data) == 0:
continue
color = colors[mode_idx] if colors and mode_idx < len(colors) else None
# Show legend only the first time each mode appears
show_in_legend = mode not in modes_in_legend
if show_in_legend:
modes_in_legend.add(mode)
# Create a unique legendgroup per mode so colors are consistent
fig.add_trace(
go.Bar(
x=mode_data["Grid"],
y=mode_data["Samples (Coverage)"],
name=mode,
marker_color=color,
legendgroup=mode, # Group by mode name
showlegend=show_in_legend,
),
row=1,
col=col_idx,
)
fig.update_layout(
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
barmode="group",
height=500,
showlegend=True,
)
fig.update_xaxes(tickangle=-45)
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
return fig
def create_feature_count_stacked_bar(breakdown_df: pd.DataFrame) -> go.Figure:
"""Create stacked bar chart showing feature counts by member.
Args:
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
Returns:
Plotly Figure object containing the stacked bar chart visualization.
"""
assert "Grid" in breakdown_df.columns
assert "Temporal Mode" in breakdown_df.columns
assert "Member" in breakdown_df.columns
assert "Number of Features" in breakdown_df.columns
unique_members = sorted(breakdown_df["Member"].unique())
temporal_modes = sorted(breakdown_df["Temporal Mode"].unique())
source_color_map = _get_colors_from_members(unique_members)
# Create subplots for each temporal mode
fig = make_subplots(
rows=1,
cols=len(temporal_modes),
subplot_titles=[f"Temporal Mode: {mode}" for mode in temporal_modes],
shared_yaxes=True,
)
for col_idx, mode in enumerate(temporal_modes, start=1):
mode_data = breakdown_df[breakdown_df["Temporal Mode"] == mode]
for member in unique_members:
member_data = mode_data[mode_data["Member"] == member]
color = source_color_map.get(member)
fig.add_trace(
go.Bar(
x=member_data["Grid"],
y=member_data["Number of Features"],
name=member,
marker_color=color,
legendgroup=member,
showlegend=(col_idx == 1),
),
row=1,
col=col_idx,
)
fig.update_layout(
title_text="Input Features by Member Across Grid Configurations",
barmode="stack",
height=500,
showlegend=True,
)
fig.update_xaxes(tickangle=-45)
fig.update_yaxes(title_text="Number of Features", row=1, col=1)
return fig
def create_inference_cells_bar(sample_df: pd.DataFrame) -> go.Figure:
"""Create bar chart for inference cells by grid configuration.
Args:
sample_df: DataFrame with columns: Grid, Temporal Mode, Inference Cells.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
assert "Grid" in sample_df.columns
assert "Inference Cells" in sample_df.columns
n_grids = sample_df["Grid"].nunique()
mode_colors: list[str] = get_palette("Number of Samples", n_colors=n_grids)
fig = px.bar(
sample_df,
x="Grid",
y="Inference Cells",
color="Grid",
title="Spatial Coverage (Grid Cells with Complete Data)",
labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
},
color_discrete_sequence=mode_colors,
text="Inference Cells",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, height=500, showlegend=False)
return fig
def create_total_samples_bar(
comparison_df: pd.DataFrame,
grid_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart for total samples by grid configuration.
Args:
comparison_df: DataFrame with columns: Grid, Total Samples.
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Training Samples (Binary Task)",
labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
return fig
def create_feature_breakdown_donut(
grid_data: pd.DataFrame,
grid_config: str,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create donut chart for feature breakdown by data source for a specific grid.
Args:
grid_data: DataFrame with columns: Data Source, Number of Features.
grid_config: Grid configuration name for the title.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the donut chart visualization.
"""
fig = px.pie(
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
return fig
def create_feature_distribution_pie(
breakdown_df: pd.DataFrame,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create pie chart for feature distribution by data source.
Args:
breakdown_df: DataFrame with columns: Data Source, Number of Features.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the pie chart visualization.
"""
fig = px.pie(
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent+label")
fig.update_layout(height=400)
return fig
def create_member_details_table(member_stats_dict: dict[str, dict]) -> pd.DataFrame:
"""Create a detailed table for member statistics.
Args:
member_stats_dict: Dictionary mapping member names to their statistics.
Each stats dict should contain: feature_count, variable_names, dimensions, size_bytes.
Returns:
DataFrame with detailed member information.
"""
rows = []
for member, stats in member_stats_dict.items():
# Calculate total data points
total_points = 1
for dim_size in stats["dimensions"].values():
total_points *= dim_size
# Format dimension string
dim_str = " x ".join([f"{k}={v:,}" for k, v in stats["dimensions"].items()])
rows.append(
{
"Member": member,
"Features": stats["feature_count"],
"Variables": len(stats["variable_names"]),
"Dimensions": dim_str,
"Data Points": f"{total_points:,}",
"Size (MB)": f"{stats['size_bytes'] / (1024 * 1024):.2f}",
}
)
return pd.DataFrame(rows)
def create_feature_breakdown_donuts_grid(breakdown_df: pd.DataFrame) -> go.Figure:
"""Create grid of donut charts showing feature breakdown by member for each grid+temporal mode.
Args:
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
Returns:
Plotly Figure object containing the grid of donut charts.
"""
assert "Grid" in breakdown_df.columns
assert "Temporal Mode" in breakdown_df.columns
assert "Member" in breakdown_df.columns
assert "Number of Features" in breakdown_df.columns
# Get unique grids and temporal modes
grids = breakdown_df["Grid"].unique().tolist()
temporal_modes = breakdown_df["Temporal Mode"].unique().tolist()
unique_members = sorted(breakdown_df["Member"].unique())
source_color_map = _get_colors_from_members(unique_members)
n_grids = len(grids)
n_modes = len(temporal_modes)
# Create subplot grid (grids as rows, temporal modes as columns)
# Don't use subplot_titles - we'll add custom annotations for row/column labels
fig = make_subplots(
rows=n_grids,
cols=n_modes,
specs=[[{"type": "pie"}] * n_modes for _ in range(n_grids)],
horizontal_spacing=0.01,
vertical_spacing=0.02,
)
# Get all members to ensure consistent coloring across all donuts
all_members = sorted(breakdown_df["Member"].unique())
for row_idx, grid in enumerate(grids, start=1):
for col_idx, mode in enumerate(temporal_modes, start=1):
subset = breakdown_df[(breakdown_df["Grid"] == grid) & (breakdown_df["Temporal Mode"] == mode)]
if len(subset) == 0:
continue
# Build labels, values, and colors in consistent order
labels = []
values = []
colors_list = []
for member in all_members:
member_data = subset[subset["Member"] == member]
if len(member_data) > 0:
labels.append(member)
values.append(member_data["Number of Features"].iloc[0])
if source_color_map:
colors_list.append(source_color_map[member])
# Only show legend for first subplot
show_legend = row_idx == 1 and col_idx == 1
fig.add_trace(
go.Pie(
labels=labels,
values=values,
name=f"{grid} - {mode}",
hole=0.4,
marker={"colors": colors_list} if colors_list else None,
textposition="inside",
textinfo="percent",
showlegend=show_legend,
),
row=row_idx,
col=col_idx,
)
# Calculate appropriate height based on number of rows
height = max(500, n_grids * 400)
fig.update_layout(
title_text="Feature Breakdown by Member Across Grid Configurations and Temporal Modes",
height=height,
showlegend=True,
margin={"l": 100, "r": 50, "t": 150, "b": 50}, # Add left margin for row labels
)
# Calculate subplot positions in paper coordinates
# Account for spacing between subplots
h_spacing = 0.01
v_spacing = 0.02
# Calculate width and height of each subplot in paper coordinates
subplot_width = (1.0 - h_spacing * (n_modes - 1)) / n_modes
subplot_height = (1.0 - v_spacing * (n_grids - 1)) / n_grids
# Add column headers at the top (temporal modes)
for col_idx, mode in enumerate(temporal_modes):
# Calculate x position for this column's center
x_pos = col_idx * (subplot_width + h_spacing) + subplot_width / 2
fig.add_annotation(
text=f"<b>{mode}</b>",
xref="paper",
yref="paper",
x=x_pos,
y=1.01,
showarrow=False,
xanchor="center",
yanchor="bottom",
font={"size": 18},
)
# Add row headers on the left (grid configurations)
for row_idx, grid in enumerate(grids):
# Calculate y position for this row's center
# Note: y coordinates go from bottom to top, but rows go from top to bottom
y_pos = 1.0 - (row_idx * (subplot_height + v_spacing) + subplot_height / 2)
fig.add_annotation(
text=f"<b>{grid}</b>",
xref="paper",
yref="paper",
x=-0.01,
y=y_pos,
showarrow=False,
xanchor="right",
yanchor="middle",
font={"size": 18},
textangle=-90,
)
return fig

View file

@ -1,276 +0,0 @@
"""Visualization functions for the overview page.
This module contains reusable plotting functions for dataset analysis visualizations,
including sample counts, feature counts, and dataset statistics.
"""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
def create_sample_count_heatmap(
pivot_df: pd.DataFrame,
target: str,
colorscale: list[str] | None = None,
) -> go.Figure:
"""Create heatmap showing sample counts by grid and task.
Args:
pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values.
target: Target dataset name for the title.
colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the heatmap visualization.
"""
fig = px.imshow(
pivot_df,
labels={
"x": "Task",
"y": "Grid Configuration",
"color": "Sample Count",
},
x=pivot_df.columns,
y=pivot_df.index,
color_continuous_scale=colorscale,
aspect="auto",
title=f"Target: {target}",
)
# Add text annotations
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
fig.update_layout(height=400)
return fig
def create_sample_count_bar_chart(
sample_df: pd.DataFrame,
target_color_maps: dict[str, list[str]] | None = None,
) -> go.Figure:
"""Create bar chart showing sample counts by grid, target, and task.
Args:
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
# Create subplots manually to have better control over colors
from plotly.subplots import make_subplots
targets = sorted(sample_df["Target"].unique())
tasks = sorted(sample_df["Task"].unique())
fig = make_subplots(
rows=1,
cols=len(targets),
subplot_titles=[f"Target: {target}" for target in targets],
shared_yaxes=True,
)
for col_idx, target in enumerate(targets, 1):
target_data = sample_df[sample_df["Target"] == target]
# Get color palette for this target
colors = target_color_maps.get(target, None) if target_color_maps else None
for task_idx, task in enumerate(tasks):
task_data = target_data[target_data["Task"] == task]
color = colors[task_idx] if colors and task_idx < len(colors) else None
# Create a unique legendgroup per task so colors are consistent
fig.add_trace(
go.Bar(
x=task_data["Grid"],
y=task_data["Samples (Coverage)"],
name=task,
marker_color=color,
legendgroup=task, # Group by task name
showlegend=(col_idx == 1), # Only show legend for first subplot
),
row=1,
col=col_idx,
)
fig.update_layout(
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
barmode="group",
height=500,
showlegend=True,
)
fig.update_xaxes(tickangle=-45)
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
return fig
def create_feature_count_stacked_bar(
breakdown_df: pd.DataFrame,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create stacked bar chart showing feature counts by data source.
Args:
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the stacked bar chart visualization.
"""
fig = px.bar(
breakdown_df,
x="Grid",
y="Number of Features",
color="Data Source",
barmode="stack",
title="Input Features by Data Source Across Grid Configurations",
labels={
"Grid": "Grid Configuration",
"Number of Features": "Number of Features",
},
color_discrete_map=source_color_map,
text_auto=False,
)
fig.update_layout(height=500, xaxis_tickangle=-45)
return fig
def create_inference_cells_bar(
comparison_df: pd.DataFrame,
grid_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart for inference cells by grid configuration.
Args:
comparison_df: DataFrame with columns: Grid, Inference Cells.
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
comparison_df,
x="Grid",
y="Inference Cells",
color="Grid",
title="Spatial Coverage (Grid Cells with Complete Data)",
labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
},
color_discrete_sequence=grid_colors,
text="Inference Cells",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
return fig
def create_total_samples_bar(
comparison_df: pd.DataFrame,
grid_colors: list[str] | None = None,
) -> go.Figure:
"""Create bar chart for total samples by grid configuration.
Args:
comparison_df: DataFrame with columns: Grid, Total Samples.
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
comparison_df,
x="Grid",
y="Total Samples",
color="Grid",
title="Training Samples (Binary Task)",
labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
},
color_discrete_sequence=grid_colors,
text="Total Samples",
)
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
return fig
def create_feature_breakdown_donut(
grid_data: pd.DataFrame,
grid_config: str,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create donut chart for feature breakdown by data source for a specific grid.
Args:
grid_data: DataFrame with columns: Data Source, Number of Features.
grid_config: Grid configuration name for the title.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the donut chart visualization.
"""
fig = px.pie(
grid_data,
names="Data Source",
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent")
fig.update_layout(showlegend=True, height=350)
return fig
def create_feature_distribution_pie(
breakdown_df: pd.DataFrame,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create pie chart for feature distribution by data source.
Args:
breakdown_df: DataFrame with columns: Data Source, Number of Features.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the pie chart visualization.
"""
fig = px.pie(
breakdown_df,
names="Data Source",
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent+label")
return fig

View file

@ -0,0 +1,376 @@
"""Dataset Statistics Section for Entropice Dashboard."""
import pandas as pd
import streamlit as st
from entropice.dashboard.plots.dataset_statistics import (
create_feature_breakdown_donuts_grid,
create_feature_count_stacked_bar,
create_feature_distribution_pie,
create_inference_cells_bar,
create_member_details_table,
create_sample_count_bar_chart,
)
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
from entropice.utils.types import (
GridLevel,
L2SourceDataset,
TargetDataset,
Task,
TemporalMode,
all_l2_source_datasets,
grid_configs,
)
type DatasetStatsCache = dict[GridLevel, dict[TemporalMode, DatasetStatistics]]
def render_training_samples_tab(sample_df: pd.DataFrame):
"""Render overview of sample counts per task+target+grid+level combination."""
st.markdown(
"""
This visualization shows the number of available training samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
"""
)
# Create and display bar chart
fig = create_sample_count_bar_chart(sample_df)
st.plotly_chart(fig, width="stretch")
# Display full table with formatting
st.markdown("#### Detailed Sample Counts")
# Format numbers with commas
sample_df["Samples (Coverage)"] = sample_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
# Format coverage as percentage with 2 decimal places
sample_df["Coverage %"] = sample_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(sample_df, hide_index=True, width="stretch")
def render_dataset_characteristics_tab(
breakdown_df: pd.DataFrame, sample_df: pd.DataFrame, comparison_df: pd.DataFrame
):
"""Render static comparison of feature counts across all grid configurations."""
st.markdown(
"""
Comparing dataset characteristics for all grid configurations with all data sources enabled.
- **Features**: Total number of input features from all data sources
- **Spatial Coverage**: Number of grid cells with complete data coverage
"""
)
# Get data from cache
# Create and display stacked bar chart
fig = create_feature_count_stacked_bar(breakdown_df)
st.plotly_chart(fig, width="stretch")
# Add spatial coverage metric
fig_cells = create_inference_cells_bar(sample_df)
st.plotly_chart(fig_cells, width="stretch")
# Display full comparison table with formatting
st.markdown("#### Detailed Comparison Table")
st.dataframe(comparison_df, hide_index=True, width="stretch")
def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
"""Render feature breakdown by member across grid configurations and temporal modes."""
st.markdown(
"""
Visualizing feature contribution of each data source (member) across all grid configurations
and temporal modes. Each donut chart represents the feature breakdown for a specific
grid-temporal mode combination.
"""
)
# Create and display the grid of donut charts
fig = create_feature_breakdown_donuts_grid(breakdown_df)
st.plotly_chart(fig, width="stretch")
# Display full breakdown table with formatting
st.markdown("#### Detailed Breakdown Table")
st.dataframe(breakdown_df, hide_index=True, width="stretch")
@st.fragment
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
"""Render interactive detailed configuration explorer."""
st.markdown(
"""
Explore detailed statistics for a specific dataset configuration.
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
"""
)
# Grid selection
grid_options = [gc.display_name for gc in grid_configs]
col1, col2, col3, col4 = st.columns(4)
with col1:
selected_grid_display = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
key="feature_grid_select",
)
with col2:
# Create temporal mode options with proper formatting
temporal_mode_options = ["feature", "synopsis"] + [str(year) for year in [2018, 2019, 2020, 2021, 2022, 2023]]
temporal_mode_display = st.selectbox(
"Temporal Mode",
options=temporal_mode_options,
index=0,
help="Select the temporal mode (feature, synopsis, or specific year)",
key="feature_temporal_mode_select",
)
# Convert back to proper type
selected_temporal_mode: TemporalMode
if temporal_mode_display.isdigit():
selected_temporal_mode = int(temporal_mode_display) # type: ignore[assignment]
else:
selected_temporal_mode = temporal_mode_display # pyright: ignore[reportAssignmentType]
with col3:
selected_target: TargetDataset = st.selectbox(
"Target Dataset",
options=["darts_v1", "darts_mllabels"],
index=0,
help="Select the target dataset",
key="feature_target_select",
)
with col4:
selected_task: Task = st.selectbox(
"Task",
options=["binary", "count", "density", "count_regimes", "density_regimes"],
index=0,
help="Select the task type",
key="feature_task_select",
)
# Find the selected grid config
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
# Get stats for the selected configuration
mode_stats = dataset_stats[selected_grid_config.id]
# Check if the selected temporal mode has stats
if selected_temporal_mode not in mode_stats:
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
return
stats = mode_stats[selected_temporal_mode]
available_members = list(stats.members.keys())
# Members selection
st.markdown("#### Select Data Sources")
# Use columns for checkboxes
cols = st.columns(len(all_l2_source_datasets))
selected_members: list[L2SourceDataset] = []
for idx, member in enumerate(all_l2_source_datasets):
with cols[idx]:
default_value = member in available_members
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
selected_members.append(member) # type: ignore[arg-type]
# Show results if at least one member is selected
if selected_members:
# Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
# Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
# Get target stats if available
if selected_target not in stats.target:
st.warning(f"Target {selected_target} is not available for this configuration.")
return
target_stats = stats.target[selected_target]
# Check if task is available
if selected_task not in target_stats:
st.warning(
f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
)
return
task_stats = target_stats[selected_task]
# High-level metrics
st.markdown("#### Configuration Summary")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Total Features", f"{total_features:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
if selected_member_stats:
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
else:
min_cells = 0
st.metric(
"Inference Cells",
f"{min_cells:,}",
help="Minimum number of cells across all selected data sources",
)
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
st.metric("Training Samples", f"{task_stats.training_cells:,}")
with col5:
# Calculate total data points
total_points = total_features * task_stats.training_cells
st.metric("Total Data Points", f"{total_points:,}")
# Task-specific statistics
st.markdown("#### Task Statistics")
task_col1, task_col2, task_col3 = st.columns(3)
with task_col1:
st.metric("Task Type", selected_task.replace("_", " ").title())
with task_col2:
st.metric("Coverage", f"{task_stats.coverage:.2f}%")
with task_col3:
if task_stats.class_counts:
st.metric("Number of Classes", len(task_stats.class_counts))
else:
st.metric("Task Mode", "Regression")
# Class distribution for classification tasks
if task_stats.class_distribution:
st.markdown("#### Class Distribution")
class_dist_df = pd.DataFrame(
[
{
"Class": class_name,
"Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
"Percentage": f"{pct:.2f}%",
}
for class_name, pct in task_stats.class_distribution.items()
]
)
st.dataframe(class_dist_df, hide_index=True, width="stretch")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get all unique data sources and create color map
unique_members_for_color = sorted(selected_member_stats.keys())
source_color_map_raw = {}
for member in unique_members_for_color:
source = member.split("-")[0]
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
palette = get_palette(source, n_colors=n_members + 2)
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
source_color_map_raw[member] = palette[idx + 1]
# Create and display pie chart
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
st.plotly_chart(fig, width="stretch")
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, width="stretch")
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
# Create detailed table
member_details_dict = {
member: {
"feature_count": ms.feature_count,
"variable_names": ms.variable_names,
"dimensions": ms.dimensions,
"size_bytes": ms.size_bytes,
}
for member, ms in selected_member_stats.items()
}
details_df = create_member_details_table(member_details_dict)
st.dataframe(details_df, hide_index=True, width="stretch")
# Individual member details
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size:,}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
def render_dataset_statistics(all_stats: DatasetStatsCache):
"""Render the dataset statistics section with sample and feature counts."""
st.header("📈 Dataset Statistics")
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
# Create tabs for different analysis views
analysis_tabs = st.tabs(
[
"📊 Training Samples",
"📈 Dataset Characteristics",
"🔍 Feature Breakdown",
"⚙️ Configuration Explorer",
]
)
with analysis_tabs[0]:
st.subheader("Training Samples by Configuration")
render_training_samples_tab(training_sample_df)
with analysis_tabs[1]:
st.subheader("Dataset Characteristics Across Grid Configurations")
render_dataset_characteristics_tab(feature_breakdown_df, inference_sample_df, comparison_df)
with analysis_tabs[2]:
st.subheader("Feature Breakdown by Data Source")
render_feature_breakdown_tab(feature_breakdown_df)
with analysis_tabs[3]:
st.subheader("Interactive Configuration Explorer")
render_configuration_explorer_tab(all_stats)

View file

@ -0,0 +1,158 @@
"""Experiment Results Section for the Entropice Dashboard."""
from datetime import datetime
import streamlit as st
from entropice.dashboard.utils.loaders import TrainingResult
from entropice.utils.types import (
GridConfig,
)
def render_training_results_summary(training_results: list[TrainingResult]):
"""Render summary metrics for training results."""
st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4)
with col1:
experiments = {tr.experiment for tr in training_results if tr.experiment}
st.metric("Experiments", len(experiments))
with col2:
st.metric("Total Runs", len(training_results))
with col3:
models = {tr.settings.model for tr in training_results}
st.metric("Model Types", len(models))
with col4:
latest = training_results[0] # Already sorted by creation time
latest_datetime = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d %H:%M")
st.metric("Latest Run", latest_datetime)
@st.fragment
def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
"""Render detailed experiment results table and expandable details."""
st.header("🎯 Experiment Results")
# Filters
experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
tasks = sorted({tr.settings.task for tr in training_results})
models = sorted({tr.settings.model for tr in training_results})
grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
# Create filter columns
filter_cols = st.columns(4)
with filter_cols[0]:
if experiments:
selected_experiment = st.selectbox(
"Experiment",
options=["All", *experiments],
index=0,
)
else:
selected_experiment = "All"
with filter_cols[1]:
selected_task = st.selectbox(
"Task",
options=["All", *tasks],
index=0,
)
with filter_cols[2]:
selected_model = st.selectbox(
"Model",
options=["All", *models],
index=0,
)
with filter_cols[3]:
selected_grid = st.selectbox(
"Grid",
options=["All", *grids],
index=0,
)
# Apply filters
filtered_results = training_results
if selected_experiment != "All":
filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
if selected_task != "All":
filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
if selected_model != "All":
filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model]
if selected_grid != "All":
filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
st.subheader("Results Table")
summary_df = TrainingResult.to_dataframe(filtered_results)
# Display with color coding for best scores
st.dataframe(
summary_df,
width="stretch",
hide_index=True,
)
# Expandable details for each result
st.subheader("Individual Experiment Details")
for tr in filtered_results:
tr_info = tr.display_info
display_name = tr_info.get_display_name("model_first")
with st.expander(display_name):
col1, col2 = st.columns([1, 2])
with col1:
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
st.write("**Configuration:**")
st.write(f"- **Experiment:** {tr.experiment}")
st.write(f"- **Task:** {tr.settings.task}")
st.write(f"- **Model:** {tr.settings.model}")
st.write(f"- **Grid:** {grid_config.display_name}")
st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}")
st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}")
st.write(f"- **Members:** {', '.join(tr.settings.members)}")
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
st.write(f"- **Classes:** {tr.settings.classes}")
st.write("\n**Files:**")
for file in tr.files:
if file.name == "search_settings.toml":
st.write(f"- ⚙️ `{file.name}`")
elif file.name == "best_estimator_model.pkl":
st.write(f"- 🧮 `{file.name}`")
elif file.name == "search_results.parquet":
st.write(f"- 📊 `{file.name}`")
elif file.name == "predicted_probabilities.parquet":
st.write(f"- 🎯 `{file.name}`")
else:
st.write(f"- 📄 `{file.name}`")
with col2:
st.write("**CV Score Summary:**")
# Extract all test scores
metric_df = tr.get_metric_dataframe()
if metric_df is not None:
st.dataframe(metric_df, width="stretch", hide_index=True)
else:
st.write("No test scores found in results.")
# Show parameter space explored
if "initial_K" in tr.results.columns: # Common parameter
st.write("\n**Parameter Ranges Explored:**")
for param in ["initial_K", "eps_cl", "eps_e"]:
if param in tr.results.columns:
min_val = tr.results[param].min()
max_val = tr.results[param].max()
unique_vals = tr.results[param].nunique()
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
with st.expander("Show CV Results DataFrame"):
st.dataframe(tr.results, width="stretch", hide_index=True)
st.write(f"\n**Path:** `{tr.path}`")

View file

@ -90,6 +90,14 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
A list of hex color strings.
"""
# Hardcode some common variables to specific colormaps
if variable == "ERA5":
cmap = load_cmap(name="blue_material").resampled(n_colors)
elif variable == "ArcticDEM":
cmap = load_cmap(name="deep_purple_material").resampled(n_colors)
elif variable == "AlphaEarth":
cmap = load_cmap(name="green_material").resampled(n_colors)
else:
cmap = get_cmap(variable).resampled(n_colors)
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
return colors

View file

@ -6,7 +6,6 @@ from datetime import datetime
from pathlib import Path
import antimeridian
import geopandas as gpd
import pandas as pd
import streamlit as st
import toml
@ -16,9 +15,8 @@ from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.ml.training import TrainingSettings
from entropice.utils.types import L2SourceDataset, Task
from entropice.utils.types import GridConfig
def _fix_hex_geometry(geom):
@ -32,22 +30,29 @@ def _fix_hex_geometry(geom):
@dataclass
class TrainingResult:
"""Wrapper for training result data and metadata."""
path: Path
experiment: str
settings: TrainingSettings
results: pd.DataFrame
metrics: dict[str, float]
confusion_matrix: xr.DataArray
train_metrics: dict[str, float]
test_metrics: dict[str, float]
combined_metrics: dict[str, float]
confusion_matrix: xr.Dataset | None
created_at: float
available_metrics: list[str]
files: list[Path]
@classmethod
def from_path(cls, result_path: Path) -> "TrainingResult":
def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
metrics_file = result_path / "test_metrics.toml"
metrics_file = result_path / "metrics.toml"
confusion_matrix_file = result_path / "confusion_matrix.nc"
all_files = list(result_path.iterdir())
if not result_file.exists():
raise FileNotFoundError(f"Missing results file in {result_path}")
if not settings_file.exists():
@ -56,28 +61,46 @@ class TrainingResult:
raise FileNotFoundError(f"Missing predictions file in {result_path}")
if not metrics_file.exists():
raise FileNotFoundError(f"Missing metrics file in {result_path}")
if not confusion_matrix_file.exists():
raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
settings_dict = toml.load(settings_file)["settings"]
# Handle backward compatibility: add missing fields with defaults
if "classes" not in settings_dict:
settings_dict["classes"] = None
if "param_grid" not in settings_dict:
settings_dict["param_grid"] = {}
if "cv_splits" not in settings_dict:
settings_dict["cv_splits"] = 5
if "metrics" not in settings_dict:
settings_dict["metrics"] = []
settings = TrainingSettings(**settings_dict)
results = pd.read_parquet(result_file)
metrics = toml.load(metrics_file)["test_metrics"]
confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
metrics = toml.load(metrics_file)
if not confusion_matrix_file.exists():
confusion_matrix = None
else:
confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf")
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
return cls(
path=result_path,
experiment=experiment_name or "N/A",
settings=settings,
results=results,
metrics=metrics,
train_metrics=metrics["train_metrics"],
test_metrics=metrics["test_metrics"],
combined_metrics=metrics["combined_metrics"],
confusion_matrix=confusion_matrix,
created_at=created_at,
available_metrics=available_metrics,
files=all_files,
)
@property
def display_info(self) -> TrainingResultDisplayInfo:
"""Get display information for the training result."""
return TrainingResultDisplayInfo(
task=self.settings.task,
model=self.settings.model,
@ -126,9 +149,70 @@ class TrainingResult:
st.error(f"Error loading predictions: {e}")
return None
def get_metric_dataframe(self) -> pd.DataFrame | None:
"""Get a DataFrame of available metrics for this training result."""
metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")]
if not metric_cols:
return None
metric_data = []
for col in metric_cols:
metric_name = col.replace("mean_test_", "").replace("neg_", "").title()
metrics = self.results[col]
# Check if the metric is negative
if col.startswith("mean_test_neg_"):
task_multiplier = 1 if self.settings.task != "density" else 100
task_multiplier *= -1
metrics = metrics * task_multiplier
metric_data.append(
{
"Metric": metric_name,
"Best": f"{metrics.max():.4f}",
"Mean": f"{metrics.mean():.4f}",
"Std": f"{metrics.std():.4f}",
"Worst": f"{metrics.min():.4f}",
}
)
return pd.DataFrame(metric_data)
def _get_best_metric_name(self) -> str:
"""Get the primary metric name for a given task."""
match self.settings.task:
case "binary":
return "f1"
case "count_regimes" | "density_regimes":
return "f1_weighted"
case _: # regression tasks
return "r2"
@staticmethod
def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame:
"""Convert a list of TrainingResult objects to a DataFrame for display."""
records = []
for tr in training_results:
info = tr.display_info
best_metric_name = tr._get_best_metric_name()
record = {
"Experiment": tr.experiment if tr.experiment else "N/A",
"Task": info.task,
"Model": info.model,
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
"Score-Metric": best_metric_name.title(),
"Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name),
"Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
"Trials": len(tr.results),
"Path": str(tr.path.name),
}
records.append(record)
return pd.DataFrame.from_records(records)
@st.cache_data
def load_all_training_results() -> list[TrainingResult]:
"""Load all training results from the results directory."""
results_dir = entropice.utils.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
for result_path in results_dir.iterdir():
@ -136,50 +220,22 @@ def load_all_training_results() -> list[TrainingResult]:
continue
try:
training_result = TrainingResult.from_path(result_path)
except FileNotFoundError as e:
st.warning(f"Skipping incomplete training result: {e}")
continue
training_results.append(training_result)
except FileNotFoundError as e:
is_experiment_dir = False
for experiment_path in result_path.iterdir():
if not experiment_path.is_dir():
continue
try:
experiment_name = experiment_path.name
training_result = TrainingResult.from_path(experiment_path, experiment_name)
training_results.append(training_result)
is_experiment_dir = True
except FileNotFoundError as e2:
st.warning(f"Skipping incomplete training result: {e2}")
if not is_experiment_dir:
st.warning(f"Skipping incomplete training result: {e}")
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
def load_all_training_data(
e: DatasetEnsemble,
) -> dict[Task, CategoricalTrainingDataset]:
"""Load training data for all three tasks.
Args:
e: DatasetEnsemble object.
Returns:
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
dataset = e.create(filter_target_col=e.covcol)
return {
"binary": e._cat_and_split(dataset, "binary", device="cpu"),
"count": e._cat_and_split(dataset, "count", device="cpu"),
"density": e._cat_and_split(dataset, "density", device="cpu"),
}
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
Args:
e: DatasetEnsemble object.
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
Returns:
xarray.Dataset with the raw data for the specified source.
"""
targets = e._read_target()
# Load the member data lazily to get metadata
ds = e._read_member(source, targets, lazy=False)
return ds, targets

View file

@ -3,29 +3,28 @@
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
"""
import pickle
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Literal
import geopandas as gpd
import pandas as pd
import streamlit as st
import xarray as xr
from stopuhr import stopwatch
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
Grid,
GridLevel,
L2SourceDataset,
TargetDataset,
Task,
TemporalMode,
all_l2_source_datasets,
all_target_datasets,
all_tasks,
all_temporal_modes,
grid_configs,
)
@ -41,90 +40,63 @@ class MemberStatistics:
size_bytes: int # Size of this member's data on disk in bytes
@classmethod
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
if member == "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
era5_agg = member.split("-")[1]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
elif member == "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
else:
raise NotImplementedError(f"Member {member} not implemented.")
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
"""Pre-compute the statistics for a specific dataset member."""
member_stats = {}
for member in all_l2_source_datasets:
ds = e.read_member(member, lazy=True)
size_bytes = ds.nbytes
size_bytes = store.stat().st_size
ds = xr.open_zarr(store, consolidated=False)
# Delete all coordinates which are not in the dimension
for coord in ds.coords:
if coord not in ds.dims:
ds = ds.drop_vars(coord)
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
return cls(
member_stats[member] = cls(
feature_count=n_cols_member,
variable_names=list(ds.data_vars),
dimensions=dict(ds.sizes),
coordinates=list(ds.coords),
size_bytes=size_bytes,
)
return member_stats
@dataclass(frozen=True)
class TargetStatistics:
"""Statistics for a specific target dataset."""
training_cells: dict[Task, int] # Number of cells used for training
coverage: dict[Task, float] # Percentage of total cells covered by training data
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
size_bytes: int # Size of the target dataset on disk in bytes
training_cells: int # Number of cells used for training
coverage: float # Percentage of total cells covered by training data
class_counts: dict[str, int] | None # Class name to count mapping
class_distribution: dict[str, float] | None # Class name to percentage mapping
@classmethod
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
if target == "darts_rts":
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
elif target == "darts_mllabels":
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
else:
raise NotImplementedError(f"Target {target} not implemented.")
target_gdf = gpd.read_parquet(target_store)
size_bytes = target_store.stat().st_size
training_cells: dict[Task, int] = {}
training_coverage: dict[Task, float] = {}
class_counts: dict[Task, dict[str, int]] = {}
class_distribution: dict[Task, dict[str, float]] = {}
def compute(cls, e: DatasetEnsemble, target: TargetDataset, total_cells: int) -> dict[Task, "TargetStatistics"]:
"""Pre-compute the statistics for a specific target dataset."""
target_stats = {}
for task in all_tasks:
task_col = taskcol[task][target]
cov_col = covcol[target]
targets = e.get_targets(target=target, task=task)
task_gdf = target_gdf[target_gdf[cov_col]]
training_cells[task] = len(task_gdf)
training_coverage[task] = len(task_gdf) / total_cells * 100
training_cells = len(targets)
training_coverage = len(targets) / total_cells * 100
model_labels = task_gdf[task_col].dropna()
if task == "binary":
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
elif task == "count":
binned = bin_values(model_labels.astype(int), task=task)
elif task == "density":
binned = bin_values(model_labels, task=task)
if task in ["count", "density"]:
class_counts = None
class_distribution = None
else:
raise ValueError("Invalid task.")
counts = binned.value_counts()
assert targets["y"].dtype == "category", "Classification tasks must have categorical target dtype."
counts = targets["y"].value_counts()
distribution = counts / counts.sum() * 100
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
class_distribution[task] = distribution.to_dict()
return TargetStatistics(
class_counts = counts.to_dict()
class_distribution = distribution.to_dict()
target_stats[task] = cls(
training_cells=training_cells,
coverage=training_coverage,
class_counts=class_counts,
class_distribution=class_distribution,
size_bytes=size_bytes,
)
return target_stats
@dataclass(frozen=True)
@ -139,108 +111,154 @@ class DatasetStatistics:
total_cells: int # Total number of grid cells potentially covered
size_bytes: int # Size of the dataset on disk in bytes
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
@staticmethod
def get_sample_count_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
def get_target_sample_count_df(
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
mode_stats = all_stats[grid_config.id]
for temporal_mode, stats in mode_stats.items():
for target_name, target_stats in stats.target.items():
for task in all_tasks:
training_cells = target_stats.training_cells[task]
coverage = target_stats.coverage[task]
# We can assume that all tasks have equal counts, since they only affect distribution, not presence
target_task_stats = target_stats["binary"]
n_samples = target_task_stats.training_cells
coverage = target_task_stats.coverage
for task, target_task_stats in target_stats.items():
assert n_samples == target_task_stats.training_cells, (
"Inconsistent sample counts across tasks for the same target."
)
assert coverage == target_task_stats.coverage, (
"Inconsistent coverage across tasks for the same target."
)
rows.append(
{
"Grid": grid_config.display_name,
"Target": target_name.replace("darts_", ""),
"Task": task.capitalize(),
"Samples (Coverage)": training_cells,
"Target": target_name.replace("_", " ").title(),
"Temporal Mode": str(temporal_mode).capitalize(),
"Samples (Coverage)": n_samples,
"Coverage %": coverage,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
@staticmethod
def get_feature_count_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
def get_inference_sample_count_df(
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
data_sources = list(stats.members.keys())
# Determine minimum cells across all data sources
mode_stats = all_stats[grid_config.id]
# We can assume that all temporal modes have equal counts, since they only affect features, not samples
stats = mode_stats["feature"]
assert len(stats.members) > 0, "Dataset must have at least one member."
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
# Get sample count from first target dataset (darts_rts)
first_target = stats.target["darts_rts"]
total_samples = first_target.training_cells["binary"]
for temporal_mode, stats in mode_stats.items():
assert len(stats.members) > 0, "Dataset must have at least one member."
assert (
min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) == min_cells
), "Inconsistent inference cell counts across temporal modes."
rows.append(
{
"Grid": grid_config.display_name,
"Total Features": stats.total_features,
"Data Sources": len(data_sources),
"Inference Cells": min_cells,
"Total Samples": total_samples,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
@staticmethod
def get_comparison_df(
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert comparison data to DataFrame for detailed table."""
rows = []
for grid_config in grid_configs:
mode_stats = all_stats[grid_config.id]
for temporal_mode, stats in mode_stats.items():
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
rows.append(
{
"Grid": grid_config.display_name,
"Temporal Mode": str(temporal_mode).capitalize(),
"Total Features": f"{stats.total_features:,}",
"Total Cells": f"{stats.total_cells:,}",
"Min Cells": f"{min_cells:,}",
"Size (MB)": f"{stats.size_bytes / (1024 * 1024):.2f}",
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows).sort_values(["Grid_Level_Sort", "Temporal Mode"]).drop(columns="Grid_Level_Sort")
@staticmethod
def get_feature_breakdown_df(
all_stats: dict[GridLevel, "DatasetStatistics"],
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
mode_stats = all_stats[grid_config.id]
added_year_temporal_mode = False
for temporal_mode, stats in mode_stats.items():
if added_year_temporal_mode:
continue
if isinstance(temporal_mode, int):
# Only add one year-based temporal mode for the breakdown
added_year_temporal_mode = True
temporal_mode = "Year-Based"
for member_name, member_stats in stats.members.items():
rows.append(
{
"Grid": grid_config.display_name,
"Data Source": member_name,
"Temporal Mode": temporal_mode.capitalize(),
"Member": member_name,
"Number of Features": member_stats.feature_count,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
return (
pd.DataFrame(rows)
.sort_values(["Grid_Level_Sort", "Temporal Mode", "Member"])
.drop(columns="Grid_Level_Sort")
)
@st.cache_data
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
# @st.cache_data # ty:ignore[invalid-argument-type]
def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode, DatasetStatistics]]:
"""Precompute dataset statistics for all grid-level combinations and temporal modes."""
cache_file = entropice.utils.paths.get_dataset_stats_cache()
if cache_file.exists():
with open(cache_file, "rb") as f:
dataset_stats = pickle.load(f)
return dataset_stats
dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {}
for grid_config in grid_configs:
dataset_stats[grid_config.id] = {}
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell."
target_statistics: dict[TargetDataset, TargetStatistics] = {}
for temporal_mode in all_temporal_modes:
e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode)
target_statistics = {}
for target in all_target_datasets:
target_statistics[target] = TargetStatistics.compute(
grid=grid_config.grid,
level=grid_config.level,
target=target,
total_cells=total_cells,
)
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
for member in all_l2_source_datasets:
member_statistics[member] = MemberStatistics.compute(
grid=grid_config.grid, level=grid_config.level, member=member
)
if isinstance(temporal_mode, int) and target == "darts_mllabels":
# darts_mllabels does not support year-based temporal modes
continue
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
member_statistics = MemberStatistics.compute(e)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
ts.size_bytes for ts in target_statistics.values()
)
dataset_stats[grid_config.id] = DatasetStatistics(
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
dataset_stats[grid_config.id][temporal_mode] = DatasetStatistics(
total_features=total_features,
total_cells=total_cells,
size_bytes=total_size_bytes,
@ -248,95 +266,12 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
target=target_statistics,
)
with open(cache_file, "wb") as f:
pickle.dump(dataset_stats, f)
return dataset_stats
@dataclass(frozen=True)
class EnsembleMemberStatistics:
n_features: int # Number of features from this member in the ensemble
p_features: float # Percentage of features from this member in the ensemble
n_nanrows: int # Number of rows which contain any NaN
size_bytes: int # Size of this member's data in the ensemble in bytes
@classmethod
def compute(
cls,
dataset: gpd.GeoDataFrame,
member: L2SourceDataset,
n_features: int,
) -> "EnsembleMemberStatistics":
if member == "AlphaEarth":
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
elif member == "ArcticDEM":
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
era5_cols = dataset.columns.str.startswith("era5_")
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
if member == "ERA5-yearly":
member_dataset = dataset[cols_with_three_splits]
elif member == "ERA5-seasonal":
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
elif member == "ERA5-shoulder":
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
else:
raise NotImplementedError(f"Member {member} not implemented.")
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
n_features_member = len(member_dataset.columns)
p_features_member = n_features_member / n_features
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
return EnsembleMemberStatistics(
n_features=n_features_member,
p_features=p_features_member,
n_nanrows=n_rows_with_nan_member,
size_bytes=size_bytes_member,
)
@dataclass(frozen=True)
class EnsembleDatasetStatistics:
"""Statistics for a specified composition / ensemble at a specific grid and level.
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
That way, the real number of features and cells after NaN filtering can be reported.
"""
n_features: int # Number of features in the dataset
n_cells: int # Number of grid cells covered
n_nanrows: int # Number of rows which contain any NaN
size_bytes: int # Size of the dataset on disk in bytes
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
@classmethod
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
# Assert that no column in all-nan
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
size_bytes = dataset.memory_usage(deep=True).sum()
n_features = len(dataset.columns)
n_cells = len(dataset)
# Number of rows which contain any NaN
n_rows_with_nan = dataset.isna().any(axis=1).sum()
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
for member in ensemble.members:
member_statistics[member] = EnsembleMemberStatistics.compute(
dataset=dataset,
member=member,
n_features=n_features,
)
return cls(
n_features=n_features,
n_cells=n_cells,
n_nanrows=n_rows_with_nan,
size_bytes=size_bytes,
members=member_statistics,
)
@dataclass(frozen=True)
class TrainingDatasetStatistics:
n_samples: int # Total number of samples in the dataset
@ -346,14 +281,14 @@ class TrainingDatasetStatistics:
n_test_samples: int # Number of cells used for testing
train_test_ratio: float # Ratio of training to test samples
class_labels: list[str] # Ordered list of class labels
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
n_classes: int # Number of classes
class_labels: list[str] | None # Ordered list of class labels
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] | None # Min/max raw values
n_classes: int | None # Number of classes
training_class_counts: dict[str, int] # Class counts in training set
training_class_distribution: dict[str, float] # Class percentages in training set
test_class_counts: dict[str, int] # Class counts in test set
test_class_distribution: dict[str, float] # Class percentages in test set
training_class_counts: dict[str, int] | None # Class counts in training set
training_class_distribution: dict[str, float] | None # Class percentages in training set
test_class_counts: dict[str, int] | None # Class counts in test set
test_class_distribution: dict[str, float] | None # Class percentages in test set
raw_value_min: float # Minimum raw target value
raw_value_max: float # Maximum raw target value
@ -361,7 +296,7 @@ class TrainingDatasetStatistics:
raw_value_median: float # Median raw target value
raw_value_std: float # Standard deviation of raw target values
imbalance_ratio: float # Smallest class count / largest class count (overall)
imbalance_ratio: float | None # Smallest class count / largest class count (overall)
size_bytes: int # Total memory usage of features in bytes
@classmethod
@ -369,58 +304,64 @@ class TrainingDatasetStatistics:
cls,
ensemble: DatasetEnsemble,
task: Task,
dataset: gpd.GeoDataFrame | None = None,
target: TargetDataset,
) -> "TrainingDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
training_dataset = ensemble.create_training_set(task=task, target=target)
# Sample counts
n_samples = len(categorical_dataset)
n_training_samples = len(categorical_dataset.y.train)
n_test_samples = len(categorical_dataset.y.test)
n_samples = len(training_dataset)
n_training_samples = (training_dataset.split == "train").sum()
n_test_samples = (training_dataset.split == "test").sum()
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
# Feature statistics
n_features = len(categorical_dataset.X.data.columns)
feature_names = list(categorical_dataset.X.data.columns)
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
n_features = len(training_dataset.features.columns)
feature_names = training_dataset.features.columns.tolist()
size_bytes = training_dataset.features.memory_usage(deep=True).sum()
# Class information
class_labels = categorical_dataset.y.labels
class_intervals = categorical_dataset.y.intervals
class_labels = training_dataset.target_labels
if class_labels is None:
class_intervals = None
n_classes = None
training_class_counts = None
training_class_distribution = None
test_class_counts = None
test_class_distribution = None
imbalance_ratio = None
else:
class_intervals = training_dataset.target_intervals
n_classes = len(class_labels)
# Training class distribution
train_y_series = pd.Series(categorical_dataset.y.train)
train_counts = train_y_series.value_counts().sort_index()
train_y = training_dataset.targets[training_dataset.split == "train"]["y"]
test_y = training_dataset.targets[training_dataset.split == "test"]["y"]
train_counts = train_y.value_counts().sort_index()
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
train_total = sum(training_class_counts.values())
training_class_distribution = {
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
}
# Test class distribution
test_y_series = pd.Series(categorical_dataset.y.test)
test_counts = test_y_series.value_counts().sort_index()
test_counts = test_y.value_counts().sort_index()
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
test_total = sum(test_class_counts.values())
test_class_distribution = {
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
}
# Imbalance ratio (smallest class / largest class across both splits)
all_counts = list(train_counts.values()) + list(test_class_counts.values())
nonzero_counts = [c for c in all_counts if c > 0]
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
# Raw value statistics
raw_values = categorical_dataset.y.raw_values
raw_values = training_dataset.targets["z"]
raw_value_min = float(raw_values.min())
raw_value_max = float(raw_values.max())
raw_value_mean = float(raw_values.mean())
raw_value_median = float(raw_values.median())
raw_value_std = float(raw_values.std())
# Imbalance ratio (smallest class / largest class across both splits)
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
nonzero_counts = [c for c in all_counts if c > 0]
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
return cls(
n_samples=n_samples,
n_features=n_features,

View file

@ -1,508 +1,17 @@
"""Overview page: List of available result directories with some summary statistics."""
from datetime import datetime
from typing import cast
import pandas as pd
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.overview import (
create_feature_breakdown_donut,
create_feature_count_stacked_bar,
create_feature_distribution_pie,
create_inference_cells_bar,
create_sample_count_bar_chart,
from entropice.dashboard.sections.dataset_statistics import render_dataset_statistics
from entropice.dashboard.sections.experiment_results import (
render_experiment_results,
render_training_results_summary,
)
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import (
DatasetStatistics,
load_all_default_dataset_statistics,
)
from entropice.utils.types import (
GridConfig,
L2SourceDataset,
TargetDataset,
grid_configs,
)
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.markdown(
"""
This visualization shows the number of available training samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
- **Grid Level**: varying by grid type
"""
)
# Get sample count DataFrame from cache
all_stats = load_all_default_dataset_statistics()
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
# Get color palettes for each target dataset
n_tasks = sample_df["Task"].nunique()
target_color_maps = {
"rts": get_palette("task_types", n_colors=n_tasks),
"mllabels": get_palette("data_sources", n_colors=n_tasks),
}
# Create and display bar chart
fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
st.plotly_chart(fig, use_container_width=True)
# Display full table with formatting
st.markdown("#### Detailed Sample Counts")
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
# Format numbers with commas
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
# Format coverage as percentage with 2 decimal places
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(display_df, hide_index=True, use_container_width=True)
def render_feature_count_comparison():
"""Render static comparison of feature counts across all grid configurations."""
st.markdown(
"""
Comparing dataset characteristics for all grid configurations with all data sources enabled.
- **Features**: Total number of input features from all data sources
- **Spatial Coverage**: Number of grid cells with complete data coverage
"""
)
# Get data from cache
all_stats = load_all_default_dataset_statistics()
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
# Create and display stacked bar chart
fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
st.plotly_chart(fig, use_container_width=True)
# Add spatial coverage metric
n_grids = len(comparison_df)
grid_colors = get_palette("grid_configs", n_colors=n_grids)
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
st.plotly_chart(fig_cells, use_container_width=True)
# Display full comparison table with formatting
st.markdown("#### Detailed Comparison Table")
display_df = comparison_df[
[
"Grid",
"Total Features",
"Data Sources",
"Inference Cells",
]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
st.dataframe(display_df, hide_index=True, use_container_width=True)
@st.fragment
def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments."""
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
grid_options = [gc.display_name for gc in grid_configs]
col1, col2 = st.columns(2)
with col1:
grid_level_combined = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
key="feature_grid_select",
)
with col2:
target = st.selectbox(
"Target Dataset",
options=["darts_rts", "darts_mllabels"],
index=0,
help="Select the target dataset",
key="feature_target_select",
)
# Find the selected grid config
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Get available members from the stats
all_stats = load_all_default_dataset_statistics()
stats = all_stats[selected_grid_config.id]
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
# Members selection
st.markdown("#### Select Data Sources")
all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
# Use columns for checkboxes
cols = st.columns(len(all_members))
selected_members: list[L2SourceDataset] = []
for idx, member in enumerate(all_members):
with cols[idx]:
default_value = member in available_members
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
selected_members.append(cast(L2SourceDataset, member))
# Show results if at least one member is selected
if selected_members:
# Get statistics from cache (already loaded)
grid_stats = all_stats[selected_grid_config.id]
# Filter to selected members only
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
# Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
# Get target stats
target_stats = grid_stats.target[cast(TargetDataset, target)]
# High-level metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Total Features", f"{total_features:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
st.metric(
"Inference Cells",
f"{min_cells:,}",
help="Number of union of cells across all data sources",
)
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
# Use binary task training cells as sample count
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
with col5:
# Calculate total data points
total_points = total_features * target_stats.training_cells["binary"]
st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
# Create and display pie chart
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch")
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, width="stretch")
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats.feature_count)
with metric_cols[1]:
st.metric("Variables", len(member_stats.variable_names))
with metric_cols[2]:
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
total_points = 1
for dim_size in member_stats.dimensions.values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
# Size on disk
size_mb = member_stats.size_bytes / (1024 * 1024)
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
# Create tabs for different analysis views
analysis_tabs = st.tabs(
[
"📊 Training Samples",
"📈 Dataset Characteristics",
"🔍 Feature Breakdown",
"⚙️ Configuration Explorer",
]
)
with analysis_tabs[0]:
st.subheader("Training Samples by Configuration")
render_sample_count_overview()
with analysis_tabs[1]:
st.subheader("Dataset Characteristics Across Grid Configurations")
render_feature_count_comparison()
with analysis_tabs[2]:
st.subheader("Feature Breakdown by Data Source")
# Get data from cache
all_stats = load_all_default_dataset_statistics()
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Sparse Resolution girds
for res in ["sparse", "low", "medium"]:
cols = st.columns(2)
with cols[0]:
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
for gc in grid_configs_res:
grid_display = gc.display_name
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
with cols[1]:
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
for gc in grid_configs_res:
grid_display = gc.display_name
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
# Create donut charts for each grid configuration
# num_grids = len(comparison_df)
# cols_per_row = 3
# num_rows = (num_grids + cols_per_row - 1) // cols_per_row
# for row_idx in range(num_rows):
# cols = st.columns(cols_per_row)
# for col_idx in range(cols_per_row):
# grid_idx = row_idx * cols_per_row + col_idx
# if grid_idx < num_grids:
# grid_config = comparison_df.iloc[grid_idx]["Grid"]
# grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
# with cols[col_idx]:
# fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
# st.plotly_chart(fig, use_container_width=True)
with analysis_tabs[3]:
st.subheader("Interactive Configuration Explorer")
render_feature_count_explorer()
def render_training_results_summary(training_results):
"""Render summary metrics for training results."""
st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4)
with col1:
tasks = {tr.settings.task for tr in training_results}
st.metric("Tasks", len(tasks))
with col2:
grids = {tr.settings.grid for tr in training_results}
st.metric("Grid Types", len(grids))
with col3:
models = {tr.settings.model for tr in training_results}
st.metric("Model Types", len(models))
with col4:
latest = training_results[0] # Already sorted by creation time
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
st.metric("Latest Run", latest_date)
def render_experiment_results(training_results):
"""Render detailed experiment results table and expandable details."""
st.header("🎯 Experiment Results")
st.subheader("Results Table")
# Build a summary dataframe
summary_data = []
for tr in training_results:
# Extract best scores from the results dataframe
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
best_scores = {}
for col in score_cols:
metric_name = col.replace("mean_test_", "")
best_score = tr.results[col].max()
best_scores[metric_name] = best_score
# Get primary metric (usually the first one or accuracy)
primary_metric = (
"accuracy"
if "mean_test_accuracy" in tr.results.columns
else score_cols[0].replace("mean_test_", "")
if score_cols
else "N/A"
)
primary_score = best_scores.get(primary_metric, 0.0)
summary_data.append(
{
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
"Task": tr.settings.task,
"Grid": tr.settings.grid,
"Level": tr.settings.level,
"Model": tr.settings.model,
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
"Trials": len(tr.results),
"Path": str(tr.path.name),
}
)
summary_df = pd.DataFrame(summary_data)
# Display with color coding for best scores
st.dataframe(
summary_df,
width="stretch",
hide_index=True,
)
st.divider()
# Expandable details for each result
st.subheader("Individual Experiment Details")
for tr in training_results:
display_name = (
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
)
with st.expander(display_name):
col1, col2 = st.columns([1, 2])
with col1:
st.write("**Configuration:**")
st.write(f"- **Task:** {tr.settings.task}")
st.write(f"- **Grid:** {tr.settings.grid}")
st.write(f"- **Level:** {tr.settings.level}")
st.write(f"- **Model:** {tr.settings.model}")
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
st.write(f"- **Classes:** {tr.settings.classes}")
st.write("\n**Files:**")
st.write("- 📊 search_results.parquet")
st.write("- 🧮 best_estimator_state.nc")
st.write("- 🎯 predicted_probabilities.parquet")
st.write("- ⚙️ search_settings.toml")
with col2:
st.write("**Best Scores:**")
# Extract all test scores
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
if score_cols:
metric_data = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").title()
best_score = tr.results[col].max()
mean_score = tr.results[col].mean()
std_score = tr.results[col].std()
metric_data.append(
{
"Metric": metric_name,
"Best": f"{best_score:.4f}",
"Mean": f"{mean_score:.4f}",
"Std": f"{std_score:.4f}",
}
)
metric_df = pd.DataFrame(metric_data)
st.dataframe(metric_df, width="stretch", hide_index=True)
else:
st.write("No test scores found in results.")
# Show parameter space explored
if "initial_K" in tr.results.columns: # Common parameter
st.write("\n**Parameter Ranges Explored:**")
for param in ["initial_K", "eps_cl", "eps_e"]:
if param in tr.results.columns:
min_val = tr.results[param].min()
max_val = tr.results[param].max()
unique_vals = tr.results[param].nunique()
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
st.write(f"\n**Path:** `{tr.path}`")
def render_overview_page():
@ -537,7 +46,8 @@ def render_overview_page():
st.divider()
# Render dataset analysis section
render_dataset_analysis()
all_stats = load_all_default_dataset_statistics()
render_dataset_statistics(all_stats)
st.balloons()
stopwatch.summary()

View file

@ -48,7 +48,7 @@ class AutoGluonSettings:
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
"""Combined settings for AutoGluon training."""
classes: list[str]
classes: list[str] | None
problem_type: str
@ -228,7 +228,7 @@ def autogluon_train(
combined_settings = AutoGluonTrainingSettings(
**asdict(settings),
**asdict(dataset_ensemble),
classes=training_data.y.labels,
classes=training_data.target_labels,
problem_type=problem_type,
)
settings_file = results_dir / "training_settings.toml"

View file

@ -78,7 +78,7 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
pivcols = set(collapsed.index.names) - {"cell_ids"}
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
collapsed.columns = ["_".join(v) for v in collapsed.columns]
collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns]
if use_dummy:
collapsed = collapsed.dropna(how="all")
return collapsed
@ -378,7 +378,8 @@ class DatasetEnsemble:
case ("darts_v1" | "darts_v2", int()):
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
targets = xr.open_zarr(target_store, consolidated=False)
targets = targets.sel(year=self.temporal_mode)
case ("darts_mllabels", str()): # Years are not supported
target_store = entropice.utils.paths.get_darts_file(
grid=self.grid, level=self.level, version="mllabels"
@ -467,6 +468,10 @@ class DatasetEnsemble:
# Apply the temporal mode
match (member.split("-"), self.temporal_mode):
case (["ArcticDEM"], _):
pass # No temporal dimension
case (_, "feature"):
pass
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
ds_mean = ds.mean(dim="year")
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
@ -475,12 +480,8 @@ class DatasetEnsemble:
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
)
ds = xr.merge([ds_mean, ds_trend])
case (["ArcticDEM"], "synopsis"):
pass # No temporal dimension
case (_, int() as year):
ds = ds.sel(year=year, drop=True)
case (_, "feature"):
pass
case _:
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
@ -556,7 +557,7 @@ class DatasetEnsemble:
cell_ids: pd.Series,
era5_agg: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame:
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=True)
era5_df = _collapse_to_dataframe(era5)
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
@ -565,7 +566,10 @@ class DatasetEnsemble:
@stopwatch("Preparing AlphaEarth Embeddings")
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=True)
# For non-synopsis modes there is only a single data variable "embeddings"
if self.temporal_mode != "synopsis":
embeddings = embeddings["embeddings"]
embeddings_df = _collapse_to_dataframe(embeddings)
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
@ -574,7 +578,7 @@ class DatasetEnsemble:
@stopwatch("Preparing ArcticDEM")
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=False)
if len(arcticdem["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
@ -620,7 +624,7 @@ class DatasetEnsemble:
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
# @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
def create_training_set(
self,
task: Task,

View file

@ -3,6 +3,7 @@
import pickle
from dataclasses import asdict, dataclass
from functools import partial
from pathlib import Path
import cyclopts
import numpy as np
@ -115,17 +116,19 @@ class TrainingSettings(DatasetEnsemble, CVSettings):
def random_cv(
dataset_ensemble: DatasetEnsemble,
settings: CVSettings = CVSettings(),
):
experiment: str | None = None,
) -> Path:
"""Perform random cross-validation on the training dataset.
Args:
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
settings (CVSettings): The cross-validation settings.
experiment (str | None): Optional experiment name for results directory.
"""
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
use_array_api = settings.model == "espa"
device = "torch" if use_array_api else "cuda"
use_array_api = settings.model != "xgboost"
device = "torch" if settings.model == "espa" else "cuda"
set_config(array_api_dispatch=use_array_api)
print("Creating training data...")
@ -173,7 +176,8 @@ def random_cv(
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
results_dir = get_cv_results_dir(
"random_search",
experiment=experiment,
name="random_search",
grid=dataset_ensemble.grid,
level=dataset_ensemble.level,
task=settings.task,
@ -283,6 +287,7 @@ def random_cv(
stopwatch.summary()
print("Done.")
return results_dir
if __name__ == "__main__":

View file

@ -139,7 +139,14 @@ def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
return cache_dir / f"cells_{cells_hash}.parquet"
def get_dataset_stats_cache() -> Path:
cache_dir = DATASET_ENSEMBLES_DIR / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "dataset_stats.pckl"
def get_cv_results_dir(
experiment: str | None,
name: str,
grid: Grid,
level: int,
@ -147,7 +154,12 @@ def get_cv_results_dir(
) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
if experiment is not None:
experiment_dir = RESULTS_DIR / experiment
experiment_dir.mkdir(parents=True, exist_ok=True)
else:
experiment_dir = RESULTS_DIR
results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir

View file

@ -15,12 +15,12 @@ type GridLevel = Literal[
"healpix9",
"healpix10",
]
type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
type TargetDataset = Literal["darts_v1", "darts_mllabels"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
# TODO: Consider implementing a "timeseries" temporal mode
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
type Model = Literal["espa", "xgboost", "rf", "knn"]
type Stage = Literal["train", "inference", "visualization"]
@ -37,8 +37,9 @@ class GridConfig:
sort_key: str
@classmethod
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
def from_grid_level(cls, grid_level: GridLevel | tuple[Literal["hex", "healpix"], int]) -> "GridConfig":
"""Create a GridConfig from a GridLevel string."""
if isinstance(grid_level, str):
if grid_level.startswith("hex"):
grid = "hex"
level = int(grid_level[3:])
@ -47,7 +48,9 @@ class GridConfig:
level = int(grid_level[7:])
else:
raise ValueError(f"Invalid grid level: {grid_level}")
else:
grid, level = grid_level
grid_level: GridLevel = f"{grid}{level}" # ty:ignore[invalid-assignment]
display_name = f"{grid.capitalize()}-{level}"
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
@ -74,8 +77,8 @@ class GridConfig:
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
all_target_datasets: list[TargetDataset] = ["darts_v1", "darts_mllabels"]
all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",

222
tests/test_training.py Normal file
View file

@ -0,0 +1,222 @@
"""Tests for training.py module, specifically random_cv function.
This test suite validates the random_cv training function across all model-task
combinations using a minimal hex level 3 grid with synopsis temporal mode.
Test Coverage:
- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
- Device handling for each model type (torch/CUDA/cuML compatibility)
- Multi-label target dataset support
- Temporal mode configuration (synopsis)
- Output file creation and validation
Running Tests:
# Run all training tests (18 tests total, ~3 iterations each)
pixi run pytest tests/test_training.py -v
# Run only device handling tests
pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
# Run a specific model-task combination
pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
Note: Tests use minimal iterations (3) and level 3 grid for speed.
Full production runs use higher iteration counts (100-2000).
"""
import shutil
import pytest
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.training import CVSettings, random_cv
from entropice.utils.types import Model, Task
@pytest.fixture(scope="module")
def test_ensemble():
"""Create a minimal DatasetEnsemble for testing.
Uses hex level 3 grid with synopsis temporal mode for fast testing.
"""
return DatasetEnsemble(
grid="hex",
level=3,
temporal_mode="synopsis",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
)
@pytest.fixture
def cleanup_results():
"""Clean up results directory after each test.
This fixture collects the actual result directories created during tests
and removes them after the test completes.
"""
created_dirs = []
def register_dir(results_dir):
"""Register a directory to be cleaned up."""
created_dirs.append(results_dir)
return results_dir
yield register_dir
# Clean up only the directories created during this test
for results_dir in created_dirs:
if results_dir.exists():
shutil.rmtree(results_dir)
# Model-task combinations to test
# Note: Not all combinations make sense, but we test all to ensure robustness
MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
TASKS: list[Task] = ["binary", "count", "density"]
class TestRandomCV:
"""Test suite for random_cv function."""
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("task", TASKS)
def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
"""Test random_cv with all model-task combinations.
This test runs 3 iterations for each combination to verify:
- The function completes without errors
- Device handling works correctly for each model type
- All output files are created
"""
# Use darts_v1 as the primary target for all tests
settings = CVSettings(
n_iter=3,
task=task,
target="darts_v1",
model=model,
)
# Run the cross-validation and get the results directory
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify results directory was created
assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
# Verify all expected output files exist
expected_files = [
"search_settings.toml",
"best_estimator_model.pkl",
"search_results.parquet",
"metrics.toml",
"predicted_probabilities.parquet",
]
# Add task-specific files
if task in ["binary", "count", "density"]:
# All tasks that use classification (including count/density when binned)
# Note: count and density without _regimes suffix might be regression
if task == "binary" or "_regimes" in task:
expected_files.append("confusion_matrix.nc")
# Add model-specific files
if model in ["espa", "xgboost", "rf"]:
expected_files.append("best_estimator_state.nc")
for filename in expected_files:
filepath = results_dir / filename
assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
@pytest.mark.parametrize("model", MODELS)
def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
"""Test that device handling works correctly for each model type.
Different models require different device configurations:
- espa: Uses torch with array API dispatch
- xgboost: Uses CUDA without array API dispatch
- rf/knn: GPU-accelerated via cuML
"""
settings = CVSettings(
n_iter=3,
task="binary", # Simple binary task for device testing
target="darts_v1",
model=model,
)
# This should complete without device-related errors
try:
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
except RuntimeError as e:
# Check if error is device-related
error_msg = str(e).lower()
device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
if any(keyword in error_msg for keyword in device_keywords):
pytest.fail(f"Device handling error for {model=}: {e}")
else:
# Re-raise non-device errors
raise
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
"""Test random_cv with multi-label target dataset."""
settings = CVSettings(
n_iter=3,
task="binary",
target="darts_mllabels",
model="espa",
)
# Run the cross-validation and get the results directory
results_dir = random_cv(
dataset_ensemble=test_ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify results were created
assert results_dir.exists(), "Results directory not created"
assert (results_dir / "search_settings.toml").exists()
def test_temporal_mode_synopsis(self, cleanup_results):
"""Test that temporal_mode='synopsis' is correctly used."""
import toml
ensemble = DatasetEnsemble(
grid="hex",
level=3,
temporal_mode="synopsis",
members=["AlphaEarth"],
add_lonlat=True,
)
settings = CVSettings(
n_iter=3,
task="binary",
target="darts_v1",
model="espa",
)
# This should use synopsis mode (all years aggregated)
results_dir = random_cv(
dataset_ensemble=ensemble,
settings=settings,
experiment="test_training",
)
cleanup_results(results_dir)
# Verify the settings were stored correctly
assert results_dir.exists(), "Results directory not created"
with open(results_dir / "search_settings.toml") as f:
stored_settings = toml.load(f)
assert stored_settings["settings"]["temporal_mode"] == "synopsis"

View file

@ -0,0 +1,38 @@
import cyclopts
from rich import pretty, print, traceback
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import all_temporal_modes, grid_configs
pretty.install()
traceback.install()
cli = cyclopts.App()
def _gather_ensemble_stats(e: DatasetEnsemble):
# Get a small sample of the cell ids
sample_cell_ids = e.cell_ids[:5]
features = e.make_features(sample_cell_ids)
print(
f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
)
print(f"Number of feature columns: {len(features.columns)}")
for member in ["arcticdem", "embeddings", "era5"]:
member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
print()
@cli.default()
def validate_datasets():
for gc in grid_configs:
for temporal_mode in all_temporal_modes:
e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
_gather_ensemble_stats(e)
if __name__ == "__main__":
cli()