diff --git a/pyproject.toml b/pyproject.toml
index 6615db7..fd9fe67 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,7 +68,7 @@ dependencies = [
"pandas-stubs>=2.3.3.251201,<3",
"pytest>=9.0.2,<10",
"autogluon-tabular[all,mitra]>=1.5.0",
- "shap>=0.50.0,<0.51",
+ "shap>=0.50.0,<0.51", "h5py>=3.15.1,<4",
]
[project.scripts]
diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 713f7d0..6eea4d9 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -5,12 +5,14 @@ Pages:
- Overview: List of available result directories with some summary statistics.
- Training Data: Visualization of training data distributions.
- Training Results Analysis: Analysis of training results and model performance.
+- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
- Model State: Visualization of model state and features.
- Inference: Visualization of inference results.
"""
import streamlit as st
+from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
@@ -26,13 +28,14 @@ def main():
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
+ autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
- "Training": [training_data_page, training_analysis_page],
+ "Training": [training_data_page, training_analysis_page, autogluon_page],
"Model State": [model_state_page],
"Inference": [inference_page],
}
diff --git a/src/entropice/dashboard/plots/dataset_statistics.py b/src/entropice/dashboard/plots/dataset_statistics.py
new file mode 100644
index 0000000..a4685c5
--- /dev/null
+++ b/src/entropice/dashboard/plots/dataset_statistics.py
@@ -0,0 +1,458 @@
+"""Visualization functions for the overview page.
+
+This module contains reusable plotting functions for dataset analysis visualizations,
+including sample counts, feature counts, and dataset statistics.
+"""
+
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+
+from entropice.dashboard.utils.colors import get_palette
+
+
+def _get_colors_from_members(unique_members: list[str]) -> dict[str, str]:
+ unique_members = sorted(unique_members)
+ unique_sources = {member.split("-")[0] for member in unique_members}
+ source_member_map = {
+ source: [member for member in unique_members if member.split("-")[0] == source] for source in unique_sources
+ }
+ source_color_map = {}
+ for source, members in source_member_map.items():
+ n_members = len(members) + 2 # Avoid very light/dark colors
+ palette = get_palette(source, n_colors=n_members)
+ for idx, member in enumerate(members, start=1):
+ source_color_map[member] = palette[idx]
+ return source_color_map
+
+
+def create_sample_count_bar_chart(sample_df: pd.DataFrame) -> go.Figure:
+ """Create bar chart showing sample counts by grid, target and temporal mode.
+
+ Args:
+ sample_df: DataFrame with columns: Grid, Target, Samples (Coverage).
+
+ Returns:
+ Plotly Figure object containing the bar chart visualization.
+
+ """
+ assert "Grid" in sample_df.columns
+ assert "Target" in sample_df.columns
+ assert "Temporal Mode" in sample_df.columns
+ assert "Samples (Coverage)" in sample_df.columns
+
+ targets = sorted(sample_df["Target"].unique())
+ modes = sorted(sample_df["Temporal Mode"].unique())
+
+ # Create subplots manually to have better control over colors
+ fig = make_subplots(
+ rows=1,
+ cols=len(targets),
+ subplot_titles=[f"Target: {target}" for target in targets],
+ shared_yaxes=True,
+ )
+ # Get color palette for this plot
+ colors = get_palette("Number of Samples", n_colors=len(modes))
+
+ # Track which modes have been added to legend
+ modes_in_legend = set()
+
+ for col_idx, target in enumerate(targets, start=1):
+ target_data = sample_df[sample_df["Target"] == target]
+
+ for mode_idx, mode in enumerate(modes):
+ mode_data = target_data[target_data["Temporal Mode"] == mode]
+ if len(mode_data) == 0:
+ continue
+ color = colors[mode_idx] if colors and mode_idx < len(colors) else None
+
+ # Show legend only the first time each mode appears
+ show_in_legend = mode not in modes_in_legend
+ if show_in_legend:
+ modes_in_legend.add(mode)
+
+ # Create a unique legendgroup per mode so colors are consistent
+ fig.add_trace(
+ go.Bar(
+ x=mode_data["Grid"],
+ y=mode_data["Samples (Coverage)"],
+ name=mode,
+ marker_color=color,
+ legendgroup=mode, # Group by mode name
+ showlegend=show_in_legend,
+ ),
+ row=1,
+ col=col_idx,
+ )
+
+ fig.update_layout(
+ title_text="Training Sample Counts by Grid Configuration and Target Dataset",
+ barmode="group",
+ height=500,
+ showlegend=True,
+ )
+
+ fig.update_xaxes(tickangle=-45)
+ fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
+
+ return fig
+
+
+def create_feature_count_stacked_bar(breakdown_df: pd.DataFrame) -> go.Figure:
+ """Create stacked bar chart showing feature counts by member.
+
+ Args:
+ breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
+
+ Returns:
+ Plotly Figure object containing the stacked bar chart visualization.
+
+ """
+ assert "Grid" in breakdown_df.columns
+ assert "Temporal Mode" in breakdown_df.columns
+ assert "Member" in breakdown_df.columns
+ assert "Number of Features" in breakdown_df.columns
+
+ unique_members = sorted(breakdown_df["Member"].unique())
+ temporal_modes = sorted(breakdown_df["Temporal Mode"].unique())
+ source_color_map = _get_colors_from_members(unique_members)
+
+ # Create subplots for each temporal mode
+ fig = make_subplots(
+ rows=1,
+ cols=len(temporal_modes),
+ subplot_titles=[f"Temporal Mode: {mode}" for mode in temporal_modes],
+ shared_yaxes=True,
+ )
+
+ for col_idx, mode in enumerate(temporal_modes, start=1):
+ mode_data = breakdown_df[breakdown_df["Temporal Mode"] == mode]
+
+ for member in unique_members:
+ member_data = mode_data[mode_data["Member"] == member]
+ color = source_color_map.get(member)
+
+ fig.add_trace(
+ go.Bar(
+ x=member_data["Grid"],
+ y=member_data["Number of Features"],
+ name=member,
+ marker_color=color,
+ legendgroup=member,
+ showlegend=(col_idx == 1),
+ ),
+ row=1,
+ col=col_idx,
+ )
+
+ fig.update_layout(
+ title_text="Input Features by Member Across Grid Configurations",
+ barmode="stack",
+ height=500,
+ showlegend=True,
+ )
+ fig.update_xaxes(tickangle=-45)
+ fig.update_yaxes(title_text="Number of Features", row=1, col=1)
+
+ return fig
+
+
+def create_inference_cells_bar(sample_df: pd.DataFrame) -> go.Figure:
+ """Create bar chart for inference cells by grid configuration.
+
+ Args:
+ sample_df: DataFrame with columns: Grid, Temporal Mode, Inference Cells.
+
+ Returns:
+ Plotly Figure object containing the bar chart visualization.
+
+ """
+ assert "Grid" in sample_df.columns
+ assert "Inference Cells" in sample_df.columns
+
+ n_grids = sample_df["Grid"].nunique()
+ mode_colors: list[str] = get_palette("Number of Samples", n_colors=n_grids)
+
+ fig = px.bar(
+ sample_df,
+ x="Grid",
+ y="Inference Cells",
+ color="Grid",
+ title="Spatial Coverage (Grid Cells with Complete Data)",
+ labels={
+ "Grid": "Grid Configuration",
+ "Inference Cells": "Number of Cells",
+ },
+ color_discrete_sequence=mode_colors,
+ text="Inference Cells",
+ )
+
+ fig.update_traces(texttemplate="%{text:,}", textposition="outside")
+ fig.update_layout(xaxis_tickangle=-45, height=500, showlegend=False)
+
+ return fig
+
+
+def create_total_samples_bar(
+ comparison_df: pd.DataFrame,
+ grid_colors: list[str] | None = None,
+) -> go.Figure:
+ """Create bar chart for total samples by grid configuration.
+
+ Args:
+ comparison_df: DataFrame with columns: Grid, Total Samples.
+ grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
+
+ Returns:
+ Plotly Figure object containing the bar chart visualization.
+
+ """
+ fig = px.bar(
+ comparison_df,
+ x="Grid",
+ y="Total Samples",
+ color="Grid",
+ title="Training Samples (Binary Task)",
+ labels={
+ "Grid": "Grid Configuration",
+ "Total Samples": "Number of Samples",
+ },
+ color_discrete_sequence=grid_colors,
+ text="Total Samples",
+ )
+
+ fig.update_traces(texttemplate="%{text:,}", textposition="outside")
+ fig.update_layout(xaxis_tickangle=-45, showlegend=False)
+
+ return fig
+
+
+def create_feature_breakdown_donut(
+ grid_data: pd.DataFrame,
+ grid_config: str,
+ source_color_map: dict[str, str] | None = None,
+) -> go.Figure:
+ """Create donut chart for feature breakdown by data source for a specific grid.
+
+ Args:
+ grid_data: DataFrame with columns: Data Source, Number of Features.
+ grid_config: Grid configuration name for the title.
+ source_color_map: Optional dictionary mapping data source names to specific colors.
+ If None, uses default Plotly colors.
+
+ Returns:
+ Plotly Figure object containing the donut chart visualization.
+
+ """
+ fig = px.pie(
+ grid_data,
+ names="Data Source",
+ values="Number of Features",
+ title=grid_config,
+ hole=0.4,
+ color_discrete_map=source_color_map,
+ color="Data Source",
+ )
+
+ fig.update_traces(textposition="inside", textinfo="percent")
+ fig.update_layout(showlegend=True, height=350)
+
+ return fig
+
+
+def create_feature_distribution_pie(
+ breakdown_df: pd.DataFrame,
+ source_color_map: dict[str, str] | None = None,
+) -> go.Figure:
+ """Create pie chart for feature distribution by data source.
+
+ Args:
+ breakdown_df: DataFrame with columns: Data Source, Number of Features.
+ source_color_map: Optional dictionary mapping data source names to specific colors.
+ If None, uses default Plotly colors.
+
+ Returns:
+ Plotly Figure object containing the pie chart visualization.
+
+ """
+ fig = px.pie(
+ breakdown_df,
+ names="Data Source",
+ values="Number of Features",
+ title="Feature Distribution by Data Source",
+ hole=0.4,
+ color_discrete_map=source_color_map,
+ color="Data Source",
+ )
+
+ fig.update_traces(textposition="inside", textinfo="percent+label")
+ fig.update_layout(height=400)
+
+ return fig
+
+
+def create_member_details_table(member_stats_dict: dict[str, dict]) -> pd.DataFrame:
+ """Create a detailed table for member statistics.
+
+ Args:
+ member_stats_dict: Dictionary mapping member names to their statistics.
+ Each stats dict should contain: feature_count, variable_names, dimensions, size_bytes.
+
+ Returns:
+ DataFrame with detailed member information.
+
+ """
+ rows = []
+ for member, stats in member_stats_dict.items():
+ # Calculate total data points
+ total_points = 1
+ for dim_size in stats["dimensions"].values():
+ total_points *= dim_size
+
+ # Format dimension string
+ dim_str = " x ".join([f"{k}={v:,}" for k, v in stats["dimensions"].items()])
+
+ rows.append(
+ {
+ "Member": member,
+ "Features": stats["feature_count"],
+ "Variables": len(stats["variable_names"]),
+ "Dimensions": dim_str,
+ "Data Points": f"{total_points:,}",
+ "Size (MB)": f"{stats['size_bytes'] / (1024 * 1024):.2f}",
+ }
+ )
+
+ return pd.DataFrame(rows)
+
+
+def create_feature_breakdown_donuts_grid(breakdown_df: pd.DataFrame) -> go.Figure:
+ """Create grid of donut charts showing feature breakdown by member for each grid+temporal mode.
+
+ Args:
+ breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
+
+ Returns:
+ Plotly Figure object containing the grid of donut charts.
+
+ """
+ assert "Grid" in breakdown_df.columns
+ assert "Temporal Mode" in breakdown_df.columns
+ assert "Member" in breakdown_df.columns
+ assert "Number of Features" in breakdown_df.columns
+
+ # Get unique grids and temporal modes
+ grids = breakdown_df["Grid"].unique().tolist()
+ temporal_modes = breakdown_df["Temporal Mode"].unique().tolist()
+ unique_members = sorted(breakdown_df["Member"].unique())
+ source_color_map = _get_colors_from_members(unique_members)
+
+ n_grids = len(grids)
+ n_modes = len(temporal_modes)
+
+ # Create subplot grid (grids as rows, temporal modes as columns)
+ # Don't use subplot_titles - we'll add custom annotations for row/column labels
+ fig = make_subplots(
+ rows=n_grids,
+ cols=n_modes,
+ specs=[[{"type": "pie"}] * n_modes for _ in range(n_grids)],
+ horizontal_spacing=0.01,
+ vertical_spacing=0.02,
+ )
+
+ # Get all members to ensure consistent coloring across all donuts
+ all_members = sorted(breakdown_df["Member"].unique())
+
+ for row_idx, grid in enumerate(grids, start=1):
+ for col_idx, mode in enumerate(temporal_modes, start=1):
+ subset = breakdown_df[(breakdown_df["Grid"] == grid) & (breakdown_df["Temporal Mode"] == mode)]
+
+ if len(subset) == 0:
+ continue
+
+ # Build labels, values, and colors in consistent order
+ labels = []
+ values = []
+ colors_list = []
+
+ for member in all_members:
+ member_data = subset[subset["Member"] == member]
+ if len(member_data) > 0:
+ labels.append(member)
+ values.append(member_data["Number of Features"].iloc[0])
+ if source_color_map:
+ colors_list.append(source_color_map[member])
+
+ # Only show legend for first subplot
+ show_legend = row_idx == 1 and col_idx == 1
+
+ fig.add_trace(
+ go.Pie(
+ labels=labels,
+ values=values,
+ name=f"{grid} - {mode}",
+ hole=0.4,
+ marker={"colors": colors_list} if colors_list else None,
+ textposition="inside",
+ textinfo="percent",
+ showlegend=show_legend,
+ ),
+ row=row_idx,
+ col=col_idx,
+ )
+
+ # Calculate appropriate height based on number of rows
+ height = max(500, n_grids * 400)
+
+ fig.update_layout(
+ title_text="Feature Breakdown by Member Across Grid Configurations and Temporal Modes",
+ height=height,
+ showlegend=True,
+ margin={"l": 100, "r": 50, "t": 150, "b": 50}, # Add left margin for row labels
+ )
+
+ # Calculate subplot positions in paper coordinates
+ # Account for spacing between subplots
+ h_spacing = 0.01
+ v_spacing = 0.02
+
+ # Calculate width and height of each subplot in paper coordinates
+ subplot_width = (1.0 - h_spacing * (n_modes - 1)) / n_modes
+ subplot_height = (1.0 - v_spacing * (n_grids - 1)) / n_grids
+
+ # Add column headers at the top (temporal modes)
+ for col_idx, mode in enumerate(temporal_modes):
+ # Calculate x position for this column's center
+ x_pos = col_idx * (subplot_width + h_spacing) + subplot_width / 2
+ fig.add_annotation(
+ text=f"{mode}",
+ xref="paper",
+ yref="paper",
+ x=x_pos,
+ y=1.01,
+ showarrow=False,
+ xanchor="center",
+ yanchor="bottom",
+ font={"size": 18},
+ )
+
+ # Add row headers on the left (grid configurations)
+ for row_idx, grid in enumerate(grids):
+ # Calculate y position for this row's center
+ # Note: y coordinates go from bottom to top, but rows go from top to bottom
+ y_pos = 1.0 - (row_idx * (subplot_height + v_spacing) + subplot_height / 2)
+ fig.add_annotation(
+ text=f"{grid}",
+ xref="paper",
+ yref="paper",
+ x=-0.01,
+ y=y_pos,
+ showarrow=False,
+ xanchor="right",
+ yanchor="middle",
+ font={"size": 18},
+ textangle=-90,
+ )
+
+ return fig
diff --git a/src/entropice/dashboard/plots/overview.py b/src/entropice/dashboard/plots/overview.py
deleted file mode 100644
index 91de690..0000000
--- a/src/entropice/dashboard/plots/overview.py
+++ /dev/null
@@ -1,276 +0,0 @@
-"""Visualization functions for the overview page.
-
-This module contains reusable plotting functions for dataset analysis visualizations,
-including sample counts, feature counts, and dataset statistics.
-"""
-
-import pandas as pd
-import plotly.express as px
-import plotly.graph_objects as go
-
-
-def create_sample_count_heatmap(
- pivot_df: pd.DataFrame,
- target: str,
- colorscale: list[str] | None = None,
-) -> go.Figure:
- """Create heatmap showing sample counts by grid and task.
-
- Args:
- pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values.
- target: Target dataset name for the title.
- colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the heatmap visualization.
-
- """
- fig = px.imshow(
- pivot_df,
- labels={
- "x": "Task",
- "y": "Grid Configuration",
- "color": "Sample Count",
- },
- x=pivot_df.columns,
- y=pivot_df.index,
- color_continuous_scale=colorscale,
- aspect="auto",
- title=f"Target: {target}",
- )
-
- # Add text annotations
- fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
- fig.update_layout(height=400)
-
- return fig
-
-
-def create_sample_count_bar_chart(
- sample_df: pd.DataFrame,
- target_color_maps: dict[str, list[str]] | None = None,
-) -> go.Figure:
- """Create bar chart showing sample counts by grid, target, and task.
-
- Args:
- sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
- target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
- If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the bar chart visualization.
-
- """
- # Create subplots manually to have better control over colors
- from plotly.subplots import make_subplots
-
- targets = sorted(sample_df["Target"].unique())
- tasks = sorted(sample_df["Task"].unique())
-
- fig = make_subplots(
- rows=1,
- cols=len(targets),
- subplot_titles=[f"Target: {target}" for target in targets],
- shared_yaxes=True,
- )
-
- for col_idx, target in enumerate(targets, 1):
- target_data = sample_df[sample_df["Target"] == target]
- # Get color palette for this target
- colors = target_color_maps.get(target, None) if target_color_maps else None
-
- for task_idx, task in enumerate(tasks):
- task_data = target_data[target_data["Task"] == task]
- color = colors[task_idx] if colors and task_idx < len(colors) else None
-
- # Create a unique legendgroup per task so colors are consistent
- fig.add_trace(
- go.Bar(
- x=task_data["Grid"],
- y=task_data["Samples (Coverage)"],
- name=task,
- marker_color=color,
- legendgroup=task, # Group by task name
- showlegend=(col_idx == 1), # Only show legend for first subplot
- ),
- row=1,
- col=col_idx,
- )
-
- fig.update_layout(
- title_text="Training Sample Counts by Grid Configuration and Target Dataset",
- barmode="group",
- height=500,
- showlegend=True,
- )
-
- fig.update_xaxes(tickangle=-45)
- fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
-
- return fig
-
-
-def create_feature_count_stacked_bar(
- breakdown_df: pd.DataFrame,
- source_color_map: dict[str, str] | None = None,
-) -> go.Figure:
- """Create stacked bar chart showing feature counts by data source.
-
- Args:
- breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
- source_color_map: Optional dictionary mapping data source names to specific colors.
- If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the stacked bar chart visualization.
-
- """
- fig = px.bar(
- breakdown_df,
- x="Grid",
- y="Number of Features",
- color="Data Source",
- barmode="stack",
- title="Input Features by Data Source Across Grid Configurations",
- labels={
- "Grid": "Grid Configuration",
- "Number of Features": "Number of Features",
- },
- color_discrete_map=source_color_map,
- text_auto=False,
- )
-
- fig.update_layout(height=500, xaxis_tickangle=-45)
-
- return fig
-
-
-def create_inference_cells_bar(
- comparison_df: pd.DataFrame,
- grid_colors: list[str] | None = None,
-) -> go.Figure:
- """Create bar chart for inference cells by grid configuration.
-
- Args:
- comparison_df: DataFrame with columns: Grid, Inference Cells.
- grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the bar chart visualization.
-
- """
- fig = px.bar(
- comparison_df,
- x="Grid",
- y="Inference Cells",
- color="Grid",
- title="Spatial Coverage (Grid Cells with Complete Data)",
- labels={
- "Grid": "Grid Configuration",
- "Inference Cells": "Number of Cells",
- },
- color_discrete_sequence=grid_colors,
- text="Inference Cells",
- )
-
- fig.update_traces(texttemplate="%{text:,}", textposition="outside")
- fig.update_layout(xaxis_tickangle=-45, showlegend=False)
-
- return fig
-
-
-def create_total_samples_bar(
- comparison_df: pd.DataFrame,
- grid_colors: list[str] | None = None,
-) -> go.Figure:
- """Create bar chart for total samples by grid configuration.
-
- Args:
- comparison_df: DataFrame with columns: Grid, Total Samples.
- grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the bar chart visualization.
-
- """
- fig = px.bar(
- comparison_df,
- x="Grid",
- y="Total Samples",
- color="Grid",
- title="Training Samples (Binary Task)",
- labels={
- "Grid": "Grid Configuration",
- "Total Samples": "Number of Samples",
- },
- color_discrete_sequence=grid_colors,
- text="Total Samples",
- )
-
- fig.update_traces(texttemplate="%{text:,}", textposition="outside")
- fig.update_layout(xaxis_tickangle=-45, showlegend=False)
-
- return fig
-
-
-def create_feature_breakdown_donut(
- grid_data: pd.DataFrame,
- grid_config: str,
- source_color_map: dict[str, str] | None = None,
-) -> go.Figure:
- """Create donut chart for feature breakdown by data source for a specific grid.
-
- Args:
- grid_data: DataFrame with columns: Data Source, Number of Features.
- grid_config: Grid configuration name for the title.
- source_color_map: Optional dictionary mapping data source names to specific colors.
- If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the donut chart visualization.
-
- """
- fig = px.pie(
- grid_data,
- names="Data Source",
- values="Number of Features",
- title=grid_config,
- hole=0.4,
- color_discrete_map=source_color_map,
- color="Data Source",
- )
-
- fig.update_traces(textposition="inside", textinfo="percent")
- fig.update_layout(showlegend=True, height=350)
-
- return fig
-
-
-def create_feature_distribution_pie(
- breakdown_df: pd.DataFrame,
- source_color_map: dict[str, str] | None = None,
-) -> go.Figure:
- """Create pie chart for feature distribution by data source.
-
- Args:
- breakdown_df: DataFrame with columns: Data Source, Number of Features.
- source_color_map: Optional dictionary mapping data source names to specific colors.
- If None, uses default Plotly colors.
-
- Returns:
- Plotly Figure object containing the pie chart visualization.
-
- """
- fig = px.pie(
- breakdown_df,
- names="Data Source",
- values="Number of Features",
- title="Feature Distribution by Data Source",
- hole=0.4,
- color_discrete_map=source_color_map,
- color="Data Source",
- )
-
- fig.update_traces(textposition="inside", textinfo="percent+label")
-
- return fig
diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py
new file mode 100644
index 0000000..34ae0cb
--- /dev/null
+++ b/src/entropice/dashboard/sections/dataset_statistics.py
@@ -0,0 +1,376 @@
+"""Dataset Statistics Section for Entropice Dashboard."""
+
+import pandas as pd
+import streamlit as st
+
+from entropice.dashboard.plots.dataset_statistics import (
+ create_feature_breakdown_donuts_grid,
+ create_feature_count_stacked_bar,
+ create_feature_distribution_pie,
+ create_inference_cells_bar,
+ create_member_details_table,
+ create_sample_count_bar_chart,
+)
+from entropice.dashboard.utils.colors import get_palette
+from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
+from entropice.utils.types import (
+ GridLevel,
+ L2SourceDataset,
+ TargetDataset,
+ Task,
+ TemporalMode,
+ all_l2_source_datasets,
+ grid_configs,
+)
+
+type DatasetStatsCache = dict[GridLevel, dict[TemporalMode, DatasetStatistics]]
+
+
+def render_training_samples_tab(sample_df: pd.DataFrame):
+ """Render overview of sample counts per task+target+grid+level combination."""
+ st.markdown(
+ """
+ This visualization shows the number of available training samples for each combination of:
+ - **Task**: binary, count, density
+ - **Target Dataset**: darts_rts, darts_mllabels
+ - **Grid System**: hex, healpix
+ - **Grid Level**: varying by grid type
+ """
+ )
+
+ # Create and display bar chart
+ fig = create_sample_count_bar_chart(sample_df)
+ st.plotly_chart(fig, width="stretch")
+
+ # Display full table with formatting
+ st.markdown("#### Detailed Sample Counts")
+ # Format numbers with commas
+ sample_df["Samples (Coverage)"] = sample_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
+ # Format coverage as percentage with 2 decimal places
+ sample_df["Coverage %"] = sample_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
+
+ st.dataframe(sample_df, hide_index=True, width="stretch")
+
+
+def render_dataset_characteristics_tab(
+ breakdown_df: pd.DataFrame, sample_df: pd.DataFrame, comparison_df: pd.DataFrame
+):
+ """Render static comparison of feature counts across all grid configurations."""
+ st.markdown(
+ """
+ Comparing dataset characteristics for all grid configurations with all data sources enabled.
+ - **Features**: Total number of input features from all data sources
+ - **Spatial Coverage**: Number of grid cells with complete data coverage
+ """
+ )
+
+ # Get data from cache
+
+ # Create and display stacked bar chart
+ fig = create_feature_count_stacked_bar(breakdown_df)
+ st.plotly_chart(fig, width="stretch")
+
+ # Add spatial coverage metric
+ fig_cells = create_inference_cells_bar(sample_df)
+ st.plotly_chart(fig_cells, width="stretch")
+
+ # Display full comparison table with formatting
+ st.markdown("#### Detailed Comparison Table")
+ st.dataframe(comparison_df, hide_index=True, width="stretch")
+
+
+def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
+ """Render feature breakdown by member across grid configurations and temporal modes."""
+ st.markdown(
+ """
+ Visualizing feature contribution of each data source (member) across all grid configurations
+ and temporal modes. Each donut chart represents the feature breakdown for a specific
+ grid-temporal mode combination.
+ """
+ )
+ # Create and display the grid of donut charts
+ fig = create_feature_breakdown_donuts_grid(breakdown_df)
+ st.plotly_chart(fig, width="stretch")
+
+ # Display full breakdown table with formatting
+ st.markdown("#### Detailed Breakdown Table")
+ st.dataframe(breakdown_df, hide_index=True, width="stretch")
+
+
+@st.fragment
+def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
+ """Render interactive detailed configuration explorer."""
+ st.markdown(
+ """
+ Explore detailed statistics for a specific dataset configuration.
+ Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
+ """
+ )
+
+ # Grid selection
+ grid_options = [gc.display_name for gc in grid_configs]
+
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ selected_grid_display = st.selectbox(
+ "Grid Configuration",
+ options=grid_options,
+ index=0,
+ help="Select the grid system and resolution level",
+ key="feature_grid_select",
+ )
+
+ with col2:
+ # Create temporal mode options with proper formatting
+ temporal_mode_options = ["feature", "synopsis"] + [str(year) for year in [2018, 2019, 2020, 2021, 2022, 2023]]
+ temporal_mode_display = st.selectbox(
+ "Temporal Mode",
+ options=temporal_mode_options,
+ index=0,
+ help="Select the temporal mode (feature, synopsis, or specific year)",
+ key="feature_temporal_mode_select",
+ )
+ # Convert back to proper type
+ selected_temporal_mode: TemporalMode
+ if temporal_mode_display.isdigit():
+ selected_temporal_mode = int(temporal_mode_display) # type: ignore[assignment]
+ else:
+ selected_temporal_mode = temporal_mode_display # pyright: ignore[reportAssignmentType]
+
+ with col3:
+ selected_target: TargetDataset = st.selectbox(
+ "Target Dataset",
+ options=["darts_v1", "darts_mllabels"],
+ index=0,
+ help="Select the target dataset",
+ key="feature_target_select",
+ )
+
+ with col4:
+ selected_task: Task = st.selectbox(
+ "Task",
+ options=["binary", "count", "density", "count_regimes", "density_regimes"],
+ index=0,
+ help="Select the task type",
+ key="feature_task_select",
+ )
+
+ # Find the selected grid config
+ selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
+
+ # Get stats for the selected configuration
+ mode_stats = dataset_stats[selected_grid_config.id]
+
+ # Check if the selected temporal mode has stats
+ if selected_temporal_mode not in mode_stats:
+ st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
+ return
+
+ stats = mode_stats[selected_temporal_mode]
+ available_members = list(stats.members.keys())
+
+ # Members selection
+ st.markdown("#### Select Data Sources")
+
+ # Use columns for checkboxes
+ cols = st.columns(len(all_l2_source_datasets))
+ selected_members: list[L2SourceDataset] = []
+
+ for idx, member in enumerate(all_l2_source_datasets):
+ with cols[idx]:
+ default_value = member in available_members
+ if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
+ selected_members.append(member) # type: ignore[arg-type]
+
+ # Show results if at least one member is selected
+ if selected_members:
+ # Filter to selected members only
+ selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+
+ # Calculate total features for selected members
+ total_features = sum(ms.feature_count for ms in selected_member_stats.values())
+
+ # Get target stats if available
+ if selected_target not in stats.target:
+ st.warning(f"Target {selected_target} is not available for this configuration.")
+ return
+
+ target_stats = stats.target[selected_target]
+
+ # Check if task is available
+ if selected_task not in target_stats:
+ st.warning(
+ f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
+ )
+ return
+
+ task_stats = target_stats[selected_task]
+
+ # High-level metrics
+ st.markdown("#### Configuration Summary")
+ col1, col2, col3, col4, col5 = st.columns(5)
+ with col1:
+ st.metric("Total Features", f"{total_features:,}")
+ with col2:
+ # Calculate minimum cells across all data sources (for inference capability)
+ if selected_member_stats:
+ min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
+ else:
+ min_cells = 0
+ st.metric(
+ "Inference Cells",
+ f"{min_cells:,}",
+ help="Minimum number of cells across all selected data sources",
+ )
+ with col3:
+ st.metric("Data Sources", len(selected_members))
+ with col4:
+ st.metric("Training Samples", f"{task_stats.training_cells:,}")
+ with col5:
+ # Calculate total data points
+ total_points = total_features * task_stats.training_cells
+ st.metric("Total Data Points", f"{total_points:,}")
+
+ # Task-specific statistics
+ st.markdown("#### Task Statistics")
+ task_col1, task_col2, task_col3 = st.columns(3)
+ with task_col1:
+ st.metric("Task Type", selected_task.replace("_", " ").title())
+ with task_col2:
+ st.metric("Coverage", f"{task_stats.coverage:.2f}%")
+ with task_col3:
+ if task_stats.class_counts:
+ st.metric("Number of Classes", len(task_stats.class_counts))
+ else:
+ st.metric("Task Mode", "Regression")
+
+ # Class distribution for classification tasks
+ if task_stats.class_distribution:
+ st.markdown("#### Class Distribution")
+ class_dist_df = pd.DataFrame(
+ [
+ {
+ "Class": class_name,
+ "Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
+ "Percentage": f"{pct:.2f}%",
+ }
+ for class_name, pct in task_stats.class_distribution.items()
+ ]
+ )
+ st.dataframe(class_dist_df, hide_index=True, width="stretch")
+
+ # Feature breakdown by source
+ st.markdown("#### Feature Breakdown by Data Source")
+
+ breakdown_data = []
+ for member, member_stats in selected_member_stats.items():
+ breakdown_data.append(
+ {
+ "Data Source": member,
+ "Number of Features": member_stats.feature_count,
+ "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
+ }
+ )
+
+ breakdown_df = pd.DataFrame(breakdown_data)
+
+ # Get all unique data sources and create color map
+ unique_members_for_color = sorted(selected_member_stats.keys())
+ source_color_map_raw = {}
+ for member in unique_members_for_color:
+ source = member.split("-")[0]
+ n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
+ palette = get_palette(source, n_colors=n_members + 2)
+ idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
+ source_color_map_raw[member] = palette[idx + 1]
+
+ # Create and display pie chart
+ fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
+ st.plotly_chart(fig, width="stretch")
+
+ # Show detailed table
+ st.dataframe(breakdown_df, hide_index=True, width="stretch")
+
+ # Detailed member information
+ with st.expander("📦 Detailed Source Information", expanded=False):
+ # Create detailed table
+ member_details_dict = {
+ member: {
+ "feature_count": ms.feature_count,
+ "variable_names": ms.variable_names,
+ "dimensions": ms.dimensions,
+ "size_bytes": ms.size_bytes,
+ }
+ for member, ms in selected_member_stats.items()
+ }
+ details_df = create_member_details_table(member_details_dict)
+ st.dataframe(details_df, hide_index=True, width="stretch")
+
+ # Individual member details
+ for member, member_stats in selected_member_stats.items():
+ st.markdown(f"### {member}")
+
+ # Variables
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats.variable_names
+ ]
+ )
+ st.markdown(vars_html, unsafe_allow_html=True)
+
+ # Dimensions
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size:,}"
+ for dim_name, dim_size in member_stats.dimensions.items()
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
+
+ st.markdown("---")
+ else:
+ st.info("👆 Select at least one data source to see feature statistics")
+
+
+def render_dataset_statistics(all_stats: DatasetStatsCache):
+ """Render the dataset statistics section with sample and feature counts."""
+ st.header("📈 Dataset Statistics")
+
+ all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
+ training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
+ feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
+ comparison_df = DatasetStatistics.get_comparison_df(all_stats)
+ inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
+
+ # Create tabs for different analysis views
+ analysis_tabs = st.tabs(
+ [
+ "📊 Training Samples",
+ "📈 Dataset Characteristics",
+ "🔍 Feature Breakdown",
+ "⚙️ Configuration Explorer",
+ ]
+ )
+
+ with analysis_tabs[0]:
+ st.subheader("Training Samples by Configuration")
+ render_training_samples_tab(training_sample_df)
+
+ with analysis_tabs[1]:
+ st.subheader("Dataset Characteristics Across Grid Configurations")
+ render_dataset_characteristics_tab(feature_breakdown_df, inference_sample_df, comparison_df)
+
+ with analysis_tabs[2]:
+ st.subheader("Feature Breakdown by Data Source")
+ render_feature_breakdown_tab(feature_breakdown_df)
+
+ with analysis_tabs[3]:
+ st.subheader("Interactive Configuration Explorer")
+ render_configuration_explorer_tab(all_stats)
diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py
new file mode 100644
index 0000000..04b3176
--- /dev/null
+++ b/src/entropice/dashboard/sections/experiment_results.py
@@ -0,0 +1,158 @@
+"""Experiment Results Section for the Entropice Dashboard."""
+
+from datetime import datetime
+
+import streamlit as st
+
+from entropice.dashboard.utils.loaders import TrainingResult
+from entropice.utils.types import (
+ GridConfig,
+)
+
+
+def render_training_results_summary(training_results: list[TrainingResult]):
+ """Render summary metrics for training results."""
+ st.header("📊 Training Results Summary")
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ experiments = {tr.experiment for tr in training_results if tr.experiment}
+ st.metric("Experiments", len(experiments))
+
+ with col2:
+ st.metric("Total Runs", len(training_results))
+
+ with col3:
+ models = {tr.settings.model for tr in training_results}
+ st.metric("Model Types", len(models))
+
+ with col4:
+ latest = training_results[0] # Already sorted by creation time
+ latest_datetime = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d %H:%M")
+ st.metric("Latest Run", latest_datetime)
+
+
+@st.fragment
+def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
+ """Render detailed experiment results table and expandable details."""
+ st.header("🎯 Experiment Results")
+
+ # Filters
+ experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
+ tasks = sorted({tr.settings.task for tr in training_results})
+ models = sorted({tr.settings.model for tr in training_results})
+ grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
+
+ # Create filter columns
+ filter_cols = st.columns(4)
+
+ with filter_cols[0]:
+ if experiments:
+ selected_experiment = st.selectbox(
+ "Experiment",
+ options=["All", *experiments],
+ index=0,
+ )
+ else:
+ selected_experiment = "All"
+
+ with filter_cols[1]:
+ selected_task = st.selectbox(
+ "Task",
+ options=["All", *tasks],
+ index=0,
+ )
+
+ with filter_cols[2]:
+ selected_model = st.selectbox(
+ "Model",
+ options=["All", *models],
+ index=0,
+ )
+
+ with filter_cols[3]:
+ selected_grid = st.selectbox(
+ "Grid",
+ options=["All", *grids],
+ index=0,
+ )
+
+ # Apply filters
+ filtered_results = training_results
+ if selected_experiment != "All":
+ filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
+ if selected_task != "All":
+ filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
+ if selected_model != "All":
+ filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model]
+ if selected_grid != "All":
+ filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
+
+ st.subheader("Results Table")
+
+ summary_df = TrainingResult.to_dataframe(filtered_results)
+ # Display with color coding for best scores
+ st.dataframe(
+ summary_df,
+ width="stretch",
+ hide_index=True,
+ )
+
+ # Expandable details for each result
+ st.subheader("Individual Experiment Details")
+
+ for tr in filtered_results:
+ tr_info = tr.display_info
+ display_name = tr_info.get_display_name("model_first")
+ with st.expander(display_name):
+ col1, col2 = st.columns([1, 2])
+
+ with col1:
+ grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
+ st.write("**Configuration:**")
+ st.write(f"- **Experiment:** {tr.experiment}")
+ st.write(f"- **Task:** {tr.settings.task}")
+ st.write(f"- **Model:** {tr.settings.model}")
+ st.write(f"- **Grid:** {grid_config.display_name}")
+ st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}")
+ st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}")
+ st.write(f"- **Members:** {', '.join(tr.settings.members)}")
+ st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
+ st.write(f"- **Classes:** {tr.settings.classes}")
+
+ st.write("\n**Files:**")
+ for file in tr.files:
+ if file.name == "search_settings.toml":
+ st.write(f"- ⚙️ `{file.name}`")
+ elif file.name == "best_estimator_model.pkl":
+ st.write(f"- 🧮 `{file.name}`")
+ elif file.name == "search_results.parquet":
+ st.write(f"- 📊 `{file.name}`")
+ elif file.name == "predicted_probabilities.parquet":
+ st.write(f"- 🎯 `{file.name}`")
+ else:
+ st.write(f"- 📄 `{file.name}`")
+ with col2:
+ st.write("**CV Score Summary:**")
+
+ # Extract all test scores
+ metric_df = tr.get_metric_dataframe()
+ if metric_df is not None:
+ st.dataframe(metric_df, width="stretch", hide_index=True)
+ else:
+ st.write("No test scores found in results.")
+
+ # Show parameter space explored
+ if "initial_K" in tr.results.columns: # Common parameter
+ st.write("\n**Parameter Ranges Explored:**")
+ for param in ["initial_K", "eps_cl", "eps_e"]:
+ if param in tr.results.columns:
+ min_val = tr.results[param].min()
+ max_val = tr.results[param].max()
+ unique_vals = tr.results[param].nunique()
+ st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
+
+ with st.expander("Show CV Results DataFrame"):
+ st.dataframe(tr.results, width="stretch", hide_index=True)
+
+ st.write(f"\n**Path:** `{tr.path}`")
diff --git a/src/entropice/dashboard/utils/colors.py b/src/entropice/dashboard/utils/colors.py
index 97b3838..aa1a758 100644
--- a/src/entropice/dashboard/utils/colors.py
+++ b/src/entropice/dashboard/utils/colors.py
@@ -90,7 +90,15 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
A list of hex color strings.
"""
- cmap = get_cmap(variable).resampled(n_colors)
+ # Hardcode some common variables to specific colormaps
+ if variable == "ERA5":
+ cmap = load_cmap(name="blue_material").resampled(n_colors)
+ elif variable == "ArcticDEM":
+ cmap = load_cmap(name="deep_purple_material").resampled(n_colors)
+ elif variable == "AlphaEarth":
+ cmap = load_cmap(name="green_material").resampled(n_colors)
+ else:
+ cmap = get_cmap(variable).resampled(n_colors)
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
return colors
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index 93f3c79..83c9796 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -6,7 +6,6 @@ from datetime import datetime
from pathlib import Path
import antimeridian
-import geopandas as gpd
import pandas as pd
import streamlit as st
import toml
@@ -16,9 +15,8 @@ from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
-from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.ml.training import TrainingSettings
-from entropice.utils.types import L2SourceDataset, Task
+from entropice.utils.types import GridConfig
def _fix_hex_geometry(geom):
@@ -32,22 +30,29 @@ def _fix_hex_geometry(geom):
@dataclass
class TrainingResult:
+ """Wrapper for training result data and metadata."""
+
path: Path
+ experiment: str
settings: TrainingSettings
results: pd.DataFrame
- metrics: dict[str, float]
- confusion_matrix: xr.DataArray
+ train_metrics: dict[str, float]
+ test_metrics: dict[str, float]
+ combined_metrics: dict[str, float]
+ confusion_matrix: xr.Dataset | None
created_at: float
available_metrics: list[str]
+ files: list[Path]
@classmethod
- def from_path(cls, result_path: Path) -> "TrainingResult":
+ def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
- metrics_file = result_path / "test_metrics.toml"
+ metrics_file = result_path / "metrics.toml"
confusion_matrix_file = result_path / "confusion_matrix.nc"
+ all_files = list(result_path.iterdir())
if not result_file.exists():
raise FileNotFoundError(f"Missing results file in {result_path}")
if not settings_file.exists():
@@ -56,28 +61,46 @@ class TrainingResult:
raise FileNotFoundError(f"Missing predictions file in {result_path}")
if not metrics_file.exists():
raise FileNotFoundError(f"Missing metrics file in {result_path}")
- if not confusion_matrix_file.exists():
- raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
created_at = result_path.stat().st_ctime
- settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
+ settings_dict = toml.load(settings_file)["settings"]
+
+ # Handle backward compatibility: add missing fields with defaults
+ if "classes" not in settings_dict:
+ settings_dict["classes"] = None
+ if "param_grid" not in settings_dict:
+ settings_dict["param_grid"] = {}
+ if "cv_splits" not in settings_dict:
+ settings_dict["cv_splits"] = 5
+ if "metrics" not in settings_dict:
+ settings_dict["metrics"] = []
+
+ settings = TrainingSettings(**settings_dict)
results = pd.read_parquet(result_file)
- metrics = toml.load(metrics_file)["test_metrics"]
- confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
+ metrics = toml.load(metrics_file)
+ if not confusion_matrix_file.exists():
+ confusion_matrix = None
+ else:
+ confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf")
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
return cls(
path=result_path,
+ experiment=experiment_name or "N/A",
settings=settings,
results=results,
- metrics=metrics,
+ train_metrics=metrics["train_metrics"],
+ test_metrics=metrics["test_metrics"],
+ combined_metrics=metrics["combined_metrics"],
confusion_matrix=confusion_matrix,
created_at=created_at,
available_metrics=available_metrics,
+ files=all_files,
)
@property
def display_info(self) -> TrainingResultDisplayInfo:
+ """Get display information for the training result."""
return TrainingResultDisplayInfo(
task=self.settings.task,
model=self.settings.model,
@@ -126,9 +149,70 @@ class TrainingResult:
st.error(f"Error loading predictions: {e}")
return None
+ def get_metric_dataframe(self) -> pd.DataFrame | None:
+ """Get a DataFrame of available metrics for this training result."""
+ metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")]
+ if not metric_cols:
+ return None
+ metric_data = []
+ for col in metric_cols:
+ metric_name = col.replace("mean_test_", "").replace("neg_", "").title()
+ metrics = self.results[col]
+ # Check if the metric is negative
+ if col.startswith("mean_test_neg_"):
+ task_multiplier = 1 if self.settings.task != "density" else 100
+ task_multiplier *= -1
+ metrics = metrics * task_multiplier
+
+ metric_data.append(
+ {
+ "Metric": metric_name,
+ "Best": f"{metrics.max():.4f}",
+ "Mean": f"{metrics.mean():.4f}",
+ "Std": f"{metrics.std():.4f}",
+ "Worst": f"{metrics.min():.4f}",
+ }
+ )
+
+ return pd.DataFrame(metric_data)
+
+ def _get_best_metric_name(self) -> str:
+ """Get the primary metric name for a given task."""
+ match self.settings.task:
+ case "binary":
+ return "f1"
+ case "count_regimes" | "density_regimes":
+ return "f1_weighted"
+ case _: # regression tasks
+ return "r2"
+
+ @staticmethod
+ def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame:
+ """Convert a list of TrainingResult objects to a DataFrame for display."""
+ records = []
+ for tr in training_results:
+ info = tr.display_info
+ best_metric_name = tr._get_best_metric_name()
+
+ record = {
+ "Experiment": tr.experiment if tr.experiment else "N/A",
+ "Task": info.task,
+ "Model": info.model,
+ "Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
+ "Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
+ "Score-Metric": best_metric_name.title(),
+ "Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name),
+ "Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
+ "Trials": len(tr.results),
+ "Path": str(tr.path.name),
+ }
+ records.append(record)
+ return pd.DataFrame.from_records(records)
+
@st.cache_data
def load_all_training_results() -> list[TrainingResult]:
+ """Load all training results from the results directory."""
results_dir = entropice.utils.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
for result_path in results_dir.iterdir():
@@ -136,50 +220,22 @@ def load_all_training_results() -> list[TrainingResult]:
continue
try:
training_result = TrainingResult.from_path(result_path)
+ training_results.append(training_result)
except FileNotFoundError as e:
- st.warning(f"Skipping incomplete training result: {e}")
- continue
- training_results.append(training_result)
+ is_experiment_dir = False
+ for experiment_path in result_path.iterdir():
+ if not experiment_path.is_dir():
+ continue
+ try:
+ experiment_name = experiment_path.name
+ training_result = TrainingResult.from_path(experiment_path, experiment_name)
+ training_results.append(training_result)
+ is_experiment_dir = True
+ except FileNotFoundError as e2:
+ st.warning(f"Skipping incomplete training result: {e2}")
+ if not is_experiment_dir:
+ st.warning(f"Skipping incomplete training result: {e}")
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
-
-
-def load_all_training_data(
- e: DatasetEnsemble,
-) -> dict[Task, CategoricalTrainingDataset]:
- """Load training data for all three tasks.
-
- Args:
- e: DatasetEnsemble object.
-
- Returns:
- Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
-
- """
- dataset = e.create(filter_target_col=e.covcol)
- return {
- "binary": e._cat_and_split(dataset, "binary", device="cpu"),
- "count": e._cat_and_split(dataset, "count", device="cpu"),
- "density": e._cat_and_split(dataset, "density", device="cpu"),
- }
-
-
-def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
- """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
-
- Args:
- e: DatasetEnsemble object.
- source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
-
- Returns:
- xarray.Dataset with the raw data for the specified source.
-
- """
- targets = e._read_target()
-
- # Load the member data lazily to get metadata
- ds = e._read_member(source, targets, lazy=False)
-
- return ds, targets
diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py
index 004c7f5..9e9ddaf 100644
--- a/src/entropice/dashboard/utils/stats.py
+++ b/src/entropice/dashboard/utils/stats.py
@@ -3,29 +3,28 @@
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
"""
+import pickle
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Literal
-import geopandas as gpd
import pandas as pd
-import streamlit as st
-import xarray as xr
from stopuhr import stopwatch
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult
-from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
+from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
- Grid,
GridLevel,
L2SourceDataset,
TargetDataset,
Task,
+ TemporalMode,
all_l2_source_datasets,
all_target_datasets,
all_tasks,
+ all_temporal_modes,
grid_configs,
)
@@ -41,90 +40,63 @@ class MemberStatistics:
size_bytes: int # Size of this member's data on disk in bytes
@classmethod
- def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
- if member == "AlphaEarth":
- store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
- elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
- era5_agg = member.split("-")[1]
- store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
- elif member == "ArcticDEM":
- store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
- else:
- raise NotImplementedError(f"Member {member} not implemented.")
+ def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
+ """Pre-compute the statistics for a specific dataset member."""
+ member_stats = {}
+ for member in all_l2_source_datasets:
+ ds = e.read_member(member, lazy=True)
+ size_bytes = ds.nbytes
- size_bytes = store.stat().st_size
- ds = xr.open_zarr(store, consolidated=False)
+ n_cols_member = len(ds.data_vars)
+ for dim in ds.sizes:
+ if dim != "cell_ids":
+ n_cols_member *= ds.sizes[dim]
- # Delete all coordinates which are not in the dimension
- for coord in ds.coords:
- if coord not in ds.dims:
- ds = ds.drop_vars(coord)
- n_cols_member = len(ds.data_vars)
- for dim in ds.sizes:
- if dim != "cell_ids":
- n_cols_member *= ds.sizes[dim]
-
- return cls(
- feature_count=n_cols_member,
- variable_names=list(ds.data_vars),
- dimensions=dict(ds.sizes),
- coordinates=list(ds.coords),
- size_bytes=size_bytes,
- )
+ member_stats[member] = cls(
+ feature_count=n_cols_member,
+ variable_names=list(ds.data_vars),
+ dimensions=dict(ds.sizes),
+ coordinates=list(ds.coords),
+ size_bytes=size_bytes,
+ )
+ return member_stats
@dataclass(frozen=True)
class TargetStatistics:
"""Statistics for a specific target dataset."""
- training_cells: dict[Task, int] # Number of cells used for training
- coverage: dict[Task, float] # Percentage of total cells covered by training data
- class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
- class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
- size_bytes: int # Size of the target dataset on disk in bytes
+ training_cells: int # Number of cells used for training
+ coverage: float # Percentage of total cells covered by training data
+ class_counts: dict[str, int] | None # Class name to count mapping
+ class_distribution: dict[str, float] | None # Class name to percentage mapping
@classmethod
- def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
- if target == "darts_rts":
- target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
- elif target == "darts_mllabels":
- target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
- else:
- raise NotImplementedError(f"Target {target} not implemented.")
- target_gdf = gpd.read_parquet(target_store)
- size_bytes = target_store.stat().st_size
- training_cells: dict[Task, int] = {}
- training_coverage: dict[Task, float] = {}
- class_counts: dict[Task, dict[str, int]] = {}
- class_distribution: dict[Task, dict[str, float]] = {}
+ def compute(cls, e: DatasetEnsemble, target: TargetDataset, total_cells: int) -> dict[Task, "TargetStatistics"]:
+ """Pre-compute the statistics for a specific target dataset."""
+ target_stats = {}
for task in all_tasks:
- task_col = taskcol[task][target]
- cov_col = covcol[target]
+ targets = e.get_targets(target=target, task=task)
- task_gdf = target_gdf[target_gdf[cov_col]]
- training_cells[task] = len(task_gdf)
- training_coverage[task] = len(task_gdf) / total_cells * 100
+ training_cells = len(targets)
+ training_coverage = len(targets) / total_cells * 100
- model_labels = task_gdf[task_col].dropna()
- if task == "binary":
- binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
- elif task == "count":
- binned = bin_values(model_labels.astype(int), task=task)
- elif task == "density":
- binned = bin_values(model_labels, task=task)
+ if task in ["count", "density"]:
+ class_counts = None
+ class_distribution = None
else:
- raise ValueError("Invalid task.")
- counts = binned.value_counts()
- distribution = counts / counts.sum() * 100
- class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
- class_distribution[task] = distribution.to_dict()
- return TargetStatistics(
- training_cells=training_cells,
- coverage=training_coverage,
- class_counts=class_counts,
- class_distribution=class_distribution,
- size_bytes=size_bytes,
- )
+ assert targets["y"].dtype == "category", "Classification tasks must have categorical target dtype."
+ counts = targets["y"].value_counts()
+ distribution = counts / counts.sum() * 100
+ class_counts = counts.to_dict()
+ class_distribution = distribution.to_dict()
+ target_stats[task] = cls(
+ training_cells=training_cells,
+ coverage=training_coverage,
+ class_counts=class_counts,
+ class_distribution=class_distribution,
+ )
+ return target_stats
@dataclass(frozen=True)
@@ -139,204 +111,167 @@ class DatasetStatistics:
total_cells: int # Total number of grid cells potentially covered
size_bytes: int # Size of the dataset on disk in bytes
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
- target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
+ target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
@staticmethod
- def get_sample_count_df(
- all_stats: dict[GridLevel, "DatasetStatistics"],
+ def get_target_sample_count_df(
+ all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for grid_config in grid_configs:
- stats = all_stats[grid_config.id]
- for target_name, target_stats in stats.target.items():
- for task in all_tasks:
- training_cells = target_stats.training_cells[task]
- coverage = target_stats.coverage[task]
+ mode_stats = all_stats[grid_config.id]
+ for temporal_mode, stats in mode_stats.items():
+ for target_name, target_stats in stats.target.items():
+ # We can assume that all tasks have equal counts, since they only affect distribution, not presence
+ target_task_stats = target_stats["binary"]
+ n_samples = target_task_stats.training_cells
+ coverage = target_task_stats.coverage
+ for task, target_task_stats in target_stats.items():
+ assert n_samples == target_task_stats.training_cells, (
+ "Inconsistent sample counts across tasks for the same target."
+ )
+ assert coverage == target_task_stats.coverage, (
+ "Inconsistent coverage across tasks for the same target."
+ )
rows.append(
{
"Grid": grid_config.display_name,
- "Target": target_name.replace("darts_", ""),
- "Task": task.capitalize(),
- "Samples (Coverage)": training_cells,
+ "Target": target_name.replace("_", " ").title(),
+ "Temporal Mode": str(temporal_mode).capitalize(),
+ "Samples (Coverage)": n_samples,
"Coverage %": coverage,
"Grid_Level_Sort": grid_config.sort_key,
}
)
- return pd.DataFrame(rows)
+ return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
@staticmethod
- def get_feature_count_df(
- all_stats: dict[GridLevel, "DatasetStatistics"],
+ def get_inference_sample_count_df(
+ all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for grid_config in grid_configs:
- stats = all_stats[grid_config.id]
- data_sources = list(stats.members.keys())
-
- # Determine minimum cells across all data sources
+ mode_stats = all_stats[grid_config.id]
+ # We can assume that all temporal modes have equal counts, since they only affect features, not samples
+ stats = mode_stats["feature"]
+ assert len(stats.members) > 0, "Dataset must have at least one member."
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
- # Get sample count from first target dataset (darts_rts)
- first_target = stats.target["darts_rts"]
- total_samples = first_target.training_cells["binary"]
+ for temporal_mode, stats in mode_stats.items():
+ assert len(stats.members) > 0, "Dataset must have at least one member."
+ assert (
+ min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) == min_cells
+ ), "Inconsistent inference cell counts across temporal modes."
rows.append(
{
"Grid": grid_config.display_name,
- "Total Features": stats.total_features,
- "Data Sources": len(data_sources),
"Inference Cells": min_cells,
- "Total Samples": total_samples,
"Grid_Level_Sort": grid_config.sort_key,
}
)
- return pd.DataFrame(rows)
+ return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
+
+ @staticmethod
+ def get_comparison_df(
+ all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
+ ) -> pd.DataFrame:
+ """Convert comparison data to DataFrame for detailed table."""
+ rows = []
+ for grid_config in grid_configs:
+ mode_stats = all_stats[grid_config.id]
+ for temporal_mode, stats in mode_stats.items():
+ min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
+ rows.append(
+ {
+ "Grid": grid_config.display_name,
+ "Temporal Mode": str(temporal_mode).capitalize(),
+ "Total Features": f"{stats.total_features:,}",
+ "Total Cells": f"{stats.total_cells:,}",
+ "Min Cells": f"{min_cells:,}",
+ "Size (MB)": f"{stats.size_bytes / (1024 * 1024):.2f}",
+ "Grid_Level_Sort": grid_config.sort_key,
+ }
+ )
+ return pd.DataFrame(rows).sort_values(["Grid_Level_Sort", "Temporal Mode"]).drop(columns="Grid_Level_Sort")
@staticmethod
def get_feature_breakdown_df(
- all_stats: dict[GridLevel, "DatasetStatistics"],
+ all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for grid_config in grid_configs:
- stats = all_stats[grid_config.id]
- for member_name, member_stats in stats.members.items():
- rows.append(
- {
- "Grid": grid_config.display_name,
- "Data Source": member_name,
- "Number of Features": member_stats.feature_count,
- "Grid_Level_Sort": grid_config.sort_key,
- }
- )
- return pd.DataFrame(rows)
+ mode_stats = all_stats[grid_config.id]
+ added_year_temporal_mode = False
+ for temporal_mode, stats in mode_stats.items():
+ if added_year_temporal_mode:
+ continue
+ if isinstance(temporal_mode, int):
+ # Only add one year-based temporal mode for the breakdown
+ added_year_temporal_mode = True
+ temporal_mode = "Year-Based"
+ for member_name, member_stats in stats.members.items():
+ rows.append(
+ {
+ "Grid": grid_config.display_name,
+ "Temporal Mode": temporal_mode.capitalize(),
+ "Member": member_name,
+ "Number of Features": member_stats.feature_count,
+ "Grid_Level_Sort": grid_config.sort_key,
+ }
+ )
+ return (
+ pd.DataFrame(rows)
+ .sort_values(["Grid_Level_Sort", "Temporal Mode", "Member"])
+ .drop(columns="Grid_Level_Sort")
+ )
-@st.cache_data
-def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
- dataset_stats: dict[GridLevel, DatasetStatistics] = {}
+# @st.cache_data # ty:ignore[invalid-argument-type]
+def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode, DatasetStatistics]]:
+ """Precompute dataset statistics for all grid-level combinations and temporal modes."""
+ cache_file = entropice.utils.paths.get_dataset_stats_cache()
+ if cache_file.exists():
+ with open(cache_file, "rb") as f:
+ dataset_stats = pickle.load(f)
+ return dataset_stats
+
+ dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {}
for grid_config in grid_configs:
+ dataset_stats[grid_config.id] = {}
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell."
- target_statistics: dict[TargetDataset, TargetStatistics] = {}
- for target in all_target_datasets:
- target_statistics[target] = TargetStatistics.compute(
- grid=grid_config.grid,
- level=grid_config.level,
- target=target,
+ for temporal_mode in all_temporal_modes:
+ e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode)
+ target_statistics = {}
+ for target in all_target_datasets:
+ if isinstance(temporal_mode, int) and target == "darts_mllabels":
+ # darts_mllabels does not support year-based temporal modes
+ continue
+ target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
+ member_statistics = MemberStatistics.compute(e)
+
+ total_features = sum(ms.feature_count for ms in member_statistics.values())
+ total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
+ dataset_stats[grid_config.id][temporal_mode] = DatasetStatistics(
+ total_features=total_features,
total_cells=total_cells,
- )
- member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
- for member in all_l2_source_datasets:
- member_statistics[member] = MemberStatistics.compute(
- grid=grid_config.grid, level=grid_config.level, member=member
+ size_bytes=total_size_bytes,
+ members=member_statistics,
+ target=target_statistics,
)
- total_features = sum(ms.feature_count for ms in member_statistics.values())
- total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
- ts.size_bytes for ts in target_statistics.values()
- )
- dataset_stats[grid_config.id] = DatasetStatistics(
- total_features=total_features,
- total_cells=total_cells,
- size_bytes=total_size_bytes,
- members=member_statistics,
- target=target_statistics,
- )
+ with open(cache_file, "wb") as f:
+ pickle.dump(dataset_stats, f)
return dataset_stats
-@dataclass(frozen=True)
-class EnsembleMemberStatistics:
- n_features: int # Number of features from this member in the ensemble
- p_features: float # Percentage of features from this member in the ensemble
- n_nanrows: int # Number of rows which contain any NaN
- size_bytes: int # Size of this member's data in the ensemble in bytes
-
- @classmethod
- def compute(
- cls,
- dataset: gpd.GeoDataFrame,
- member: L2SourceDataset,
- n_features: int,
- ) -> "EnsembleMemberStatistics":
- if member == "AlphaEarth":
- member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
- elif member == "ArcticDEM":
- member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
- elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
- era5_cols = dataset.columns.str.startswith("era5_")
- cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
- cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
- cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
- if member == "ERA5-yearly":
- member_dataset = dataset[cols_with_three_splits]
- elif member == "ERA5-seasonal":
- member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
- elif member == "ERA5-shoulder":
- member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
- else:
- raise NotImplementedError(f"Member {member} not implemented.")
- size_bytes_member = member_dataset.memory_usage(deep=True).sum()
- n_features_member = len(member_dataset.columns)
- p_features_member = n_features_member / n_features
- n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
- return EnsembleMemberStatistics(
- n_features=n_features_member,
- p_features=p_features_member,
- n_nanrows=n_rows_with_nan_member,
- size_bytes=size_bytes_member,
- )
-
-
-@dataclass(frozen=True)
-class EnsembleDatasetStatistics:
- """Statistics for a specified composition / ensemble at a specific grid and level.
-
- These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
- That way, the real number of features and cells after NaN filtering can be reported.
- """
-
- n_features: int # Number of features in the dataset
- n_cells: int # Number of grid cells covered
- n_nanrows: int # Number of rows which contain any NaN
- size_bytes: int # Size of the dataset on disk in bytes
- members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
-
- @classmethod
- def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
- dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
- # Assert that no column in all-nan
- assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
-
- size_bytes = dataset.memory_usage(deep=True).sum()
- n_features = len(dataset.columns)
- n_cells = len(dataset)
-
- # Number of rows which contain any NaN
- n_rows_with_nan = dataset.isna().any(axis=1).sum()
-
- member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
- for member in ensemble.members:
- member_statistics[member] = EnsembleMemberStatistics.compute(
- dataset=dataset,
- member=member,
- n_features=n_features,
- )
- return cls(
- n_features=n_features,
- n_cells=n_cells,
- n_nanrows=n_rows_with_nan,
- size_bytes=size_bytes,
- members=member_statistics,
- )
-
-
@dataclass(frozen=True)
class TrainingDatasetStatistics:
n_samples: int # Total number of samples in the dataset
@@ -346,14 +281,14 @@ class TrainingDatasetStatistics:
n_test_samples: int # Number of cells used for testing
train_test_ratio: float # Ratio of training to test samples
- class_labels: list[str] # Ordered list of class labels
- class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
- n_classes: int # Number of classes
+ class_labels: list[str] | None # Ordered list of class labels
+ class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] | None # Min/max raw values
+ n_classes: int | None # Number of classes
- training_class_counts: dict[str, int] # Class counts in training set
- training_class_distribution: dict[str, float] # Class percentages in training set
- test_class_counts: dict[str, int] # Class counts in test set
- test_class_distribution: dict[str, float] # Class percentages in test set
+ training_class_counts: dict[str, int] | None # Class counts in training set
+ training_class_distribution: dict[str, float] | None # Class percentages in training set
+ test_class_counts: dict[str, int] | None # Class counts in test set
+ test_class_distribution: dict[str, float] | None # Class percentages in test set
raw_value_min: float # Minimum raw target value
raw_value_max: float # Maximum raw target value
@@ -361,7 +296,7 @@ class TrainingDatasetStatistics:
raw_value_median: float # Median raw target value
raw_value_std: float # Standard deviation of raw target values
- imbalance_ratio: float # Smallest class count / largest class count (overall)
+ imbalance_ratio: float | None # Smallest class count / largest class count (overall)
size_bytes: int # Total memory usage of features in bytes
@classmethod
@@ -369,58 +304,64 @@ class TrainingDatasetStatistics:
cls,
ensemble: DatasetEnsemble,
task: Task,
- dataset: gpd.GeoDataFrame | None = None,
+ target: TargetDataset,
) -> "TrainingDatasetStatistics":
- dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
- categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
+ training_dataset = ensemble.create_training_set(task=task, target=target)
# Sample counts
- n_samples = len(categorical_dataset)
- n_training_samples = len(categorical_dataset.y.train)
- n_test_samples = len(categorical_dataset.y.test)
+ n_samples = len(training_dataset)
+ n_training_samples = (training_dataset.split == "train").sum()
+ n_test_samples = (training_dataset.split == "test").sum()
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
# Feature statistics
- n_features = len(categorical_dataset.X.data.columns)
- feature_names = list(categorical_dataset.X.data.columns)
- size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
+ n_features = len(training_dataset.features.columns)
+ feature_names = training_dataset.features.columns.tolist()
+ size_bytes = training_dataset.features.memory_usage(deep=True).sum()
# Class information
- class_labels = categorical_dataset.y.labels
- class_intervals = categorical_dataset.y.intervals
- n_classes = len(class_labels)
+ class_labels = training_dataset.target_labels
+ if class_labels is None:
+ class_intervals = None
+ n_classes = None
+ training_class_counts = None
+ training_class_distribution = None
+ test_class_counts = None
+ test_class_distribution = None
+ imbalance_ratio = None
+ else:
+ class_intervals = training_dataset.target_intervals
+ n_classes = len(class_labels)
- # Training class distribution
- train_y_series = pd.Series(categorical_dataset.y.train)
- train_counts = train_y_series.value_counts().sort_index()
- training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
- train_total = sum(training_class_counts.values())
- training_class_distribution = {
- k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
- }
+ train_y = training_dataset.targets[training_dataset.split == "train"]["y"]
+ test_y = training_dataset.targets[training_dataset.split == "test"]["y"]
- # Test class distribution
- test_y_series = pd.Series(categorical_dataset.y.test)
- test_counts = test_y_series.value_counts().sort_index()
- test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
- test_total = sum(test_class_counts.values())
- test_class_distribution = {
- k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
- }
+ train_counts = train_y.value_counts().sort_index()
+ training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
+ train_total = sum(training_class_counts.values())
+ training_class_distribution = {
+ k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
+ }
+ test_counts = test_y.value_counts().sort_index()
+ test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
+ test_total = sum(test_class_counts.values())
+ test_class_distribution = {
+ k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
+ }
+
+ # Imbalance ratio (smallest class / largest class across both splits)
+ all_counts = list(train_counts.values()) + list(test_class_counts.values())
+ nonzero_counts = [c for c in all_counts if c > 0]
+ imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
# Raw value statistics
- raw_values = categorical_dataset.y.raw_values
+ raw_values = training_dataset.targets["z"]
raw_value_min = float(raw_values.min())
raw_value_max = float(raw_values.max())
raw_value_mean = float(raw_values.mean())
raw_value_median = float(raw_values.median())
raw_value_std = float(raw_values.std())
- # Imbalance ratio (smallest class / largest class across both splits)
- all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
- nonzero_counts = [c for c in all_counts if c > 0]
- imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
-
return cls(
n_samples=n_samples,
n_features=n_features,
diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py
index 799a392..0a87a8e 100644
--- a/src/entropice/dashboard/views/overview_page.py
+++ b/src/entropice/dashboard/views/overview_page.py
@@ -1,508 +1,17 @@
"""Overview page: List of available result directories with some summary statistics."""
-from datetime import datetime
-from typing import cast
-
-import pandas as pd
import streamlit as st
from stopuhr import stopwatch
-from entropice.dashboard.plots.overview import (
- create_feature_breakdown_donut,
- create_feature_count_stacked_bar,
- create_feature_distribution_pie,
- create_inference_cells_bar,
- create_sample_count_bar_chart,
+from entropice.dashboard.sections.dataset_statistics import render_dataset_statistics
+from entropice.dashboard.sections.experiment_results import (
+ render_experiment_results,
+ render_training_results_summary,
)
-from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import (
- DatasetStatistics,
load_all_default_dataset_statistics,
)
-from entropice.utils.types import (
- GridConfig,
- L2SourceDataset,
- TargetDataset,
- grid_configs,
-)
-
-
-def render_sample_count_overview():
- """Render overview of sample counts per task+target+grid+level combination."""
- st.markdown(
- """
- This visualization shows the number of available training samples for each combination of:
- - **Task**: binary, count, density
- - **Target Dataset**: darts_rts, darts_mllabels
- - **Grid System**: hex, healpix
- - **Grid Level**: varying by grid type
- """
- )
-
- # Get sample count DataFrame from cache
- all_stats = load_all_default_dataset_statistics()
- sample_df = DatasetStatistics.get_sample_count_df(all_stats)
-
- # Get color palettes for each target dataset
- n_tasks = sample_df["Task"].nunique()
- target_color_maps = {
- "rts": get_palette("task_types", n_colors=n_tasks),
- "mllabels": get_palette("data_sources", n_colors=n_tasks),
- }
-
- # Create and display bar chart
- fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
- st.plotly_chart(fig, use_container_width=True)
-
- # Display full table with formatting
- st.markdown("#### Detailed Sample Counts")
- display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
-
- # Format numbers with commas
- display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
- # Format coverage as percentage with 2 decimal places
- display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
-
- st.dataframe(display_df, hide_index=True, use_container_width=True)
-
-
-def render_feature_count_comparison():
- """Render static comparison of feature counts across all grid configurations."""
- st.markdown(
- """
- Comparing dataset characteristics for all grid configurations with all data sources enabled.
- - **Features**: Total number of input features from all data sources
- - **Spatial Coverage**: Number of grid cells with complete data coverage
- """
- )
-
- # Get data from cache
- all_stats = load_all_default_dataset_statistics()
- comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
- breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
- breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
-
- # Get all unique data sources and create color map
- unique_sources = sorted(breakdown_df["Data Source"].unique())
- n_sources = len(unique_sources)
- source_color_list = get_palette("data_sources", n_colors=n_sources)
- source_color_map = dict(zip(unique_sources, source_color_list))
-
- # Create and display stacked bar chart
- fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
- st.plotly_chart(fig, use_container_width=True)
-
- # Add spatial coverage metric
- n_grids = len(comparison_df)
- grid_colors = get_palette("grid_configs", n_colors=n_grids)
-
- fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
- st.plotly_chart(fig_cells, use_container_width=True)
-
- # Display full comparison table with formatting
- st.markdown("#### Detailed Comparison Table")
- display_df = comparison_df[
- [
- "Grid",
- "Total Features",
- "Data Sources",
- "Inference Cells",
- ]
- ].copy()
-
- # Format numbers with commas
- for col in ["Total Features", "Inference Cells"]:
- display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
-
- st.dataframe(display_df, hide_index=True, use_container_width=True)
-
-
-@st.fragment
-def render_feature_count_explorer():
- """Render interactive detailed configuration explorer using fragments."""
- st.markdown("Select specific grid configuration and data sources for detailed statistics")
-
- # Grid selection
- grid_options = [gc.display_name for gc in grid_configs]
-
- col1, col2 = st.columns(2)
-
- with col1:
- grid_level_combined = st.selectbox(
- "Grid Configuration",
- options=grid_options,
- index=0,
- help="Select the grid system and resolution level",
- key="feature_grid_select",
- )
-
- with col2:
- target = st.selectbox(
- "Target Dataset",
- options=["darts_rts", "darts_mllabels"],
- index=0,
- help="Select the target dataset",
- key="feature_target_select",
- )
-
- # Find the selected grid config
- selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
-
- # Get available members from the stats
- all_stats = load_all_default_dataset_statistics()
- stats = all_stats[selected_grid_config.id]
- available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
-
- # Members selection
- st.markdown("#### Select Data Sources")
-
- all_members = cast(
- list[L2SourceDataset],
- ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
- )
-
- # Use columns for checkboxes
- cols = st.columns(len(all_members))
- selected_members: list[L2SourceDataset] = []
-
- for idx, member in enumerate(all_members):
- with cols[idx]:
- default_value = member in available_members
- if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
- selected_members.append(cast(L2SourceDataset, member))
-
- # Show results if at least one member is selected
- if selected_members:
- # Get statistics from cache (already loaded)
- grid_stats = all_stats[selected_grid_config.id]
-
- # Filter to selected members only
- selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
-
- # Calculate total features for selected members
- total_features = sum(ms.feature_count for ms in selected_member_stats.values())
-
- # Get target stats
- target_stats = grid_stats.target[cast(TargetDataset, target)]
-
- # High-level metrics
- col1, col2, col3, col4, col5 = st.columns(5)
- with col1:
- st.metric("Total Features", f"{total_features:,}")
- with col2:
- # Calculate minimum cells across all data sources (for inference capability)
- min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
- st.metric(
- "Inference Cells",
- f"{min_cells:,}",
- help="Number of union of cells across all data sources",
- )
- with col3:
- st.metric("Data Sources", len(selected_members))
- with col4:
- # Use binary task training cells as sample count
- st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
- with col5:
- # Calculate total data points
- total_points = total_features * target_stats.training_cells["binary"]
- st.metric("Total Data Points", f"{total_points:,}")
-
- # Feature breakdown by source
- st.markdown("#### Feature Breakdown by Data Source")
-
- breakdown_data = []
- for member, member_stats in selected_member_stats.items():
- breakdown_data.append(
- {
- "Data Source": member,
- "Number of Features": member_stats.feature_count,
- "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
- }
- )
-
- breakdown_df = pd.DataFrame(breakdown_data)
-
- # Get all unique data sources and create color map
- unique_sources = sorted(breakdown_df["Data Source"].unique())
- n_sources = len(unique_sources)
- source_color_list = get_palette("data_sources", n_colors=n_sources)
- source_color_map = dict(zip(unique_sources, source_color_list))
-
- # Create and display pie chart
- fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
- st.plotly_chart(fig, width="stretch")
-
- # Show detailed table
- st.dataframe(breakdown_df, hide_index=True, width="stretch")
-
- # Detailed member information
- with st.expander("📦 Detailed Source Information", expanded=False):
- for member, member_stats in selected_member_stats.items():
- st.markdown(f"### {member}")
-
- metric_cols = st.columns(4)
- with metric_cols[0]:
- st.metric("Features", member_stats.feature_count)
- with metric_cols[1]:
- st.metric("Variables", len(member_stats.variable_names))
- with metric_cols[2]:
- dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
- st.metric("Shape", dim_str)
- with metric_cols[3]:
- total_points = 1
- for dim_size in member_stats.dimensions.values():
- total_points *= dim_size
- st.metric("Data Points", f"{total_points:,}")
-
- # Variables
- st.markdown("**Variables:**")
- vars_html = " ".join(
- [
- f'{v}'
- for v in member_stats.variable_names
- ]
- )
- st.markdown(vars_html, unsafe_allow_html=True)
-
- # Dimensions
- st.markdown("**Dimensions:**")
- dim_html = " ".join(
- [
- f''
- f"{dim_name}: {dim_size}"
- for dim_name, dim_size in member_stats.dimensions.items()
- ]
- )
- st.markdown(dim_html, unsafe_allow_html=True)
-
- # Size on disk
- size_mb = member_stats.size_bytes / (1024 * 1024)
- st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
-
- st.markdown("---")
- else:
- st.info("👆 Select at least one data source to see feature statistics")
-
-
-def render_dataset_analysis():
- """Render the dataset analysis section with sample and feature counts."""
- st.header("📈 Dataset Analysis")
-
- # Create tabs for different analysis views
- analysis_tabs = st.tabs(
- [
- "📊 Training Samples",
- "📈 Dataset Characteristics",
- "🔍 Feature Breakdown",
- "⚙️ Configuration Explorer",
- ]
- )
-
- with analysis_tabs[0]:
- st.subheader("Training Samples by Configuration")
- render_sample_count_overview()
-
- with analysis_tabs[1]:
- st.subheader("Dataset Characteristics Across Grid Configurations")
- render_feature_count_comparison()
-
- with analysis_tabs[2]:
- st.subheader("Feature Breakdown by Data Source")
- # Get data from cache
- all_stats = load_all_default_dataset_statistics()
- comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
- breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
- breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
-
- # Get all unique data sources and create color map
- unique_sources = sorted(breakdown_df["Data Source"].unique())
- n_sources = len(unique_sources)
- source_color_list = get_palette("data_sources", n_colors=n_sources)
- source_color_map = dict(zip(unique_sources, source_color_list))
-
- st.markdown("Showing percentage contribution of each data source across all grid configurations")
-
- # Sparse Resolution girds
- for res in ["sparse", "low", "medium"]:
- cols = st.columns(2)
- with cols[0]:
- grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
- for gc in grid_configs_res:
- grid_display = gc.display_name
- grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
- fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
- st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
- with cols[1]:
- grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
- for gc in grid_configs_res:
- grid_display = gc.display_name
- grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
- fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
- st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
-
- # Create donut charts for each grid configuration
- # num_grids = len(comparison_df)
- # cols_per_row = 3
- # num_rows = (num_grids + cols_per_row - 1) // cols_per_row
-
- # for row_idx in range(num_rows):
- # cols = st.columns(cols_per_row)
- # for col_idx in range(cols_per_row):
- # grid_idx = row_idx * cols_per_row + col_idx
- # if grid_idx < num_grids:
- # grid_config = comparison_df.iloc[grid_idx]["Grid"]
- # grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
-
- # with cols[col_idx]:
- # fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
- # st.plotly_chart(fig, use_container_width=True)
-
- with analysis_tabs[3]:
- st.subheader("Interactive Configuration Explorer")
- render_feature_count_explorer()
-
-
-def render_training_results_summary(training_results):
- """Render summary metrics for training results."""
- st.header("📊 Training Results Summary")
- col1, col2, col3, col4 = st.columns(4)
-
- with col1:
- tasks = {tr.settings.task for tr in training_results}
- st.metric("Tasks", len(tasks))
-
- with col2:
- grids = {tr.settings.grid for tr in training_results}
- st.metric("Grid Types", len(grids))
-
- with col3:
- models = {tr.settings.model for tr in training_results}
- st.metric("Model Types", len(models))
-
- with col4:
- latest = training_results[0] # Already sorted by creation time
- latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
- st.metric("Latest Run", latest_date)
-
-
-def render_experiment_results(training_results):
- """Render detailed experiment results table and expandable details."""
- st.header("🎯 Experiment Results")
- st.subheader("Results Table")
-
- # Build a summary dataframe
- summary_data = []
- for tr in training_results:
- # Extract best scores from the results dataframe
- score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
-
- best_scores = {}
- for col in score_cols:
- metric_name = col.replace("mean_test_", "")
- best_score = tr.results[col].max()
- best_scores[metric_name] = best_score
-
- # Get primary metric (usually the first one or accuracy)
- primary_metric = (
- "accuracy"
- if "mean_test_accuracy" in tr.results.columns
- else score_cols[0].replace("mean_test_", "")
- if score_cols
- else "N/A"
- )
- primary_score = best_scores.get(primary_metric, 0.0)
-
- summary_data.append(
- {
- "Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
- "Task": tr.settings.task,
- "Grid": tr.settings.grid,
- "Level": tr.settings.level,
- "Model": tr.settings.model,
- f"Best {primary_metric.title()}": f"{primary_score:.4f}",
- "Trials": len(tr.results),
- "Path": str(tr.path.name),
- }
- )
-
- summary_df = pd.DataFrame(summary_data)
-
- # Display with color coding for best scores
- st.dataframe(
- summary_df,
- width="stretch",
- hide_index=True,
- )
-
- st.divider()
-
- # Expandable details for each result
- st.subheader("Individual Experiment Details")
-
- for tr in training_results:
- display_name = (
- f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
- )
- with st.expander(display_name):
- col1, col2 = st.columns([1, 2])
-
- with col1:
- st.write("**Configuration:**")
- st.write(f"- **Task:** {tr.settings.task}")
- st.write(f"- **Grid:** {tr.settings.grid}")
- st.write(f"- **Level:** {tr.settings.level}")
- st.write(f"- **Model:** {tr.settings.model}")
- st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
- st.write(f"- **Classes:** {tr.settings.classes}")
-
- st.write("\n**Files:**")
- st.write("- 📊 search_results.parquet")
- st.write("- 🧮 best_estimator_state.nc")
- st.write("- 🎯 predicted_probabilities.parquet")
- st.write("- ⚙️ search_settings.toml")
-
- with col2:
- st.write("**Best Scores:**")
-
- # Extract all test scores
- score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
-
- if score_cols:
- metric_data = []
- for col in score_cols:
- metric_name = col.replace("mean_test_", "").title()
- best_score = tr.results[col].max()
- mean_score = tr.results[col].mean()
- std_score = tr.results[col].std()
-
- metric_data.append(
- {
- "Metric": metric_name,
- "Best": f"{best_score:.4f}",
- "Mean": f"{mean_score:.4f}",
- "Std": f"{std_score:.4f}",
- }
- )
-
- metric_df = pd.DataFrame(metric_data)
- st.dataframe(metric_df, width="stretch", hide_index=True)
- else:
- st.write("No test scores found in results.")
-
- # Show parameter space explored
- if "initial_K" in tr.results.columns: # Common parameter
- st.write("\n**Parameter Ranges Explored:**")
- for param in ["initial_K", "eps_cl", "eps_e"]:
- if param in tr.results.columns:
- min_val = tr.results[param].min()
- max_val = tr.results[param].max()
- unique_vals = tr.results[param].nunique()
- st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
-
- st.write(f"\n**Path:** `{tr.path}`")
def render_overview_page():
@@ -537,7 +46,8 @@ def render_overview_page():
st.divider()
# Render dataset analysis section
- render_dataset_analysis()
+ all_stats = load_all_default_dataset_statistics()
+ render_dataset_statistics(all_stats)
st.balloons()
stopwatch.summary()
diff --git a/src/entropice/ml/autogluon_training.py b/src/entropice/ml/autogluon_training.py
index 5070204..869d564 100644
--- a/src/entropice/ml/autogluon_training.py
+++ b/src/entropice/ml/autogluon_training.py
@@ -48,7 +48,7 @@ class AutoGluonSettings:
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
"""Combined settings for AutoGluon training."""
- classes: list[str]
+ classes: list[str] | None
problem_type: str
@@ -228,7 +228,7 @@ def autogluon_train(
combined_settings = AutoGluonTrainingSettings(
**asdict(settings),
**asdict(dataset_ensemble),
- classes=training_data.y.labels,
+ classes=training_data.target_labels,
problem_type=problem_type,
)
settings_file = results_dir / "training_settings.toml"
diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py
index d4bbfea..f74deb8 100644
--- a/src/entropice/ml/dataset.py
+++ b/src/entropice/ml/dataset.py
@@ -78,7 +78,7 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
pivcols = set(collapsed.index.names) - {"cell_ids"}
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
- collapsed.columns = ["_".join(v) for v in collapsed.columns]
+ collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns]
if use_dummy:
collapsed = collapsed.dropna(how="all")
return collapsed
@@ -378,7 +378,8 @@ class DatasetEnsemble:
case ("darts_v1" | "darts_v2", int()):
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
- targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
+ targets = xr.open_zarr(target_store, consolidated=False)
+ targets = targets.sel(year=self.temporal_mode)
case ("darts_mllabels", str()): # Years are not supported
target_store = entropice.utils.paths.get_darts_file(
grid=self.grid, level=self.level, version="mllabels"
@@ -467,6 +468,10 @@ class DatasetEnsemble:
# Apply the temporal mode
match (member.split("-"), self.temporal_mode):
+ case (["ArcticDEM"], _):
+ pass # No temporal dimension
+ case (_, "feature"):
+ pass
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
ds_mean = ds.mean(dim="year")
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
@@ -475,12 +480,8 @@ class DatasetEnsemble:
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
)
ds = xr.merge([ds_mean, ds_trend])
- case (["ArcticDEM"], "synopsis"):
- pass # No temporal dimension
case (_, int() as year):
ds = ds.sel(year=year, drop=True)
- case (_, "feature"):
- pass
case _:
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
@@ -556,7 +557,7 @@ class DatasetEnsemble:
cell_ids: pd.Series,
era5_agg: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame:
- era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
+ era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=True)
era5_df = _collapse_to_dataframe(era5)
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
@@ -565,7 +566,10 @@ class DatasetEnsemble:
@stopwatch("Preparing AlphaEarth Embeddings")
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
- embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
+ embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=True)
+ # For non-synopsis modes there is only a single data variable "embeddings"
+ if self.temporal_mode != "synopsis":
+ embeddings = embeddings["embeddings"]
embeddings_df = _collapse_to_dataframe(embeddings)
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
@@ -574,7 +578,7 @@ class DatasetEnsemble:
@stopwatch("Preparing ArcticDEM")
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
- arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
+ arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=False)
if len(arcticdem["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
@@ -620,7 +624,7 @@ class DatasetEnsemble:
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
- @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
+ # @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
def create_training_set(
self,
task: Task,
diff --git a/src/entropice/ml/training.py b/src/entropice/ml/training.py
index c5ea037..37c4499 100644
--- a/src/entropice/ml/training.py
+++ b/src/entropice/ml/training.py
@@ -3,6 +3,7 @@
import pickle
from dataclasses import asdict, dataclass
from functools import partial
+from pathlib import Path
import cyclopts
import numpy as np
@@ -115,17 +116,19 @@ class TrainingSettings(DatasetEnsemble, CVSettings):
def random_cv(
dataset_ensemble: DatasetEnsemble,
settings: CVSettings = CVSettings(),
-):
+ experiment: str | None = None,
+) -> Path:
"""Perform random cross-validation on the training dataset.
Args:
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
settings (CVSettings): The cross-validation settings.
+ experiment (str | None): Optional experiment name for results directory.
"""
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
- use_array_api = settings.model == "espa"
- device = "torch" if use_array_api else "cuda"
+ use_array_api = settings.model != "xgboost"
+ device = "torch" if settings.model == "espa" else "cuda"
set_config(array_api_dispatch=use_array_api)
print("Creating training data...")
@@ -173,7 +176,8 @@ def random_cv(
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
results_dir = get_cv_results_dir(
- "random_search",
+ experiment=experiment,
+ name="random_search",
grid=dataset_ensemble.grid,
level=dataset_ensemble.level,
task=settings.task,
@@ -283,6 +287,7 @@ def random_cv(
stopwatch.summary()
print("Done.")
+ return results_dir
if __name__ == "__main__":
diff --git a/src/entropice/utils/paths.py b/src/entropice/utils/paths.py
index 7708845..98a83b8 100644
--- a/src/entropice/utils/paths.py
+++ b/src/entropice/utils/paths.py
@@ -139,7 +139,14 @@ def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
return cache_dir / f"cells_{cells_hash}.parquet"
+def get_dataset_stats_cache() -> Path:
+ cache_dir = DATASET_ENSEMBLES_DIR / "cache"
+ cache_dir.mkdir(parents=True, exist_ok=True)
+ return cache_dir / "dataset_stats.pckl"
+
+
def get_cv_results_dir(
+ experiment: str | None,
name: str,
grid: Grid,
level: int,
@@ -147,7 +154,12 @@ def get_cv_results_dir(
) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
- results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
+ if experiment is not None:
+ experiment_dir = RESULTS_DIR / experiment
+ experiment_dir.mkdir(parents=True, exist_ok=True)
+ else:
+ experiment_dir = RESULTS_DIR
+ results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir
diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py
index 1c228b1..322f2cd 100644
--- a/src/entropice/utils/types.py
+++ b/src/entropice/utils/types.py
@@ -15,12 +15,12 @@ type GridLevel = Literal[
"healpix9",
"healpix10",
]
-type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
+type TargetDataset = Literal["darts_v1", "darts_mllabels"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
# TODO: Consider implementing a "timeseries" temporal mode
-type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
+type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
type Model = Literal["espa", "xgboost", "rf", "knn"]
type Stage = Literal["train", "inference", "visualization"]
@@ -37,17 +37,20 @@ class GridConfig:
sort_key: str
@classmethod
- def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
+ def from_grid_level(cls, grid_level: GridLevel | tuple[Literal["hex", "healpix"], int]) -> "GridConfig":
"""Create a GridConfig from a GridLevel string."""
- if grid_level.startswith("hex"):
- grid = "hex"
- level = int(grid_level[3:])
- elif grid_level.startswith("healpix"):
- grid = "healpix"
- level = int(grid_level[7:])
+ if isinstance(grid_level, str):
+ if grid_level.startswith("hex"):
+ grid = "hex"
+ level = int(grid_level[3:])
+ elif grid_level.startswith("healpix"):
+ grid = "healpix"
+ level = int(grid_level[7:])
+ else:
+ raise ValueError(f"Invalid grid level: {grid_level}")
else:
- raise ValueError(f"Invalid grid level: {grid_level}")
-
+ grid, level = grid_level
+ grid_level: GridLevel = f"{grid}{level}" # ty:ignore[invalid-assignment]
display_name = f"{grid.capitalize()}-{level}"
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
@@ -74,8 +77,8 @@ class GridConfig:
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
-all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
-all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
+all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
+all_target_datasets: list[TargetDataset] = ["darts_v1", "darts_mllabels"]
all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
diff --git a/tests/test_training.py b/tests/test_training.py
new file mode 100644
index 0000000..4fd104a
--- /dev/null
+++ b/tests/test_training.py
@@ -0,0 +1,222 @@
+"""Tests for training.py module, specifically random_cv function.
+
+This test suite validates the random_cv training function across all model-task
+combinations using a minimal hex level 3 grid with synopsis temporal mode.
+
+Test Coverage:
+- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
+- Device handling for each model type (torch/CUDA/cuML compatibility)
+- Multi-label target dataset support
+- Temporal mode configuration (synopsis)
+- Output file creation and validation
+
+Running Tests:
+ # Run all training tests (18 tests total, ~3 iterations each)
+ pixi run pytest tests/test_training.py -v
+
+ # Run only device handling tests
+ pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
+
+ # Run a specific model-task combination
+ pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
+
+Note: Tests use minimal iterations (3) and level 3 grid for speed.
+Full production runs use higher iteration counts (100-2000).
+"""
+
+import shutil
+
+import pytest
+
+from entropice.ml.dataset import DatasetEnsemble
+from entropice.ml.training import CVSettings, random_cv
+from entropice.utils.types import Model, Task
+
+
+@pytest.fixture(scope="module")
+def test_ensemble():
+ """Create a minimal DatasetEnsemble for testing.
+
+ Uses hex level 3 grid with synopsis temporal mode for fast testing.
+ """
+ return DatasetEnsemble(
+ grid="hex",
+ level=3,
+ temporal_mode="synopsis",
+ members=["AlphaEarth"], # Use only one member for faster tests
+ add_lonlat=True,
+ )
+
+
+@pytest.fixture
+def cleanup_results():
+ """Clean up results directory after each test.
+
+ This fixture collects the actual result directories created during tests
+ and removes them after the test completes.
+ """
+ created_dirs = []
+
+ def register_dir(results_dir):
+ """Register a directory to be cleaned up."""
+ created_dirs.append(results_dir)
+ return results_dir
+
+ yield register_dir
+
+ # Clean up only the directories created during this test
+ for results_dir in created_dirs:
+ if results_dir.exists():
+ shutil.rmtree(results_dir)
+
+
+# Model-task combinations to test
+# Note: Not all combinations make sense, but we test all to ensure robustness
+MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
+TASKS: list[Task] = ["binary", "count", "density"]
+
+
+class TestRandomCV:
+ """Test suite for random_cv function."""
+
+ @pytest.mark.parametrize("model", MODELS)
+ @pytest.mark.parametrize("task", TASKS)
+ def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
+ """Test random_cv with all model-task combinations.
+
+ This test runs 3 iterations for each combination to verify:
+ - The function completes without errors
+ - Device handling works correctly for each model type
+ - All output files are created
+ """
+ # Use darts_v1 as the primary target for all tests
+ settings = CVSettings(
+ n_iter=3,
+ task=task,
+ target="darts_v1",
+ model=model,
+ )
+
+ # Run the cross-validation and get the results directory
+ results_dir = random_cv(
+ dataset_ensemble=test_ensemble,
+ settings=settings,
+ experiment="test_training",
+ )
+ cleanup_results(results_dir)
+
+ # Verify results directory was created
+ assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
+
+ # Verify all expected output files exist
+ expected_files = [
+ "search_settings.toml",
+ "best_estimator_model.pkl",
+ "search_results.parquet",
+ "metrics.toml",
+ "predicted_probabilities.parquet",
+ ]
+
+ # Add task-specific files
+ if task in ["binary", "count", "density"]:
+ # All tasks that use classification (including count/density when binned)
+ # Note: count and density without _regimes suffix might be regression
+ if task == "binary" or "_regimes" in task:
+ expected_files.append("confusion_matrix.nc")
+
+ # Add model-specific files
+ if model in ["espa", "xgboost", "rf"]:
+ expected_files.append("best_estimator_state.nc")
+
+ for filename in expected_files:
+ filepath = results_dir / filename
+ assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
+
+ @pytest.mark.parametrize("model", MODELS)
+ def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
+ """Test that device handling works correctly for each model type.
+
+ Different models require different device configurations:
+ - espa: Uses torch with array API dispatch
+ - xgboost: Uses CUDA without array API dispatch
+ - rf/knn: GPU-accelerated via cuML
+ """
+ settings = CVSettings(
+ n_iter=3,
+ task="binary", # Simple binary task for device testing
+ target="darts_v1",
+ model=model,
+ )
+
+ # This should complete without device-related errors
+ try:
+ results_dir = random_cv(
+ dataset_ensemble=test_ensemble,
+ settings=settings,
+ experiment="test_training",
+ )
+ cleanup_results(results_dir)
+ except RuntimeError as e:
+ # Check if error is device-related
+ error_msg = str(e).lower()
+ device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
+ if any(keyword in error_msg for keyword in device_keywords):
+ pytest.fail(f"Device handling error for {model=}: {e}")
+ else:
+ # Re-raise non-device errors
+ raise
+
+ def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
+ """Test random_cv with multi-label target dataset."""
+ settings = CVSettings(
+ n_iter=3,
+ task="binary",
+ target="darts_mllabels",
+ model="espa",
+ )
+
+ # Run the cross-validation and get the results directory
+ results_dir = random_cv(
+ dataset_ensemble=test_ensemble,
+ settings=settings,
+ experiment="test_training",
+ )
+ cleanup_results(results_dir)
+
+ # Verify results were created
+ assert results_dir.exists(), "Results directory not created"
+ assert (results_dir / "search_settings.toml").exists()
+
+ def test_temporal_mode_synopsis(self, cleanup_results):
+ """Test that temporal_mode='synopsis' is correctly used."""
+ import toml
+
+ ensemble = DatasetEnsemble(
+ grid="hex",
+ level=3,
+ temporal_mode="synopsis",
+ members=["AlphaEarth"],
+ add_lonlat=True,
+ )
+
+ settings = CVSettings(
+ n_iter=3,
+ task="binary",
+ target="darts_v1",
+ model="espa",
+ )
+
+ # This should use synopsis mode (all years aggregated)
+ results_dir = random_cv(
+ dataset_ensemble=ensemble,
+ settings=settings,
+ experiment="test_training",
+ )
+ cleanup_results(results_dir)
+
+ # Verify the settings were stored correctly
+ assert results_dir.exists(), "Results directory not created"
+ with open(results_dir / "search_settings.toml") as f:
+ stored_settings = toml.load(f)
+
+ assert stored_settings["settings"]["temporal_mode"] == "synopsis"
diff --git a/tests/validate_datasets.py b/tests/validate_datasets.py
new file mode 100644
index 0000000..6e8286b
--- /dev/null
+++ b/tests/validate_datasets.py
@@ -0,0 +1,38 @@
+import cyclopts
+from rich import pretty, print, traceback
+
+from entropice.ml.dataset import DatasetEnsemble
+from entropice.utils.types import all_temporal_modes, grid_configs
+
+pretty.install()
+traceback.install()
+
+cli = cyclopts.App()
+
+
+def _gather_ensemble_stats(e: DatasetEnsemble):
+ # Get a small sample of the cell ids
+ sample_cell_ids = e.cell_ids[:5]
+ features = e.make_features(sample_cell_ids)
+
+ print(
+ f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
+ )
+ print(f"Number of feature columns: {len(features.columns)}")
+
+ for member in ["arcticdem", "embeddings", "era5"]:
+ member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
+ print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
+ print()
+
+
+@cli.default()
+def validate_datasets():
+ for gc in grid_configs:
+ for temporal_mode in all_temporal_modes:
+ e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
+ _gather_ensemble_stats(e)
+
+
+if __name__ == "__main__":
+ cli()