From c9c6af837096c82be086a622a83e14d710a6ba42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Fri, 16 Jan 2026 20:33:10 +0100 Subject: [PATCH] Fix training and overview page --- pyproject.toml | 2 +- src/entropice/dashboard/app.py | 5 +- .../dashboard/plots/dataset_statistics.py | 458 ++++++++++++++++ src/entropice/dashboard/plots/overview.py | 276 ---------- .../dashboard/sections/dataset_statistics.py | 376 +++++++++++++ .../dashboard/sections/experiment_results.py | 158 ++++++ src/entropice/dashboard/utils/colors.py | 10 +- src/entropice/dashboard/utils/loaders.py | 166 ++++-- src/entropice/dashboard/utils/stats.py | 471 +++++++--------- .../dashboard/views/overview_page.py | 502 +----------------- src/entropice/ml/autogluon_training.py | 4 +- src/entropice/ml/dataset.py | 24 +- src/entropice/ml/training.py | 13 +- src/entropice/utils/paths.py | 14 +- src/entropice/utils/types.py | 29 +- tests/test_training.py | 222 ++++++++ tests/validate_datasets.py | 38 ++ 17 files changed, 1643 insertions(+), 1125 deletions(-) create mode 100644 src/entropice/dashboard/plots/dataset_statistics.py delete mode 100644 src/entropice/dashboard/plots/overview.py create mode 100644 src/entropice/dashboard/sections/dataset_statistics.py create mode 100644 src/entropice/dashboard/sections/experiment_results.py create mode 100644 tests/test_training.py create mode 100644 tests/validate_datasets.py diff --git a/pyproject.toml b/pyproject.toml index 6615db7..fd9fe67 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dependencies = [ "pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10", "autogluon-tabular[all,mitra]>=1.5.0", - "shap>=0.50.0,<0.51", + "shap>=0.50.0,<0.51", "h5py>=3.15.1,<4", ] [project.scripts] diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 713f7d0..6eea4d9 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -5,12 +5,14 @@ Pages: - Overview: List of available result directories with some summary statistics. - Training Data: Visualization of training data distributions. - Training Results Analysis: Analysis of training results and model performance. +- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations. - Model State: Visualization of model state and features. - Inference: Visualization of inference results. """ import streamlit as st +from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.overview_page import render_overview_page @@ -26,13 +28,14 @@ def main(): overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") + autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖") model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") pg = st.navigation( { "Overview": [overview_page], - "Training": [training_data_page, training_analysis_page], + "Training": [training_data_page, training_analysis_page, autogluon_page], "Model State": [model_state_page], "Inference": [inference_page], } diff --git a/src/entropice/dashboard/plots/dataset_statistics.py b/src/entropice/dashboard/plots/dataset_statistics.py new file mode 100644 index 0000000..a4685c5 --- /dev/null +++ b/src/entropice/dashboard/plots/dataset_statistics.py @@ -0,0 +1,458 @@ +"""Visualization functions for the overview page. + +This module contains reusable plotting functions for dataset analysis visualizations, +including sample counts, feature counts, and dataset statistics. +""" + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +from entropice.dashboard.utils.colors import get_palette + + +def _get_colors_from_members(unique_members: list[str]) -> dict[str, str]: + unique_members = sorted(unique_members) + unique_sources = {member.split("-")[0] for member in unique_members} + source_member_map = { + source: [member for member in unique_members if member.split("-")[0] == source] for source in unique_sources + } + source_color_map = {} + for source, members in source_member_map.items(): + n_members = len(members) + 2 # Avoid very light/dark colors + palette = get_palette(source, n_colors=n_members) + for idx, member in enumerate(members, start=1): + source_color_map[member] = palette[idx] + return source_color_map + + +def create_sample_count_bar_chart(sample_df: pd.DataFrame) -> go.Figure: + """Create bar chart showing sample counts by grid, target and temporal mode. + + Args: + sample_df: DataFrame with columns: Grid, Target, Samples (Coverage). + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + assert "Grid" in sample_df.columns + assert "Target" in sample_df.columns + assert "Temporal Mode" in sample_df.columns + assert "Samples (Coverage)" in sample_df.columns + + targets = sorted(sample_df["Target"].unique()) + modes = sorted(sample_df["Temporal Mode"].unique()) + + # Create subplots manually to have better control over colors + fig = make_subplots( + rows=1, + cols=len(targets), + subplot_titles=[f"Target: {target}" for target in targets], + shared_yaxes=True, + ) + # Get color palette for this plot + colors = get_palette("Number of Samples", n_colors=len(modes)) + + # Track which modes have been added to legend + modes_in_legend = set() + + for col_idx, target in enumerate(targets, start=1): + target_data = sample_df[sample_df["Target"] == target] + + for mode_idx, mode in enumerate(modes): + mode_data = target_data[target_data["Temporal Mode"] == mode] + if len(mode_data) == 0: + continue + color = colors[mode_idx] if colors and mode_idx < len(colors) else None + + # Show legend only the first time each mode appears + show_in_legend = mode not in modes_in_legend + if show_in_legend: + modes_in_legend.add(mode) + + # Create a unique legendgroup per mode so colors are consistent + fig.add_trace( + go.Bar( + x=mode_data["Grid"], + y=mode_data["Samples (Coverage)"], + name=mode, + marker_color=color, + legendgroup=mode, # Group by mode name + showlegend=show_in_legend, + ), + row=1, + col=col_idx, + ) + + fig.update_layout( + title_text="Training Sample Counts by Grid Configuration and Target Dataset", + barmode="group", + height=500, + showlegend=True, + ) + + fig.update_xaxes(tickangle=-45) + fig.update_yaxes(title_text="Number of Samples", row=1, col=1) + + return fig + + +def create_feature_count_stacked_bar(breakdown_df: pd.DataFrame) -> go.Figure: + """Create stacked bar chart showing feature counts by member. + + Args: + breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features. + + Returns: + Plotly Figure object containing the stacked bar chart visualization. + + """ + assert "Grid" in breakdown_df.columns + assert "Temporal Mode" in breakdown_df.columns + assert "Member" in breakdown_df.columns + assert "Number of Features" in breakdown_df.columns + + unique_members = sorted(breakdown_df["Member"].unique()) + temporal_modes = sorted(breakdown_df["Temporal Mode"].unique()) + source_color_map = _get_colors_from_members(unique_members) + + # Create subplots for each temporal mode + fig = make_subplots( + rows=1, + cols=len(temporal_modes), + subplot_titles=[f"Temporal Mode: {mode}" for mode in temporal_modes], + shared_yaxes=True, + ) + + for col_idx, mode in enumerate(temporal_modes, start=1): + mode_data = breakdown_df[breakdown_df["Temporal Mode"] == mode] + + for member in unique_members: + member_data = mode_data[mode_data["Member"] == member] + color = source_color_map.get(member) + + fig.add_trace( + go.Bar( + x=member_data["Grid"], + y=member_data["Number of Features"], + name=member, + marker_color=color, + legendgroup=member, + showlegend=(col_idx == 1), + ), + row=1, + col=col_idx, + ) + + fig.update_layout( + title_text="Input Features by Member Across Grid Configurations", + barmode="stack", + height=500, + showlegend=True, + ) + fig.update_xaxes(tickangle=-45) + fig.update_yaxes(title_text="Number of Features", row=1, col=1) + + return fig + + +def create_inference_cells_bar(sample_df: pd.DataFrame) -> go.Figure: + """Create bar chart for inference cells by grid configuration. + + Args: + sample_df: DataFrame with columns: Grid, Temporal Mode, Inference Cells. + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + assert "Grid" in sample_df.columns + assert "Inference Cells" in sample_df.columns + + n_grids = sample_df["Grid"].nunique() + mode_colors: list[str] = get_palette("Number of Samples", n_colors=n_grids) + + fig = px.bar( + sample_df, + x="Grid", + y="Inference Cells", + color="Grid", + title="Spatial Coverage (Grid Cells with Complete Data)", + labels={ + "Grid": "Grid Configuration", + "Inference Cells": "Number of Cells", + }, + color_discrete_sequence=mode_colors, + text="Inference Cells", + ) + + fig.update_traces(texttemplate="%{text:,}", textposition="outside") + fig.update_layout(xaxis_tickangle=-45, height=500, showlegend=False) + + return fig + + +def create_total_samples_bar( + comparison_df: pd.DataFrame, + grid_colors: list[str] | None = None, +) -> go.Figure: + """Create bar chart for total samples by grid configuration. + + Args: + comparison_df: DataFrame with columns: Grid, Total Samples. + grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the bar chart visualization. + + """ + fig = px.bar( + comparison_df, + x="Grid", + y="Total Samples", + color="Grid", + title="Training Samples (Binary Task)", + labels={ + "Grid": "Grid Configuration", + "Total Samples": "Number of Samples", + }, + color_discrete_sequence=grid_colors, + text="Total Samples", + ) + + fig.update_traces(texttemplate="%{text:,}", textposition="outside") + fig.update_layout(xaxis_tickangle=-45, showlegend=False) + + return fig + + +def create_feature_breakdown_donut( + grid_data: pd.DataFrame, + grid_config: str, + source_color_map: dict[str, str] | None = None, +) -> go.Figure: + """Create donut chart for feature breakdown by data source for a specific grid. + + Args: + grid_data: DataFrame with columns: Data Source, Number of Features. + grid_config: Grid configuration name for the title. + source_color_map: Optional dictionary mapping data source names to specific colors. + If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the donut chart visualization. + + """ + fig = px.pie( + grid_data, + names="Data Source", + values="Number of Features", + title=grid_config, + hole=0.4, + color_discrete_map=source_color_map, + color="Data Source", + ) + + fig.update_traces(textposition="inside", textinfo="percent") + fig.update_layout(showlegend=True, height=350) + + return fig + + +def create_feature_distribution_pie( + breakdown_df: pd.DataFrame, + source_color_map: dict[str, str] | None = None, +) -> go.Figure: + """Create pie chart for feature distribution by data source. + + Args: + breakdown_df: DataFrame with columns: Data Source, Number of Features. + source_color_map: Optional dictionary mapping data source names to specific colors. + If None, uses default Plotly colors. + + Returns: + Plotly Figure object containing the pie chart visualization. + + """ + fig = px.pie( + breakdown_df, + names="Data Source", + values="Number of Features", + title="Feature Distribution by Data Source", + hole=0.4, + color_discrete_map=source_color_map, + color="Data Source", + ) + + fig.update_traces(textposition="inside", textinfo="percent+label") + fig.update_layout(height=400) + + return fig + + +def create_member_details_table(member_stats_dict: dict[str, dict]) -> pd.DataFrame: + """Create a detailed table for member statistics. + + Args: + member_stats_dict: Dictionary mapping member names to their statistics. + Each stats dict should contain: feature_count, variable_names, dimensions, size_bytes. + + Returns: + DataFrame with detailed member information. + + """ + rows = [] + for member, stats in member_stats_dict.items(): + # Calculate total data points + total_points = 1 + for dim_size in stats["dimensions"].values(): + total_points *= dim_size + + # Format dimension string + dim_str = " x ".join([f"{k}={v:,}" for k, v in stats["dimensions"].items()]) + + rows.append( + { + "Member": member, + "Features": stats["feature_count"], + "Variables": len(stats["variable_names"]), + "Dimensions": dim_str, + "Data Points": f"{total_points:,}", + "Size (MB)": f"{stats['size_bytes'] / (1024 * 1024):.2f}", + } + ) + + return pd.DataFrame(rows) + + +def create_feature_breakdown_donuts_grid(breakdown_df: pd.DataFrame) -> go.Figure: + """Create grid of donut charts showing feature breakdown by member for each grid+temporal mode. + + Args: + breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features. + + Returns: + Plotly Figure object containing the grid of donut charts. + + """ + assert "Grid" in breakdown_df.columns + assert "Temporal Mode" in breakdown_df.columns + assert "Member" in breakdown_df.columns + assert "Number of Features" in breakdown_df.columns + + # Get unique grids and temporal modes + grids = breakdown_df["Grid"].unique().tolist() + temporal_modes = breakdown_df["Temporal Mode"].unique().tolist() + unique_members = sorted(breakdown_df["Member"].unique()) + source_color_map = _get_colors_from_members(unique_members) + + n_grids = len(grids) + n_modes = len(temporal_modes) + + # Create subplot grid (grids as rows, temporal modes as columns) + # Don't use subplot_titles - we'll add custom annotations for row/column labels + fig = make_subplots( + rows=n_grids, + cols=n_modes, + specs=[[{"type": "pie"}] * n_modes for _ in range(n_grids)], + horizontal_spacing=0.01, + vertical_spacing=0.02, + ) + + # Get all members to ensure consistent coloring across all donuts + all_members = sorted(breakdown_df["Member"].unique()) + + for row_idx, grid in enumerate(grids, start=1): + for col_idx, mode in enumerate(temporal_modes, start=1): + subset = breakdown_df[(breakdown_df["Grid"] == grid) & (breakdown_df["Temporal Mode"] == mode)] + + if len(subset) == 0: + continue + + # Build labels, values, and colors in consistent order + labels = [] + values = [] + colors_list = [] + + for member in all_members: + member_data = subset[subset["Member"] == member] + if len(member_data) > 0: + labels.append(member) + values.append(member_data["Number of Features"].iloc[0]) + if source_color_map: + colors_list.append(source_color_map[member]) + + # Only show legend for first subplot + show_legend = row_idx == 1 and col_idx == 1 + + fig.add_trace( + go.Pie( + labels=labels, + values=values, + name=f"{grid} - {mode}", + hole=0.4, + marker={"colors": colors_list} if colors_list else None, + textposition="inside", + textinfo="percent", + showlegend=show_legend, + ), + row=row_idx, + col=col_idx, + ) + + # Calculate appropriate height based on number of rows + height = max(500, n_grids * 400) + + fig.update_layout( + title_text="Feature Breakdown by Member Across Grid Configurations and Temporal Modes", + height=height, + showlegend=True, + margin={"l": 100, "r": 50, "t": 150, "b": 50}, # Add left margin for row labels + ) + + # Calculate subplot positions in paper coordinates + # Account for spacing between subplots + h_spacing = 0.01 + v_spacing = 0.02 + + # Calculate width and height of each subplot in paper coordinates + subplot_width = (1.0 - h_spacing * (n_modes - 1)) / n_modes + subplot_height = (1.0 - v_spacing * (n_grids - 1)) / n_grids + + # Add column headers at the top (temporal modes) + for col_idx, mode in enumerate(temporal_modes): + # Calculate x position for this column's center + x_pos = col_idx * (subplot_width + h_spacing) + subplot_width / 2 + fig.add_annotation( + text=f"{mode}", + xref="paper", + yref="paper", + x=x_pos, + y=1.01, + showarrow=False, + xanchor="center", + yanchor="bottom", + font={"size": 18}, + ) + + # Add row headers on the left (grid configurations) + for row_idx, grid in enumerate(grids): + # Calculate y position for this row's center + # Note: y coordinates go from bottom to top, but rows go from top to bottom + y_pos = 1.0 - (row_idx * (subplot_height + v_spacing) + subplot_height / 2) + fig.add_annotation( + text=f"{grid}", + xref="paper", + yref="paper", + x=-0.01, + y=y_pos, + showarrow=False, + xanchor="right", + yanchor="middle", + font={"size": 18}, + textangle=-90, + ) + + return fig diff --git a/src/entropice/dashboard/plots/overview.py b/src/entropice/dashboard/plots/overview.py deleted file mode 100644 index 91de690..0000000 --- a/src/entropice/dashboard/plots/overview.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Visualization functions for the overview page. - -This module contains reusable plotting functions for dataset analysis visualizations, -including sample counts, feature counts, and dataset statistics. -""" - -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go - - -def create_sample_count_heatmap( - pivot_df: pd.DataFrame, - target: str, - colorscale: list[str] | None = None, -) -> go.Figure: - """Create heatmap showing sample counts by grid and task. - - Args: - pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values. - target: Target dataset name for the title. - colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the heatmap visualization. - - """ - fig = px.imshow( - pivot_df, - labels={ - "x": "Task", - "y": "Grid Configuration", - "color": "Sample Count", - }, - x=pivot_df.columns, - y=pivot_df.index, - color_continuous_scale=colorscale, - aspect="auto", - title=f"Target: {target}", - ) - - # Add text annotations - fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10) - fig.update_layout(height=400) - - return fig - - -def create_sample_count_bar_chart( - sample_df: pd.DataFrame, - target_color_maps: dict[str, list[str]] | None = None, -) -> go.Figure: - """Create bar chart showing sample counts by grid, target, and task. - - Args: - sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage). - target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes. - If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the bar chart visualization. - - """ - # Create subplots manually to have better control over colors - from plotly.subplots import make_subplots - - targets = sorted(sample_df["Target"].unique()) - tasks = sorted(sample_df["Task"].unique()) - - fig = make_subplots( - rows=1, - cols=len(targets), - subplot_titles=[f"Target: {target}" for target in targets], - shared_yaxes=True, - ) - - for col_idx, target in enumerate(targets, 1): - target_data = sample_df[sample_df["Target"] == target] - # Get color palette for this target - colors = target_color_maps.get(target, None) if target_color_maps else None - - for task_idx, task in enumerate(tasks): - task_data = target_data[target_data["Task"] == task] - color = colors[task_idx] if colors and task_idx < len(colors) else None - - # Create a unique legendgroup per task so colors are consistent - fig.add_trace( - go.Bar( - x=task_data["Grid"], - y=task_data["Samples (Coverage)"], - name=task, - marker_color=color, - legendgroup=task, # Group by task name - showlegend=(col_idx == 1), # Only show legend for first subplot - ), - row=1, - col=col_idx, - ) - - fig.update_layout( - title_text="Training Sample Counts by Grid Configuration and Target Dataset", - barmode="group", - height=500, - showlegend=True, - ) - - fig.update_xaxes(tickangle=-45) - fig.update_yaxes(title_text="Number of Samples", row=1, col=1) - - return fig - - -def create_feature_count_stacked_bar( - breakdown_df: pd.DataFrame, - source_color_map: dict[str, str] | None = None, -) -> go.Figure: - """Create stacked bar chart showing feature counts by data source. - - Args: - breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features. - source_color_map: Optional dictionary mapping data source names to specific colors. - If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the stacked bar chart visualization. - - """ - fig = px.bar( - breakdown_df, - x="Grid", - y="Number of Features", - color="Data Source", - barmode="stack", - title="Input Features by Data Source Across Grid Configurations", - labels={ - "Grid": "Grid Configuration", - "Number of Features": "Number of Features", - }, - color_discrete_map=source_color_map, - text_auto=False, - ) - - fig.update_layout(height=500, xaxis_tickangle=-45) - - return fig - - -def create_inference_cells_bar( - comparison_df: pd.DataFrame, - grid_colors: list[str] | None = None, -) -> go.Figure: - """Create bar chart for inference cells by grid configuration. - - Args: - comparison_df: DataFrame with columns: Grid, Inference Cells. - grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the bar chart visualization. - - """ - fig = px.bar( - comparison_df, - x="Grid", - y="Inference Cells", - color="Grid", - title="Spatial Coverage (Grid Cells with Complete Data)", - labels={ - "Grid": "Grid Configuration", - "Inference Cells": "Number of Cells", - }, - color_discrete_sequence=grid_colors, - text="Inference Cells", - ) - - fig.update_traces(texttemplate="%{text:,}", textposition="outside") - fig.update_layout(xaxis_tickangle=-45, showlegend=False) - - return fig - - -def create_total_samples_bar( - comparison_df: pd.DataFrame, - grid_colors: list[str] | None = None, -) -> go.Figure: - """Create bar chart for total samples by grid configuration. - - Args: - comparison_df: DataFrame with columns: Grid, Total Samples. - grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the bar chart visualization. - - """ - fig = px.bar( - comparison_df, - x="Grid", - y="Total Samples", - color="Grid", - title="Training Samples (Binary Task)", - labels={ - "Grid": "Grid Configuration", - "Total Samples": "Number of Samples", - }, - color_discrete_sequence=grid_colors, - text="Total Samples", - ) - - fig.update_traces(texttemplate="%{text:,}", textposition="outside") - fig.update_layout(xaxis_tickangle=-45, showlegend=False) - - return fig - - -def create_feature_breakdown_donut( - grid_data: pd.DataFrame, - grid_config: str, - source_color_map: dict[str, str] | None = None, -) -> go.Figure: - """Create donut chart for feature breakdown by data source for a specific grid. - - Args: - grid_data: DataFrame with columns: Data Source, Number of Features. - grid_config: Grid configuration name for the title. - source_color_map: Optional dictionary mapping data source names to specific colors. - If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the donut chart visualization. - - """ - fig = px.pie( - grid_data, - names="Data Source", - values="Number of Features", - title=grid_config, - hole=0.4, - color_discrete_map=source_color_map, - color="Data Source", - ) - - fig.update_traces(textposition="inside", textinfo="percent") - fig.update_layout(showlegend=True, height=350) - - return fig - - -def create_feature_distribution_pie( - breakdown_df: pd.DataFrame, - source_color_map: dict[str, str] | None = None, -) -> go.Figure: - """Create pie chart for feature distribution by data source. - - Args: - breakdown_df: DataFrame with columns: Data Source, Number of Features. - source_color_map: Optional dictionary mapping data source names to specific colors. - If None, uses default Plotly colors. - - Returns: - Plotly Figure object containing the pie chart visualization. - - """ - fig = px.pie( - breakdown_df, - names="Data Source", - values="Number of Features", - title="Feature Distribution by Data Source", - hole=0.4, - color_discrete_map=source_color_map, - color="Data Source", - ) - - fig.update_traces(textposition="inside", textinfo="percent+label") - - return fig diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py new file mode 100644 index 0000000..34ae0cb --- /dev/null +++ b/src/entropice/dashboard/sections/dataset_statistics.py @@ -0,0 +1,376 @@ +"""Dataset Statistics Section for Entropice Dashboard.""" + +import pandas as pd +import streamlit as st + +from entropice.dashboard.plots.dataset_statistics import ( + create_feature_breakdown_donuts_grid, + create_feature_count_stacked_bar, + create_feature_distribution_pie, + create_inference_cells_bar, + create_member_details_table, + create_sample_count_bar_chart, +) +from entropice.dashboard.utils.colors import get_palette +from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics +from entropice.utils.types import ( + GridLevel, + L2SourceDataset, + TargetDataset, + Task, + TemporalMode, + all_l2_source_datasets, + grid_configs, +) + +type DatasetStatsCache = dict[GridLevel, dict[TemporalMode, DatasetStatistics]] + + +def render_training_samples_tab(sample_df: pd.DataFrame): + """Render overview of sample counts per task+target+grid+level combination.""" + st.markdown( + """ + This visualization shows the number of available training samples for each combination of: + - **Task**: binary, count, density + - **Target Dataset**: darts_rts, darts_mllabels + - **Grid System**: hex, healpix + - **Grid Level**: varying by grid type + """ + ) + + # Create and display bar chart + fig = create_sample_count_bar_chart(sample_df) + st.plotly_chart(fig, width="stretch") + + # Display full table with formatting + st.markdown("#### Detailed Sample Counts") + # Format numbers with commas + sample_df["Samples (Coverage)"] = sample_df["Samples (Coverage)"].apply(lambda x: f"{x:,}") + # Format coverage as percentage with 2 decimal places + sample_df["Coverage %"] = sample_df["Coverage %"].apply(lambda x: f"{x:.2f}%") + + st.dataframe(sample_df, hide_index=True, width="stretch") + + +def render_dataset_characteristics_tab( + breakdown_df: pd.DataFrame, sample_df: pd.DataFrame, comparison_df: pd.DataFrame +): + """Render static comparison of feature counts across all grid configurations.""" + st.markdown( + """ + Comparing dataset characteristics for all grid configurations with all data sources enabled. + - **Features**: Total number of input features from all data sources + - **Spatial Coverage**: Number of grid cells with complete data coverage + """ + ) + + # Get data from cache + + # Create and display stacked bar chart + fig = create_feature_count_stacked_bar(breakdown_df) + st.plotly_chart(fig, width="stretch") + + # Add spatial coverage metric + fig_cells = create_inference_cells_bar(sample_df) + st.plotly_chart(fig_cells, width="stretch") + + # Display full comparison table with formatting + st.markdown("#### Detailed Comparison Table") + st.dataframe(comparison_df, hide_index=True, width="stretch") + + +def render_feature_breakdown_tab(breakdown_df: pd.DataFrame): + """Render feature breakdown by member across grid configurations and temporal modes.""" + st.markdown( + """ + Visualizing feature contribution of each data source (member) across all grid configurations + and temporal modes. Each donut chart represents the feature breakdown for a specific + grid-temporal mode combination. + """ + ) + # Create and display the grid of donut charts + fig = create_feature_breakdown_donuts_grid(breakdown_df) + st.plotly_chart(fig, width="stretch") + + # Display full breakdown table with formatting + st.markdown("#### Detailed Breakdown Table") + st.dataframe(breakdown_df, hide_index=True, width="stretch") + + +@st.fragment +def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901 + """Render interactive detailed configuration explorer.""" + st.markdown( + """ + Explore detailed statistics for a specific dataset configuration. + Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics. + """ + ) + + # Grid selection + grid_options = [gc.display_name for gc in grid_configs] + + col1, col2, col3, col4 = st.columns(4) + + with col1: + selected_grid_display = st.selectbox( + "Grid Configuration", + options=grid_options, + index=0, + help="Select the grid system and resolution level", + key="feature_grid_select", + ) + + with col2: + # Create temporal mode options with proper formatting + temporal_mode_options = ["feature", "synopsis"] + [str(year) for year in [2018, 2019, 2020, 2021, 2022, 2023]] + temporal_mode_display = st.selectbox( + "Temporal Mode", + options=temporal_mode_options, + index=0, + help="Select the temporal mode (feature, synopsis, or specific year)", + key="feature_temporal_mode_select", + ) + # Convert back to proper type + selected_temporal_mode: TemporalMode + if temporal_mode_display.isdigit(): + selected_temporal_mode = int(temporal_mode_display) # type: ignore[assignment] + else: + selected_temporal_mode = temporal_mode_display # pyright: ignore[reportAssignmentType] + + with col3: + selected_target: TargetDataset = st.selectbox( + "Target Dataset", + options=["darts_v1", "darts_mllabels"], + index=0, + help="Select the target dataset", + key="feature_target_select", + ) + + with col4: + selected_task: Task = st.selectbox( + "Task", + options=["binary", "count", "density", "count_regimes", "density_regimes"], + index=0, + help="Select the task type", + key="feature_task_select", + ) + + # Find the selected grid config + selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display) + + # Get stats for the selected configuration + mode_stats = dataset_stats[selected_grid_config.id] + + # Check if the selected temporal mode has stats + if selected_temporal_mode not in mode_stats: + st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.") + return + + stats = mode_stats[selected_temporal_mode] + available_members = list(stats.members.keys()) + + # Members selection + st.markdown("#### Select Data Sources") + + # Use columns for checkboxes + cols = st.columns(len(all_l2_source_datasets)) + selected_members: list[L2SourceDataset] = [] + + for idx, member in enumerate(all_l2_source_datasets): + with cols[idx]: + default_value = member in available_members + if st.checkbox(member, value=default_value, key=f"feature_member_{member}"): + selected_members.append(member) # type: ignore[arg-type] + + # Show results if at least one member is selected + if selected_members: + # Filter to selected members only + selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members} + + # Calculate total features for selected members + total_features = sum(ms.feature_count for ms in selected_member_stats.values()) + + # Get target stats if available + if selected_target not in stats.target: + st.warning(f"Target {selected_target} is not available for this configuration.") + return + + target_stats = stats.target[selected_target] + + # Check if task is available + if selected_task not in target_stats: + st.warning( + f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode." + ) + return + + task_stats = target_stats[selected_task] + + # High-level metrics + st.markdown("#### Configuration Summary") + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("Total Features", f"{total_features:,}") + with col2: + # Calculate minimum cells across all data sources (for inference capability) + if selected_member_stats: + min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values()) + else: + min_cells = 0 + st.metric( + "Inference Cells", + f"{min_cells:,}", + help="Minimum number of cells across all selected data sources", + ) + with col3: + st.metric("Data Sources", len(selected_members)) + with col4: + st.metric("Training Samples", f"{task_stats.training_cells:,}") + with col5: + # Calculate total data points + total_points = total_features * task_stats.training_cells + st.metric("Total Data Points", f"{total_points:,}") + + # Task-specific statistics + st.markdown("#### Task Statistics") + task_col1, task_col2, task_col3 = st.columns(3) + with task_col1: + st.metric("Task Type", selected_task.replace("_", " ").title()) + with task_col2: + st.metric("Coverage", f"{task_stats.coverage:.2f}%") + with task_col3: + if task_stats.class_counts: + st.metric("Number of Classes", len(task_stats.class_counts)) + else: + st.metric("Task Mode", "Regression") + + # Class distribution for classification tasks + if task_stats.class_distribution: + st.markdown("#### Class Distribution") + class_dist_df = pd.DataFrame( + [ + { + "Class": class_name, + "Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0, + "Percentage": f"{pct:.2f}%", + } + for class_name, pct in task_stats.class_distribution.items() + ] + ) + st.dataframe(class_dist_df, hide_index=True, width="stretch") + + # Feature breakdown by source + st.markdown("#### Feature Breakdown by Data Source") + + breakdown_data = [] + for member, member_stats in selected_member_stats.items(): + breakdown_data.append( + { + "Data Source": member, + "Number of Features": member_stats.feature_count, + "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%", + } + ) + + breakdown_df = pd.DataFrame(breakdown_data) + + # Get all unique data sources and create color map + unique_members_for_color = sorted(selected_member_stats.keys()) + source_color_map_raw = {} + for member in unique_members_for_color: + source = member.split("-")[0] + n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source) + palette = get_palette(source, n_colors=n_members + 2) + idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member) + source_color_map_raw[member] = palette[idx + 1] + + # Create and display pie chart + fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw) + st.plotly_chart(fig, width="stretch") + + # Show detailed table + st.dataframe(breakdown_df, hide_index=True, width="stretch") + + # Detailed member information + with st.expander("📦 Detailed Source Information", expanded=False): + # Create detailed table + member_details_dict = { + member: { + "feature_count": ms.feature_count, + "variable_names": ms.variable_names, + "dimensions": ms.dimensions, + "size_bytes": ms.size_bytes, + } + for member, ms in selected_member_stats.items() + } + details_df = create_member_details_table(member_details_dict) + st.dataframe(details_df, hide_index=True, width="stretch") + + # Individual member details + for member, member_stats in selected_member_stats.items(): + st.markdown(f"### {member}") + + # Variables + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats.variable_names + ] + ) + st.markdown(vars_html, unsafe_allow_html=True) + + # Dimensions + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size:,}" + for dim_name, dim_size in member_stats.dimensions.items() + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) + + st.markdown("---") + else: + st.info("👆 Select at least one data source to see feature statistics") + + +def render_dataset_statistics(all_stats: DatasetStatsCache): + """Render the dataset statistics section with sample and feature counts.""" + st.header("📈 Dataset Statistics") + + all_stats: DatasetStatsCache = load_all_default_dataset_statistics() + training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats) + feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) + comparison_df = DatasetStatistics.get_comparison_df(all_stats) + inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats) + + # Create tabs for different analysis views + analysis_tabs = st.tabs( + [ + "📊 Training Samples", + "📈 Dataset Characteristics", + "🔍 Feature Breakdown", + "⚙️ Configuration Explorer", + ] + ) + + with analysis_tabs[0]: + st.subheader("Training Samples by Configuration") + render_training_samples_tab(training_sample_df) + + with analysis_tabs[1]: + st.subheader("Dataset Characteristics Across Grid Configurations") + render_dataset_characteristics_tab(feature_breakdown_df, inference_sample_df, comparison_df) + + with analysis_tabs[2]: + st.subheader("Feature Breakdown by Data Source") + render_feature_breakdown_tab(feature_breakdown_df) + + with analysis_tabs[3]: + st.subheader("Interactive Configuration Explorer") + render_configuration_explorer_tab(all_stats) diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py new file mode 100644 index 0000000..04b3176 --- /dev/null +++ b/src/entropice/dashboard/sections/experiment_results.py @@ -0,0 +1,158 @@ +"""Experiment Results Section for the Entropice Dashboard.""" + +from datetime import datetime + +import streamlit as st + +from entropice.dashboard.utils.loaders import TrainingResult +from entropice.utils.types import ( + GridConfig, +) + + +def render_training_results_summary(training_results: list[TrainingResult]): + """Render summary metrics for training results.""" + st.header("📊 Training Results Summary") + col1, col2, col3, col4 = st.columns(4) + + with col1: + experiments = {tr.experiment for tr in training_results if tr.experiment} + st.metric("Experiments", len(experiments)) + + with col2: + st.metric("Total Runs", len(training_results)) + + with col3: + models = {tr.settings.model for tr in training_results} + st.metric("Model Types", len(models)) + + with col4: + latest = training_results[0] # Already sorted by creation time + latest_datetime = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d %H:%M") + st.metric("Latest Run", latest_datetime) + + +@st.fragment +def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901 + """Render detailed experiment results table and expandable details.""" + st.header("🎯 Experiment Results") + + # Filters + experiments = sorted({tr.experiment for tr in training_results if tr.experiment}) + tasks = sorted({tr.settings.task for tr in training_results}) + models = sorted({tr.settings.model for tr in training_results}) + grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results}) + + # Create filter columns + filter_cols = st.columns(4) + + with filter_cols[0]: + if experiments: + selected_experiment = st.selectbox( + "Experiment", + options=["All", *experiments], + index=0, + ) + else: + selected_experiment = "All" + + with filter_cols[1]: + selected_task = st.selectbox( + "Task", + options=["All", *tasks], + index=0, + ) + + with filter_cols[2]: + selected_model = st.selectbox( + "Model", + options=["All", *models], + index=0, + ) + + with filter_cols[3]: + selected_grid = st.selectbox( + "Grid", + options=["All", *grids], + index=0, + ) + + # Apply filters + filtered_results = training_results + if selected_experiment != "All": + filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment] + if selected_task != "All": + filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task] + if selected_model != "All": + filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model] + if selected_grid != "All": + filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid] + + st.subheader("Results Table") + + summary_df = TrainingResult.to_dataframe(filtered_results) + # Display with color coding for best scores + st.dataframe( + summary_df, + width="stretch", + hide_index=True, + ) + + # Expandable details for each result + st.subheader("Individual Experiment Details") + + for tr in filtered_results: + tr_info = tr.display_info + display_name = tr_info.get_display_name("model_first") + with st.expander(display_name): + col1, col2 = st.columns([1, 2]) + + with col1: + grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level)) + st.write("**Configuration:**") + st.write(f"- **Experiment:** {tr.experiment}") + st.write(f"- **Task:** {tr.settings.task}") + st.write(f"- **Model:** {tr.settings.model}") + st.write(f"- **Grid:** {grid_config.display_name}") + st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}") + st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}") + st.write(f"- **Members:** {', '.join(tr.settings.members)}") + st.write(f"- **CV Splits:** {tr.settings.cv_splits}") + st.write(f"- **Classes:** {tr.settings.classes}") + + st.write("\n**Files:**") + for file in tr.files: + if file.name == "search_settings.toml": + st.write(f"- ⚙️ `{file.name}`") + elif file.name == "best_estimator_model.pkl": + st.write(f"- 🧮 `{file.name}`") + elif file.name == "search_results.parquet": + st.write(f"- 📊 `{file.name}`") + elif file.name == "predicted_probabilities.parquet": + st.write(f"- 🎯 `{file.name}`") + else: + st.write(f"- 📄 `{file.name}`") + with col2: + st.write("**CV Score Summary:**") + + # Extract all test scores + metric_df = tr.get_metric_dataframe() + if metric_df is not None: + st.dataframe(metric_df, width="stretch", hide_index=True) + else: + st.write("No test scores found in results.") + + # Show parameter space explored + if "initial_K" in tr.results.columns: # Common parameter + st.write("\n**Parameter Ranges Explored:**") + for param in ["initial_K", "eps_cl", "eps_e"]: + if param in tr.results.columns: + min_val = tr.results[param].min() + max_val = tr.results[param].max() + unique_vals = tr.results[param].nunique() + st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") + + with st.expander("Show CV Results DataFrame"): + st.dataframe(tr.results, width="stretch", hide_index=True) + + st.write(f"\n**Path:** `{tr.path}`") diff --git a/src/entropice/dashboard/utils/colors.py b/src/entropice/dashboard/utils/colors.py index 97b3838..aa1a758 100644 --- a/src/entropice/dashboard/utils/colors.py +++ b/src/entropice/dashboard/utils/colors.py @@ -90,7 +90,15 @@ def get_palette(variable: str, n_colors: int) -> list[str]: A list of hex color strings. """ - cmap = get_cmap(variable).resampled(n_colors) + # Hardcode some common variables to specific colormaps + if variable == "ERA5": + cmap = load_cmap(name="blue_material").resampled(n_colors) + elif variable == "ArcticDEM": + cmap = load_cmap(name="deep_purple_material").resampled(n_colors) + elif variable == "AlphaEarth": + cmap = load_cmap(name="green_material").resampled(n_colors) + else: + cmap = get_cmap(variable).resampled(n_colors) colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)] return colors diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index 93f3c79..83c9796 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -6,7 +6,6 @@ from datetime import datetime from pathlib import Path import antimeridian -import geopandas as gpd import pandas as pd import streamlit as st import toml @@ -16,9 +15,8 @@ from shapely.geometry import shape import entropice.spatial.grids import entropice.utils.paths from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo -from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble from entropice.ml.training import TrainingSettings -from entropice.utils.types import L2SourceDataset, Task +from entropice.utils.types import GridConfig def _fix_hex_geometry(geom): @@ -32,22 +30,29 @@ def _fix_hex_geometry(geom): @dataclass class TrainingResult: + """Wrapper for training result data and metadata.""" + path: Path + experiment: str settings: TrainingSettings results: pd.DataFrame - metrics: dict[str, float] - confusion_matrix: xr.DataArray + train_metrics: dict[str, float] + test_metrics: dict[str, float] + combined_metrics: dict[str, float] + confusion_matrix: xr.Dataset | None created_at: float available_metrics: list[str] + files: list[Path] @classmethod - def from_path(cls, result_path: Path) -> "TrainingResult": + def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult": """Load a TrainingResult from a given result directory path.""" result_file = result_path / "search_results.parquet" preds_file = result_path / "predicted_probabilities.parquet" settings_file = result_path / "search_settings.toml" - metrics_file = result_path / "test_metrics.toml" + metrics_file = result_path / "metrics.toml" confusion_matrix_file = result_path / "confusion_matrix.nc" + all_files = list(result_path.iterdir()) if not result_file.exists(): raise FileNotFoundError(f"Missing results file in {result_path}") if not settings_file.exists(): @@ -56,28 +61,46 @@ class TrainingResult: raise FileNotFoundError(f"Missing predictions file in {result_path}") if not metrics_file.exists(): raise FileNotFoundError(f"Missing metrics file in {result_path}") - if not confusion_matrix_file.exists(): - raise FileNotFoundError(f"Missing confusion matrix file in {result_path}") created_at = result_path.stat().st_ctime - settings = TrainingSettings(**(toml.load(settings_file)["settings"])) + settings_dict = toml.load(settings_file)["settings"] + + # Handle backward compatibility: add missing fields with defaults + if "classes" not in settings_dict: + settings_dict["classes"] = None + if "param_grid" not in settings_dict: + settings_dict["param_grid"] = {} + if "cv_splits" not in settings_dict: + settings_dict["cv_splits"] = 5 + if "metrics" not in settings_dict: + settings_dict["metrics"] = [] + + settings = TrainingSettings(**settings_dict) results = pd.read_parquet(result_file) - metrics = toml.load(metrics_file)["test_metrics"] - confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf") + metrics = toml.load(metrics_file) + if not confusion_matrix_file.exists(): + confusion_matrix = None + else: + confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf") available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")] return cls( path=result_path, + experiment=experiment_name or "N/A", settings=settings, results=results, - metrics=metrics, + train_metrics=metrics["train_metrics"], + test_metrics=metrics["test_metrics"], + combined_metrics=metrics["combined_metrics"], confusion_matrix=confusion_matrix, created_at=created_at, available_metrics=available_metrics, + files=all_files, ) @property def display_info(self) -> TrainingResultDisplayInfo: + """Get display information for the training result.""" return TrainingResultDisplayInfo( task=self.settings.task, model=self.settings.model, @@ -126,9 +149,70 @@ class TrainingResult: st.error(f"Error loading predictions: {e}") return None + def get_metric_dataframe(self) -> pd.DataFrame | None: + """Get a DataFrame of available metrics for this training result.""" + metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")] + if not metric_cols: + return None + metric_data = [] + for col in metric_cols: + metric_name = col.replace("mean_test_", "").replace("neg_", "").title() + metrics = self.results[col] + # Check if the metric is negative + if col.startswith("mean_test_neg_"): + task_multiplier = 1 if self.settings.task != "density" else 100 + task_multiplier *= -1 + metrics = metrics * task_multiplier + + metric_data.append( + { + "Metric": metric_name, + "Best": f"{metrics.max():.4f}", + "Mean": f"{metrics.mean():.4f}", + "Std": f"{metrics.std():.4f}", + "Worst": f"{metrics.min():.4f}", + } + ) + + return pd.DataFrame(metric_data) + + def _get_best_metric_name(self) -> str: + """Get the primary metric name for a given task.""" + match self.settings.task: + case "binary": + return "f1" + case "count_regimes" | "density_regimes": + return "f1_weighted" + case _: # regression tasks + return "r2" + + @staticmethod + def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame: + """Convert a list of TrainingResult objects to a DataFrame for display.""" + records = [] + for tr in training_results: + info = tr.display_info + best_metric_name = tr._get_best_metric_name() + + record = { + "Experiment": tr.experiment if tr.experiment else "N/A", + "Task": info.task, + "Model": info.model, + "Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name, + "Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"), + "Score-Metric": best_metric_name.title(), + "Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name), + "Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name), + "Trials": len(tr.results), + "Path": str(tr.path.name), + } + records.append(record) + return pd.DataFrame.from_records(records) + @st.cache_data def load_all_training_results() -> list[TrainingResult]: + """Load all training results from the results directory.""" results_dir = entropice.utils.paths.RESULTS_DIR training_results: list[TrainingResult] = [] for result_path in results_dir.iterdir(): @@ -136,50 +220,22 @@ def load_all_training_results() -> list[TrainingResult]: continue try: training_result = TrainingResult.from_path(result_path) + training_results.append(training_result) except FileNotFoundError as e: - st.warning(f"Skipping incomplete training result: {e}") - continue - training_results.append(training_result) + is_experiment_dir = False + for experiment_path in result_path.iterdir(): + if not experiment_path.is_dir(): + continue + try: + experiment_name = experiment_path.name + training_result = TrainingResult.from_path(experiment_path, experiment_name) + training_results.append(training_result) + is_experiment_dir = True + except FileNotFoundError as e2: + st.warning(f"Skipping incomplete training result: {e2}") + if not is_experiment_dir: + st.warning(f"Skipping incomplete training result: {e}") # Sort by creation time (most recent first) training_results.sort(key=lambda tr: tr.created_at, reverse=True) return training_results - - -def load_all_training_data( - e: DatasetEnsemble, -) -> dict[Task, CategoricalTrainingDataset]: - """Load training data for all three tasks. - - Args: - e: DatasetEnsemble object. - - Returns: - Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. - - """ - dataset = e.create(filter_target_col=e.covcol) - return { - "binary": e._cat_and_split(dataset, "binary", device="cpu"), - "count": e._cat_and_split(dataset, "count", device="cpu"), - "density": e._cat_and_split(dataset, "density", device="cpu"), - } - - -def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]: - """Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5). - - Args: - e: DatasetEnsemble object. - source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'. - - Returns: - xarray.Dataset with the raw data for the specified source. - - """ - targets = e._read_target() - - # Load the member data lazily to get metadata - ds = e._read_member(source, targets, lazy=False) - - return ds, targets diff --git a/src/entropice/dashboard/utils/stats.py b/src/entropice/dashboard/utils/stats.py index 004c7f5..9e9ddaf 100644 --- a/src/entropice/dashboard/utils/stats.py +++ b/src/entropice/dashboard/utils/stats.py @@ -3,29 +3,28 @@ - Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination """ +import pickle from collections import defaultdict from dataclasses import asdict, dataclass from typing import Literal -import geopandas as gpd import pandas as pd -import streamlit as st -import xarray as xr from stopuhr import stopwatch import entropice.spatial.grids import entropice.utils.paths from entropice.dashboard.utils.loaders import TrainingResult -from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol +from entropice.ml.dataset import DatasetEnsemble from entropice.utils.types import ( - Grid, GridLevel, L2SourceDataset, TargetDataset, Task, + TemporalMode, all_l2_source_datasets, all_target_datasets, all_tasks, + all_temporal_modes, grid_configs, ) @@ -41,90 +40,63 @@ class MemberStatistics: size_bytes: int # Size of this member's data on disk in bytes @classmethod - def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics": - if member == "AlphaEarth": - store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level) - elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: - era5_agg = member.split("-")[1] - store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type] - elif member == "ArcticDEM": - store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level) - else: - raise NotImplementedError(f"Member {member} not implemented.") + def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]: + """Pre-compute the statistics for a specific dataset member.""" + member_stats = {} + for member in all_l2_source_datasets: + ds = e.read_member(member, lazy=True) + size_bytes = ds.nbytes - size_bytes = store.stat().st_size - ds = xr.open_zarr(store, consolidated=False) + n_cols_member = len(ds.data_vars) + for dim in ds.sizes: + if dim != "cell_ids": + n_cols_member *= ds.sizes[dim] - # Delete all coordinates which are not in the dimension - for coord in ds.coords: - if coord not in ds.dims: - ds = ds.drop_vars(coord) - n_cols_member = len(ds.data_vars) - for dim in ds.sizes: - if dim != "cell_ids": - n_cols_member *= ds.sizes[dim] - - return cls( - feature_count=n_cols_member, - variable_names=list(ds.data_vars), - dimensions=dict(ds.sizes), - coordinates=list(ds.coords), - size_bytes=size_bytes, - ) + member_stats[member] = cls( + feature_count=n_cols_member, + variable_names=list(ds.data_vars), + dimensions=dict(ds.sizes), + coordinates=list(ds.coords), + size_bytes=size_bytes, + ) + return member_stats @dataclass(frozen=True) class TargetStatistics: """Statistics for a specific target dataset.""" - training_cells: dict[Task, int] # Number of cells used for training - coverage: dict[Task, float] # Percentage of total cells covered by training data - class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task - class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task - size_bytes: int # Size of the target dataset on disk in bytes + training_cells: int # Number of cells used for training + coverage: float # Percentage of total cells covered by training data + class_counts: dict[str, int] | None # Class name to count mapping + class_distribution: dict[str, float] | None # Class name to percentage mapping @classmethod - def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics": - if target == "darts_rts": - target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level) - elif target == "darts_mllabels": - target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True) - else: - raise NotImplementedError(f"Target {target} not implemented.") - target_gdf = gpd.read_parquet(target_store) - size_bytes = target_store.stat().st_size - training_cells: dict[Task, int] = {} - training_coverage: dict[Task, float] = {} - class_counts: dict[Task, dict[str, int]] = {} - class_distribution: dict[Task, dict[str, float]] = {} + def compute(cls, e: DatasetEnsemble, target: TargetDataset, total_cells: int) -> dict[Task, "TargetStatistics"]: + """Pre-compute the statistics for a specific target dataset.""" + target_stats = {} for task in all_tasks: - task_col = taskcol[task][target] - cov_col = covcol[target] + targets = e.get_targets(target=target, task=task) - task_gdf = target_gdf[target_gdf[cov_col]] - training_cells[task] = len(task_gdf) - training_coverage[task] = len(task_gdf) / total_cells * 100 + training_cells = len(targets) + training_coverage = len(targets) / total_cells * 100 - model_labels = task_gdf[task_col].dropna() - if task == "binary": - binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category") - elif task == "count": - binned = bin_values(model_labels.astype(int), task=task) - elif task == "density": - binned = bin_values(model_labels, task=task) + if task in ["count", "density"]: + class_counts = None + class_distribution = None else: - raise ValueError("Invalid task.") - counts = binned.value_counts() - distribution = counts / counts.sum() * 100 - class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment] - class_distribution[task] = distribution.to_dict() - return TargetStatistics( - training_cells=training_cells, - coverage=training_coverage, - class_counts=class_counts, - class_distribution=class_distribution, - size_bytes=size_bytes, - ) + assert targets["y"].dtype == "category", "Classification tasks must have categorical target dtype." + counts = targets["y"].value_counts() + distribution = counts / counts.sum() * 100 + class_counts = counts.to_dict() + class_distribution = distribution.to_dict() + target_stats[task] = cls( + training_cells=training_cells, + coverage=training_coverage, + class_counts=class_counts, + class_distribution=class_distribution, + ) + return target_stats @dataclass(frozen=True) @@ -139,204 +111,167 @@ class DatasetStatistics: total_cells: int # Total number of grid cells potentially covered size_bytes: int # Size of the dataset on disk in bytes members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member - target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset + target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task @staticmethod - def get_sample_count_df( - all_stats: dict[GridLevel, "DatasetStatistics"], + def get_target_sample_count_df( + all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]], ) -> pd.DataFrame: """Convert sample count data to DataFrame.""" rows = [] for grid_config in grid_configs: - stats = all_stats[grid_config.id] - for target_name, target_stats in stats.target.items(): - for task in all_tasks: - training_cells = target_stats.training_cells[task] - coverage = target_stats.coverage[task] + mode_stats = all_stats[grid_config.id] + for temporal_mode, stats in mode_stats.items(): + for target_name, target_stats in stats.target.items(): + # We can assume that all tasks have equal counts, since they only affect distribution, not presence + target_task_stats = target_stats["binary"] + n_samples = target_task_stats.training_cells + coverage = target_task_stats.coverage + for task, target_task_stats in target_stats.items(): + assert n_samples == target_task_stats.training_cells, ( + "Inconsistent sample counts across tasks for the same target." + ) + assert coverage == target_task_stats.coverage, ( + "Inconsistent coverage across tasks for the same target." + ) rows.append( { "Grid": grid_config.display_name, - "Target": target_name.replace("darts_", ""), - "Task": task.capitalize(), - "Samples (Coverage)": training_cells, + "Target": target_name.replace("_", " ").title(), + "Temporal Mode": str(temporal_mode).capitalize(), + "Samples (Coverage)": n_samples, "Coverage %": coverage, "Grid_Level_Sort": grid_config.sort_key, } ) - return pd.DataFrame(rows) + return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort") @staticmethod - def get_feature_count_df( - all_stats: dict[GridLevel, "DatasetStatistics"], + def get_inference_sample_count_df( + all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]], ) -> pd.DataFrame: """Convert feature count data to DataFrame.""" rows = [] for grid_config in grid_configs: - stats = all_stats[grid_config.id] - data_sources = list(stats.members.keys()) - - # Determine minimum cells across all data sources + mode_stats = all_stats[grid_config.id] + # We can assume that all temporal modes have equal counts, since they only affect features, not samples + stats = mode_stats["feature"] + assert len(stats.members) > 0, "Dataset must have at least one member." min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) - # Get sample count from first target dataset (darts_rts) - first_target = stats.target["darts_rts"] - total_samples = first_target.training_cells["binary"] + for temporal_mode, stats in mode_stats.items(): + assert len(stats.members) > 0, "Dataset must have at least one member." + assert ( + min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) == min_cells + ), "Inconsistent inference cell counts across temporal modes." rows.append( { "Grid": grid_config.display_name, - "Total Features": stats.total_features, - "Data Sources": len(data_sources), "Inference Cells": min_cells, - "Total Samples": total_samples, "Grid_Level_Sort": grid_config.sort_key, } ) - return pd.DataFrame(rows) + return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort") + + @staticmethod + def get_comparison_df( + all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]], + ) -> pd.DataFrame: + """Convert comparison data to DataFrame for detailed table.""" + rows = [] + for grid_config in grid_configs: + mode_stats = all_stats[grid_config.id] + for temporal_mode, stats in mode_stats.items(): + min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) + rows.append( + { + "Grid": grid_config.display_name, + "Temporal Mode": str(temporal_mode).capitalize(), + "Total Features": f"{stats.total_features:,}", + "Total Cells": f"{stats.total_cells:,}", + "Min Cells": f"{min_cells:,}", + "Size (MB)": f"{stats.size_bytes / (1024 * 1024):.2f}", + "Grid_Level_Sort": grid_config.sort_key, + } + ) + return pd.DataFrame(rows).sort_values(["Grid_Level_Sort", "Temporal Mode"]).drop(columns="Grid_Level_Sort") @staticmethod def get_feature_breakdown_df( - all_stats: dict[GridLevel, "DatasetStatistics"], + all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]], ) -> pd.DataFrame: """Convert feature breakdown data to DataFrame for stacked/donut charts.""" rows = [] for grid_config in grid_configs: - stats = all_stats[grid_config.id] - for member_name, member_stats in stats.members.items(): - rows.append( - { - "Grid": grid_config.display_name, - "Data Source": member_name, - "Number of Features": member_stats.feature_count, - "Grid_Level_Sort": grid_config.sort_key, - } - ) - return pd.DataFrame(rows) + mode_stats = all_stats[grid_config.id] + added_year_temporal_mode = False + for temporal_mode, stats in mode_stats.items(): + if added_year_temporal_mode: + continue + if isinstance(temporal_mode, int): + # Only add one year-based temporal mode for the breakdown + added_year_temporal_mode = True + temporal_mode = "Year-Based" + for member_name, member_stats in stats.members.items(): + rows.append( + { + "Grid": grid_config.display_name, + "Temporal Mode": temporal_mode.capitalize(), + "Member": member_name, + "Number of Features": member_stats.feature_count, + "Grid_Level_Sort": grid_config.sort_key, + } + ) + return ( + pd.DataFrame(rows) + .sort_values(["Grid_Level_Sort", "Temporal Mode", "Member"]) + .drop(columns="Grid_Level_Sort") + ) -@st.cache_data -def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]: - dataset_stats: dict[GridLevel, DatasetStatistics] = {} +# @st.cache_data # ty:ignore[invalid-argument-type] +def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode, DatasetStatistics]]: + """Precompute dataset statistics for all grid-level combinations and temporal modes.""" + cache_file = entropice.utils.paths.get_dataset_stats_cache() + if cache_file.exists(): + with open(cache_file, "rb") as f: + dataset_stats = pickle.load(f) + return dataset_stats + + dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {} for grid_config in grid_configs: + dataset_stats[grid_config.id] = {} with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"): grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered total_cells = len(grid_gdf) assert total_cells > 0, "Grid must contain at least one cell." - target_statistics: dict[TargetDataset, TargetStatistics] = {} - for target in all_target_datasets: - target_statistics[target] = TargetStatistics.compute( - grid=grid_config.grid, - level=grid_config.level, - target=target, + for temporal_mode in all_temporal_modes: + e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode) + target_statistics = {} + for target in all_target_datasets: + if isinstance(temporal_mode, int) and target == "darts_mllabels": + # darts_mllabels does not support year-based temporal modes + continue + target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells) + member_statistics = MemberStatistics.compute(e) + + total_features = sum(ms.feature_count for ms in member_statistics.values()) + total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + dataset_stats[grid_config.id][temporal_mode] = DatasetStatistics( + total_features=total_features, total_cells=total_cells, - ) - member_statistics: dict[L2SourceDataset, MemberStatistics] = {} - for member in all_l2_source_datasets: - member_statistics[member] = MemberStatistics.compute( - grid=grid_config.grid, level=grid_config.level, member=member + size_bytes=total_size_bytes, + members=member_statistics, + target=target_statistics, ) - total_features = sum(ms.feature_count for ms in member_statistics.values()) - total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum( - ts.size_bytes for ts in target_statistics.values() - ) - dataset_stats[grid_config.id] = DatasetStatistics( - total_features=total_features, - total_cells=total_cells, - size_bytes=total_size_bytes, - members=member_statistics, - target=target_statistics, - ) + with open(cache_file, "wb") as f: + pickle.dump(dataset_stats, f) return dataset_stats -@dataclass(frozen=True) -class EnsembleMemberStatistics: - n_features: int # Number of features from this member in the ensemble - p_features: float # Percentage of features from this member in the ensemble - n_nanrows: int # Number of rows which contain any NaN - size_bytes: int # Size of this member's data in the ensemble in bytes - - @classmethod - def compute( - cls, - dataset: gpd.GeoDataFrame, - member: L2SourceDataset, - n_features: int, - ) -> "EnsembleMemberStatistics": - if member == "AlphaEarth": - member_dataset = dataset[dataset.columns.str.startswith("embeddings_")] - elif member == "ArcticDEM": - member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")] - elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: - era5_cols = dataset.columns.str.startswith("era5_") - cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2) - cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3) - cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter")) - if member == "ERA5-yearly": - member_dataset = dataset[cols_with_three_splits] - elif member == "ERA5-seasonal": - member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits] - elif member == "ERA5-shoulder": - member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits] - else: - raise NotImplementedError(f"Member {member} not implemented.") - size_bytes_member = member_dataset.memory_usage(deep=True).sum() - n_features_member = len(member_dataset.columns) - p_features_member = n_features_member / n_features - n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum() - return EnsembleMemberStatistics( - n_features=n_features_member, - p_features=p_features_member, - n_nanrows=n_rows_with_nan_member, - size_bytes=size_bytes_member, - ) - - -@dataclass(frozen=True) -class EnsembleDatasetStatistics: - """Statistics for a specified composition / ensemble at a specific grid and level. - - These statistics are meant to be only computed on user demand, thus by loading a dataset into memory. - That way, the real number of features and cells after NaN filtering can be reported. - """ - - n_features: int # Number of features in the dataset - n_cells: int # Number of grid cells covered - n_nanrows: int # Number of rows which contain any NaN - size_bytes: int # Size of the dataset on disk in bytes - members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member - - @classmethod - def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics": - dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) - # Assert that no column in all-nan - assert not dataset.isna().all("index").any(), "Some input columns are all NaN" - - size_bytes = dataset.memory_usage(deep=True).sum() - n_features = len(dataset.columns) - n_cells = len(dataset) - - # Number of rows which contain any NaN - n_rows_with_nan = dataset.isna().any(axis=1).sum() - - member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {} - for member in ensemble.members: - member_statistics[member] = EnsembleMemberStatistics.compute( - dataset=dataset, - member=member, - n_features=n_features, - ) - return cls( - n_features=n_features, - n_cells=n_cells, - n_nanrows=n_rows_with_nan, - size_bytes=size_bytes, - members=member_statistics, - ) - - @dataclass(frozen=True) class TrainingDatasetStatistics: n_samples: int # Total number of samples in the dataset @@ -346,14 +281,14 @@ class TrainingDatasetStatistics: n_test_samples: int # Number of cells used for testing train_test_ratio: float # Ratio of training to test samples - class_labels: list[str] # Ordered list of class labels - class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class - n_classes: int # Number of classes + class_labels: list[str] | None # Ordered list of class labels + class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] | None # Min/max raw values + n_classes: int | None # Number of classes - training_class_counts: dict[str, int] # Class counts in training set - training_class_distribution: dict[str, float] # Class percentages in training set - test_class_counts: dict[str, int] # Class counts in test set - test_class_distribution: dict[str, float] # Class percentages in test set + training_class_counts: dict[str, int] | None # Class counts in training set + training_class_distribution: dict[str, float] | None # Class percentages in training set + test_class_counts: dict[str, int] | None # Class counts in test set + test_class_distribution: dict[str, float] | None # Class percentages in test set raw_value_min: float # Minimum raw target value raw_value_max: float # Maximum raw target value @@ -361,7 +296,7 @@ class TrainingDatasetStatistics: raw_value_median: float # Median raw target value raw_value_std: float # Standard deviation of raw target values - imbalance_ratio: float # Smallest class count / largest class count (overall) + imbalance_ratio: float | None # Smallest class count / largest class count (overall) size_bytes: int # Total memory usage of features in bytes @classmethod @@ -369,58 +304,64 @@ class TrainingDatasetStatistics: cls, ensemble: DatasetEnsemble, task: Task, - dataset: gpd.GeoDataFrame | None = None, + target: TargetDataset, ) -> "TrainingDatasetStatistics": - dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol) - categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu") + training_dataset = ensemble.create_training_set(task=task, target=target) # Sample counts - n_samples = len(categorical_dataset) - n_training_samples = len(categorical_dataset.y.train) - n_test_samples = len(categorical_dataset.y.test) + n_samples = len(training_dataset) + n_training_samples = (training_dataset.split == "train").sum() + n_test_samples = (training_dataset.split == "test").sum() train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0 # Feature statistics - n_features = len(categorical_dataset.X.data.columns) - feature_names = list(categorical_dataset.X.data.columns) - size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum() + n_features = len(training_dataset.features.columns) + feature_names = training_dataset.features.columns.tolist() + size_bytes = training_dataset.features.memory_usage(deep=True).sum() # Class information - class_labels = categorical_dataset.y.labels - class_intervals = categorical_dataset.y.intervals - n_classes = len(class_labels) + class_labels = training_dataset.target_labels + if class_labels is None: + class_intervals = None + n_classes = None + training_class_counts = None + training_class_distribution = None + test_class_counts = None + test_class_distribution = None + imbalance_ratio = None + else: + class_intervals = training_dataset.target_intervals + n_classes = len(class_labels) - # Training class distribution - train_y_series = pd.Series(categorical_dataset.y.train) - train_counts = train_y_series.value_counts().sort_index() - training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)} - train_total = sum(training_class_counts.values()) - training_class_distribution = { - k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items() - } + train_y = training_dataset.targets[training_dataset.split == "train"]["y"] + test_y = training_dataset.targets[training_dataset.split == "test"]["y"] - # Test class distribution - test_y_series = pd.Series(categorical_dataset.y.test) - test_counts = test_y_series.value_counts().sort_index() - test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)} - test_total = sum(test_class_counts.values()) - test_class_distribution = { - k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items() - } + train_counts = train_y.value_counts().sort_index() + training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)} + train_total = sum(training_class_counts.values()) + training_class_distribution = { + k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items() + } + test_counts = test_y.value_counts().sort_index() + test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)} + test_total = sum(test_class_counts.values()) + test_class_distribution = { + k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items() + } + + # Imbalance ratio (smallest class / largest class across both splits) + all_counts = list(train_counts.values()) + list(test_class_counts.values()) + nonzero_counts = [c for c in all_counts if c > 0] + imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0 # Raw value statistics - raw_values = categorical_dataset.y.raw_values + raw_values = training_dataset.targets["z"] raw_value_min = float(raw_values.min()) raw_value_max = float(raw_values.max()) raw_value_mean = float(raw_values.mean()) raw_value_median = float(raw_values.median()) raw_value_std = float(raw_values.std()) - # Imbalance ratio (smallest class / largest class across both splits) - all_counts = list(training_class_counts.values()) + list(test_class_counts.values()) - nonzero_counts = [c for c in all_counts if c > 0] - imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0 - return cls( n_samples=n_samples, n_features=n_features, diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 799a392..0a87a8e 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -1,508 +1,17 @@ """Overview page: List of available result directories with some summary statistics.""" -from datetime import datetime -from typing import cast - -import pandas as pd import streamlit as st from stopuhr import stopwatch -from entropice.dashboard.plots.overview import ( - create_feature_breakdown_donut, - create_feature_count_stacked_bar, - create_feature_distribution_pie, - create_inference_cells_bar, - create_sample_count_bar_chart, +from entropice.dashboard.sections.dataset_statistics import render_dataset_statistics +from entropice.dashboard.sections.experiment_results import ( + render_experiment_results, + render_training_results_summary, ) -from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.stats import ( - DatasetStatistics, load_all_default_dataset_statistics, ) -from entropice.utils.types import ( - GridConfig, - L2SourceDataset, - TargetDataset, - grid_configs, -) - - -def render_sample_count_overview(): - """Render overview of sample counts per task+target+grid+level combination.""" - st.markdown( - """ - This visualization shows the number of available training samples for each combination of: - - **Task**: binary, count, density - - **Target Dataset**: darts_rts, darts_mllabels - - **Grid System**: hex, healpix - - **Grid Level**: varying by grid type - """ - ) - - # Get sample count DataFrame from cache - all_stats = load_all_default_dataset_statistics() - sample_df = DatasetStatistics.get_sample_count_df(all_stats) - - # Get color palettes for each target dataset - n_tasks = sample_df["Task"].nunique() - target_color_maps = { - "rts": get_palette("task_types", n_colors=n_tasks), - "mllabels": get_palette("data_sources", n_colors=n_tasks), - } - - # Create and display bar chart - fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps) - st.plotly_chart(fig, use_container_width=True) - - # Display full table with formatting - st.markdown("#### Detailed Sample Counts") - display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy() - - # Format numbers with commas - display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}") - # Format coverage as percentage with 2 decimal places - display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%") - - st.dataframe(display_df, hide_index=True, use_container_width=True) - - -def render_feature_count_comparison(): - """Render static comparison of feature counts across all grid configurations.""" - st.markdown( - """ - Comparing dataset characteristics for all grid configurations with all data sources enabled. - - **Features**: Total number of input features from all data sources - - **Spatial Coverage**: Number of grid cells with complete data coverage - """ - ) - - # Get data from cache - all_stats = load_all_default_dataset_statistics() - comparison_df = DatasetStatistics.get_feature_count_df(all_stats) - breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) - breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") - - # Get all unique data sources and create color map - unique_sources = sorted(breakdown_df["Data Source"].unique()) - n_sources = len(unique_sources) - source_color_list = get_palette("data_sources", n_colors=n_sources) - source_color_map = dict(zip(unique_sources, source_color_list)) - - # Create and display stacked bar chart - fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map) - st.plotly_chart(fig, use_container_width=True) - - # Add spatial coverage metric - n_grids = len(comparison_df) - grid_colors = get_palette("grid_configs", n_colors=n_grids) - - fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors) - st.plotly_chart(fig_cells, use_container_width=True) - - # Display full comparison table with formatting - st.markdown("#### Detailed Comparison Table") - display_df = comparison_df[ - [ - "Grid", - "Total Features", - "Data Sources", - "Inference Cells", - ] - ].copy() - - # Format numbers with commas - for col in ["Total Features", "Inference Cells"]: - display_df[col] = display_df[col].apply(lambda x: f"{x:,}") - - st.dataframe(display_df, hide_index=True, use_container_width=True) - - -@st.fragment -def render_feature_count_explorer(): - """Render interactive detailed configuration explorer using fragments.""" - st.markdown("Select specific grid configuration and data sources for detailed statistics") - - # Grid selection - grid_options = [gc.display_name for gc in grid_configs] - - col1, col2 = st.columns(2) - - with col1: - grid_level_combined = st.selectbox( - "Grid Configuration", - options=grid_options, - index=0, - help="Select the grid system and resolution level", - key="feature_grid_select", - ) - - with col2: - target = st.selectbox( - "Target Dataset", - options=["darts_rts", "darts_mllabels"], - index=0, - help="Select the target dataset", - key="feature_target_select", - ) - - # Find the selected grid config - selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) - - # Get available members from the stats - all_stats = load_all_default_dataset_statistics() - stats = all_stats[selected_grid_config.id] - available_members = cast(list[L2SourceDataset], list(stats.members.keys())) - - # Members selection - st.markdown("#### Select Data Sources") - - all_members = cast( - list[L2SourceDataset], - ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"], - ) - - # Use columns for checkboxes - cols = st.columns(len(all_members)) - selected_members: list[L2SourceDataset] = [] - - for idx, member in enumerate(all_members): - with cols[idx]: - default_value = member in available_members - if st.checkbox(member, value=default_value, key=f"feature_member_{member}"): - selected_members.append(cast(L2SourceDataset, member)) - - # Show results if at least one member is selected - if selected_members: - # Get statistics from cache (already loaded) - grid_stats = all_stats[selected_grid_config.id] - - # Filter to selected members only - selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members} - - # Calculate total features for selected members - total_features = sum(ms.feature_count for ms in selected_member_stats.values()) - - # Get target stats - target_stats = grid_stats.target[cast(TargetDataset, target)] - - # High-level metrics - col1, col2, col3, col4, col5 = st.columns(5) - with col1: - st.metric("Total Features", f"{total_features:,}") - with col2: - # Calculate minimum cells across all data sources (for inference capability) - min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values()) - st.metric( - "Inference Cells", - f"{min_cells:,}", - help="Number of union of cells across all data sources", - ) - with col3: - st.metric("Data Sources", len(selected_members)) - with col4: - # Use binary task training cells as sample count - st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}") - with col5: - # Calculate total data points - total_points = total_features * target_stats.training_cells["binary"] - st.metric("Total Data Points", f"{total_points:,}") - - # Feature breakdown by source - st.markdown("#### Feature Breakdown by Data Source") - - breakdown_data = [] - for member, member_stats in selected_member_stats.items(): - breakdown_data.append( - { - "Data Source": member, - "Number of Features": member_stats.feature_count, - "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%", - } - ) - - breakdown_df = pd.DataFrame(breakdown_data) - - # Get all unique data sources and create color map - unique_sources = sorted(breakdown_df["Data Source"].unique()) - n_sources = len(unique_sources) - source_color_list = get_palette("data_sources", n_colors=n_sources) - source_color_map = dict(zip(unique_sources, source_color_list)) - - # Create and display pie chart - fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map) - st.plotly_chart(fig, width="stretch") - - # Show detailed table - st.dataframe(breakdown_df, hide_index=True, width="stretch") - - # Detailed member information - with st.expander("📦 Detailed Source Information", expanded=False): - for member, member_stats in selected_member_stats.items(): - st.markdown(f"### {member}") - - metric_cols = st.columns(4) - with metric_cols[0]: - st.metric("Features", member_stats.feature_count) - with metric_cols[1]: - st.metric("Variables", len(member_stats.variable_names)) - with metric_cols[2]: - dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()]) - st.metric("Shape", dim_str) - with metric_cols[3]: - total_points = 1 - for dim_size in member_stats.dimensions.values(): - total_points *= dim_size - st.metric("Data Points", f"{total_points:,}") - - # Variables - st.markdown("**Variables:**") - vars_html = " ".join( - [ - f'{v}' - for v in member_stats.variable_names - ] - ) - st.markdown(vars_html, unsafe_allow_html=True) - - # Dimensions - st.markdown("**Dimensions:**") - dim_html = " ".join( - [ - f'' - f"{dim_name}: {dim_size}" - for dim_name, dim_size in member_stats.dimensions.items() - ] - ) - st.markdown(dim_html, unsafe_allow_html=True) - - # Size on disk - size_mb = member_stats.size_bytes / (1024 * 1024) - st.markdown(f"**Size on Disk:** {size_mb:.2f} MB") - - st.markdown("---") - else: - st.info("👆 Select at least one data source to see feature statistics") - - -def render_dataset_analysis(): - """Render the dataset analysis section with sample and feature counts.""" - st.header("📈 Dataset Analysis") - - # Create tabs for different analysis views - analysis_tabs = st.tabs( - [ - "📊 Training Samples", - "📈 Dataset Characteristics", - "🔍 Feature Breakdown", - "⚙️ Configuration Explorer", - ] - ) - - with analysis_tabs[0]: - st.subheader("Training Samples by Configuration") - render_sample_count_overview() - - with analysis_tabs[1]: - st.subheader("Dataset Characteristics Across Grid Configurations") - render_feature_count_comparison() - - with analysis_tabs[2]: - st.subheader("Feature Breakdown by Data Source") - # Get data from cache - all_stats = load_all_default_dataset_statistics() - comparison_df = DatasetStatistics.get_feature_count_df(all_stats) - breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) - breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") - - # Get all unique data sources and create color map - unique_sources = sorted(breakdown_df["Data Source"].unique()) - n_sources = len(unique_sources) - source_color_list = get_palette("data_sources", n_colors=n_sources) - source_color_map = dict(zip(unique_sources, source_color_list)) - - st.markdown("Showing percentage contribution of each data source across all grid configurations") - - # Sparse Resolution girds - for res in ["sparse", "low", "medium"]: - cols = st.columns(2) - with cols[0]: - grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"] - for gc in grid_configs_res: - grid_display = gc.display_name - grid_data = breakdown_df[breakdown_df["Grid"] == grid_display] - fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map) - st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}") - with cols[1]: - grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"] - for gc in grid_configs_res: - grid_display = gc.display_name - grid_data = breakdown_df[breakdown_df["Grid"] == grid_display] - fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map) - st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}") - - # Create donut charts for each grid configuration - # num_grids = len(comparison_df) - # cols_per_row = 3 - # num_rows = (num_grids + cols_per_row - 1) // cols_per_row - - # for row_idx in range(num_rows): - # cols = st.columns(cols_per_row) - # for col_idx in range(cols_per_row): - # grid_idx = row_idx * cols_per_row + col_idx - # if grid_idx < num_grids: - # grid_config = comparison_df.iloc[grid_idx]["Grid"] - # grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] - - # with cols[col_idx]: - # fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map) - # st.plotly_chart(fig, use_container_width=True) - - with analysis_tabs[3]: - st.subheader("Interactive Configuration Explorer") - render_feature_count_explorer() - - -def render_training_results_summary(training_results): - """Render summary metrics for training results.""" - st.header("📊 Training Results Summary") - col1, col2, col3, col4 = st.columns(4) - - with col1: - tasks = {tr.settings.task for tr in training_results} - st.metric("Tasks", len(tasks)) - - with col2: - grids = {tr.settings.grid for tr in training_results} - st.metric("Grid Types", len(grids)) - - with col3: - models = {tr.settings.model for tr in training_results} - st.metric("Model Types", len(models)) - - with col4: - latest = training_results[0] # Already sorted by creation time - latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d") - st.metric("Latest Run", latest_date) - - -def render_experiment_results(training_results): - """Render detailed experiment results table and expandable details.""" - st.header("🎯 Experiment Results") - st.subheader("Results Table") - - # Build a summary dataframe - summary_data = [] - for tr in training_results: - # Extract best scores from the results dataframe - score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")] - - best_scores = {} - for col in score_cols: - metric_name = col.replace("mean_test_", "") - best_score = tr.results[col].max() - best_scores[metric_name] = best_score - - # Get primary metric (usually the first one or accuracy) - primary_metric = ( - "accuracy" - if "mean_test_accuracy" in tr.results.columns - else score_cols[0].replace("mean_test_", "") - if score_cols - else "N/A" - ) - primary_score = best_scores.get(primary_metric, 0.0) - - summary_data.append( - { - "Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"), - "Task": tr.settings.task, - "Grid": tr.settings.grid, - "Level": tr.settings.level, - "Model": tr.settings.model, - f"Best {primary_metric.title()}": f"{primary_score:.4f}", - "Trials": len(tr.results), - "Path": str(tr.path.name), - } - ) - - summary_df = pd.DataFrame(summary_data) - - # Display with color coding for best scores - st.dataframe( - summary_df, - width="stretch", - hide_index=True, - ) - - st.divider() - - # Expandable details for each result - st.subheader("Individual Experiment Details") - - for tr in training_results: - display_name = ( - f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}" - ) - with st.expander(display_name): - col1, col2 = st.columns([1, 2]) - - with col1: - st.write("**Configuration:**") - st.write(f"- **Task:** {tr.settings.task}") - st.write(f"- **Grid:** {tr.settings.grid}") - st.write(f"- **Level:** {tr.settings.level}") - st.write(f"- **Model:** {tr.settings.model}") - st.write(f"- **CV Splits:** {tr.settings.cv_splits}") - st.write(f"- **Classes:** {tr.settings.classes}") - - st.write("\n**Files:**") - st.write("- 📊 search_results.parquet") - st.write("- 🧮 best_estimator_state.nc") - st.write("- 🎯 predicted_probabilities.parquet") - st.write("- ⚙️ search_settings.toml") - - with col2: - st.write("**Best Scores:**") - - # Extract all test scores - score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")] - - if score_cols: - metric_data = [] - for col in score_cols: - metric_name = col.replace("mean_test_", "").title() - best_score = tr.results[col].max() - mean_score = tr.results[col].mean() - std_score = tr.results[col].std() - - metric_data.append( - { - "Metric": metric_name, - "Best": f"{best_score:.4f}", - "Mean": f"{mean_score:.4f}", - "Std": f"{std_score:.4f}", - } - ) - - metric_df = pd.DataFrame(metric_data) - st.dataframe(metric_df, width="stretch", hide_index=True) - else: - st.write("No test scores found in results.") - - # Show parameter space explored - if "initial_K" in tr.results.columns: # Common parameter - st.write("\n**Parameter Ranges Explored:**") - for param in ["initial_K", "eps_cl", "eps_e"]: - if param in tr.results.columns: - min_val = tr.results[param].min() - max_val = tr.results[param].max() - unique_vals = tr.results[param].nunique() - st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") - - st.write(f"\n**Path:** `{tr.path}`") def render_overview_page(): @@ -537,7 +46,8 @@ def render_overview_page(): st.divider() # Render dataset analysis section - render_dataset_analysis() + all_stats = load_all_default_dataset_statistics() + render_dataset_statistics(all_stats) st.balloons() stopwatch.summary() diff --git a/src/entropice/ml/autogluon_training.py b/src/entropice/ml/autogluon_training.py index 5070204..869d564 100644 --- a/src/entropice/ml/autogluon_training.py +++ b/src/entropice/ml/autogluon_training.py @@ -48,7 +48,7 @@ class AutoGluonSettings: class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings): """Combined settings for AutoGluon training.""" - classes: list[str] + classes: list[str] | None problem_type: str @@ -228,7 +228,7 @@ def autogluon_train( combined_settings = AutoGluonTrainingSettings( **asdict(settings), **asdict(dataset_ensemble), - classes=training_data.y.labels, + classes=training_data.target_labels, problem_type=problem_type, ) settings_file = results_dir / "training_settings.toml" diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index d4bbfea..f74deb8 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -78,7 +78,7 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame: collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan pivcols = set(collapsed.index.names) - {"cell_ids"} collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols) - collapsed.columns = ["_".join(v) for v in collapsed.columns] + collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns] if use_dummy: collapsed = collapsed.dropna(how="all") return collapsed @@ -378,7 +378,8 @@ class DatasetEnsemble: case ("darts_v1" | "darts_v2", int()): version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment] target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version) - targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode) + targets = xr.open_zarr(target_store, consolidated=False) + targets = targets.sel(year=self.temporal_mode) case ("darts_mllabels", str()): # Years are not supported target_store = entropice.utils.paths.get_darts_file( grid=self.grid, level=self.level, version="mllabels" @@ -467,6 +468,10 @@ class DatasetEnsemble: # Apply the temporal mode match (member.split("-"), self.temporal_mode): + case (["ArcticDEM"], _): + pass # No temporal dimension + case (_, "feature"): + pass case (["ERA5", _] | ["AlphaEarth"], "synopsis"): ds_mean = ds.mean(dim="year") ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True) @@ -475,12 +480,8 @@ class DatasetEnsemble: {var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars} ) ds = xr.merge([ds_mean, ds_trend]) - case (["ArcticDEM"], "synopsis"): - pass # No temporal dimension case (_, int() as year): ds = ds.sel(year=year, drop=True) - case (_, "feature"): - pass case _: raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.") @@ -556,7 +557,7 @@ class DatasetEnsemble: cell_ids: pd.Series, era5_agg: Literal["yearly", "seasonal", "shoulder"], ) -> pd.DataFrame: - era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False) + era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=True) era5_df = _collapse_to_dataframe(era5) era5_df.columns = [f"era5_{col}" for col in era5_df.columns] # Ensure all target cell_ids are present, fill missing with NaN @@ -565,7 +566,10 @@ class DatasetEnsemble: @stopwatch("Preparing AlphaEarth Embeddings") def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame: - embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"] + embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=True) + # For non-synopsis modes there is only a single data variable "embeddings" + if self.temporal_mode != "synopsis": + embeddings = embeddings["embeddings"] embeddings_df = _collapse_to_dataframe(embeddings) embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns] # Ensure all target cell_ids are present, fill missing with NaN @@ -574,7 +578,7 @@ class DatasetEnsemble: @stopwatch("Preparing ArcticDEM") def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame: - arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True) + arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=False) if len(arcticdem["cell_ids"]) == 0: # No data for these cells - create empty DataFrame with expected columns # Use the Dataset metadata to determine column structure @@ -620,7 +624,7 @@ class DatasetEnsemble: batch_cell_ids = all_cell_ids.iloc[i : i + batch_size] yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode) - @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"]) + # @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"]) def create_training_set( self, task: Task, diff --git a/src/entropice/ml/training.py b/src/entropice/ml/training.py index c5ea037..37c4499 100644 --- a/src/entropice/ml/training.py +++ b/src/entropice/ml/training.py @@ -3,6 +3,7 @@ import pickle from dataclasses import asdict, dataclass from functools import partial +from pathlib import Path import cyclopts import numpy as np @@ -115,17 +116,19 @@ class TrainingSettings(DatasetEnsemble, CVSettings): def random_cv( dataset_ensemble: DatasetEnsemble, settings: CVSettings = CVSettings(), -): + experiment: str | None = None, +) -> Path: """Perform random cross-validation on the training dataset. Args: dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration. settings (CVSettings): The cross-validation settings. + experiment (str | None): Optional experiment name for results directory. """ # Since we use cuml and xgboost libraries, we can only enable array API for ESPA - use_array_api = settings.model == "espa" - device = "torch" if use_array_api else "cuda" + use_array_api = settings.model != "xgboost" + device = "torch" if settings.model == "espa" else "cuda" set_config(array_api_dispatch=use_array_api) print("Creating training data...") @@ -173,7 +176,8 @@ def random_cv( print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}") results_dir = get_cv_results_dir( - "random_search", + experiment=experiment, + name="random_search", grid=dataset_ensemble.grid, level=dataset_ensemble.level, task=settings.task, @@ -283,6 +287,7 @@ def random_cv( stopwatch.summary() print("Done.") + return results_dir if __name__ == "__main__": diff --git a/src/entropice/utils/paths.py b/src/entropice/utils/paths.py index 7708845..98a83b8 100644 --- a/src/entropice/utils/paths.py +++ b/src/entropice/utils/paths.py @@ -139,7 +139,14 @@ def get_features_cache(ensemble_id: str, cells_hash: str) -> Path: return cache_dir / f"cells_{cells_hash}.parquet" +def get_dataset_stats_cache() -> Path: + cache_dir = DATASET_ENSEMBLES_DIR / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / "dataset_stats.pckl" + + def get_cv_results_dir( + experiment: str | None, name: str, grid: Grid, level: int, @@ -147,7 +154,12 @@ def get_cv_results_dir( ) -> Path: gridname = _get_gridname(grid, level) now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}" + if experiment is not None: + experiment_dir = RESULTS_DIR / experiment + experiment_dir.mkdir(parents=True, exist_ok=True) + else: + experiment_dir = RESULTS_DIR + results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}" results_dir.mkdir(parents=True, exist_ok=True) return results_dir diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py index 1c228b1..322f2cd 100644 --- a/src/entropice/utils/types.py +++ b/src/entropice/utils/types.py @@ -15,12 +15,12 @@ type GridLevel = Literal[ "healpix9", "healpix10", ] -type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"] +type TargetDataset = Literal["darts_v1", "darts_mllabels"] type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"] # TODO: Consider implementing a "timeseries" temporal mode -type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024] +type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023] type Model = Literal["espa", "xgboost", "rf", "knn"] type Stage = Literal["train", "inference", "visualization"] @@ -37,17 +37,20 @@ class GridConfig: sort_key: str @classmethod - def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig": + def from_grid_level(cls, grid_level: GridLevel | tuple[Literal["hex", "healpix"], int]) -> "GridConfig": """Create a GridConfig from a GridLevel string.""" - if grid_level.startswith("hex"): - grid = "hex" - level = int(grid_level[3:]) - elif grid_level.startswith("healpix"): - grid = "healpix" - level = int(grid_level[7:]) + if isinstance(grid_level, str): + if grid_level.startswith("hex"): + grid = "hex" + level = int(grid_level[3:]) + elif grid_level.startswith("healpix"): + grid = "healpix" + level = int(grid_level[7:]) + else: + raise ValueError(f"Invalid grid level: {grid_level}") else: - raise ValueError(f"Invalid grid level: {grid_level}") - + grid, level = grid_level + grid_level: GridLevel = f"{grid}{level}" # ty:ignore[invalid-assignment] display_name = f"{grid.capitalize()}-{level}" resmap: dict[str, Literal["sparse", "low", "medium"]] = { @@ -74,8 +77,8 @@ class GridConfig: # Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"] -all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024] -all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"] +all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023] +all_target_datasets: list[TargetDataset] = ["darts_v1", "darts_mllabels"] all_l2_source_datasets: list[L2SourceDataset] = [ "ArcticDEM", "ERA5-shoulder", diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 0000000..4fd104a --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,222 @@ +"""Tests for training.py module, specifically random_cv function. + +This test suite validates the random_cv training function across all model-task +combinations using a minimal hex level 3 grid with synopsis temporal mode. + +Test Coverage: +- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn +- Device handling for each model type (torch/CUDA/cuML compatibility) +- Multi-label target dataset support +- Temporal mode configuration (synopsis) +- Output file creation and validation + +Running Tests: + # Run all training tests (18 tests total, ~3 iterations each) + pixi run pytest tests/test_training.py -v + + # Run only device handling tests + pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v + + # Run a specific model-task combination + pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v + +Note: Tests use minimal iterations (3) and level 3 grid for speed. +Full production runs use higher iteration counts (100-2000). +""" + +import shutil + +import pytest + +from entropice.ml.dataset import DatasetEnsemble +from entropice.ml.training import CVSettings, random_cv +from entropice.utils.types import Model, Task + + +@pytest.fixture(scope="module") +def test_ensemble(): + """Create a minimal DatasetEnsemble for testing. + + Uses hex level 3 grid with synopsis temporal mode for fast testing. + """ + return DatasetEnsemble( + grid="hex", + level=3, + temporal_mode="synopsis", + members=["AlphaEarth"], # Use only one member for faster tests + add_lonlat=True, + ) + + +@pytest.fixture +def cleanup_results(): + """Clean up results directory after each test. + + This fixture collects the actual result directories created during tests + and removes them after the test completes. + """ + created_dirs = [] + + def register_dir(results_dir): + """Register a directory to be cleaned up.""" + created_dirs.append(results_dir) + return results_dir + + yield register_dir + + # Clean up only the directories created during this test + for results_dir in created_dirs: + if results_dir.exists(): + shutil.rmtree(results_dir) + + +# Model-task combinations to test +# Note: Not all combinations make sense, but we test all to ensure robustness +MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"] +TASKS: list[Task] = ["binary", "count", "density"] + + +class TestRandomCV: + """Test suite for random_cv function.""" + + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("task", TASKS) + def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results): + """Test random_cv with all model-task combinations. + + This test runs 3 iterations for each combination to verify: + - The function completes without errors + - Device handling works correctly for each model type + - All output files are created + """ + # Use darts_v1 as the primary target for all tests + settings = CVSettings( + n_iter=3, + task=task, + target="darts_v1", + model=model, + ) + + # Run the cross-validation and get the results directory + results_dir = random_cv( + dataset_ensemble=test_ensemble, + settings=settings, + experiment="test_training", + ) + cleanup_results(results_dir) + + # Verify results directory was created + assert results_dir.exists(), f"Results directory not created for {model=}, {task=}" + + # Verify all expected output files exist + expected_files = [ + "search_settings.toml", + "best_estimator_model.pkl", + "search_results.parquet", + "metrics.toml", + "predicted_probabilities.parquet", + ] + + # Add task-specific files + if task in ["binary", "count", "density"]: + # All tasks that use classification (including count/density when binned) + # Note: count and density without _regimes suffix might be regression + if task == "binary" or "_regimes" in task: + expected_files.append("confusion_matrix.nc") + + # Add model-specific files + if model in ["espa", "xgboost", "rf"]: + expected_files.append("best_estimator_state.nc") + + for filename in expected_files: + filepath = results_dir / filename + assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}" + + @pytest.mark.parametrize("model", MODELS) + def test_device_handling(self, test_ensemble, model: Model, cleanup_results): + """Test that device handling works correctly for each model type. + + Different models require different device configurations: + - espa: Uses torch with array API dispatch + - xgboost: Uses CUDA without array API dispatch + - rf/knn: GPU-accelerated via cuML + """ + settings = CVSettings( + n_iter=3, + task="binary", # Simple binary task for device testing + target="darts_v1", + model=model, + ) + + # This should complete without device-related errors + try: + results_dir = random_cv( + dataset_ensemble=test_ensemble, + settings=settings, + experiment="test_training", + ) + cleanup_results(results_dir) + except RuntimeError as e: + # Check if error is device-related + error_msg = str(e).lower() + device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"] + if any(keyword in error_msg for keyword in device_keywords): + pytest.fail(f"Device handling error for {model=}: {e}") + else: + # Re-raise non-device errors + raise + + def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results): + """Test random_cv with multi-label target dataset.""" + settings = CVSettings( + n_iter=3, + task="binary", + target="darts_mllabels", + model="espa", + ) + + # Run the cross-validation and get the results directory + results_dir = random_cv( + dataset_ensemble=test_ensemble, + settings=settings, + experiment="test_training", + ) + cleanup_results(results_dir) + + # Verify results were created + assert results_dir.exists(), "Results directory not created" + assert (results_dir / "search_settings.toml").exists() + + def test_temporal_mode_synopsis(self, cleanup_results): + """Test that temporal_mode='synopsis' is correctly used.""" + import toml + + ensemble = DatasetEnsemble( + grid="hex", + level=3, + temporal_mode="synopsis", + members=["AlphaEarth"], + add_lonlat=True, + ) + + settings = CVSettings( + n_iter=3, + task="binary", + target="darts_v1", + model="espa", + ) + + # This should use synopsis mode (all years aggregated) + results_dir = random_cv( + dataset_ensemble=ensemble, + settings=settings, + experiment="test_training", + ) + cleanup_results(results_dir) + + # Verify the settings were stored correctly + assert results_dir.exists(), "Results directory not created" + with open(results_dir / "search_settings.toml") as f: + stored_settings = toml.load(f) + + assert stored_settings["settings"]["temporal_mode"] == "synopsis" diff --git a/tests/validate_datasets.py b/tests/validate_datasets.py new file mode 100644 index 0000000..6e8286b --- /dev/null +++ b/tests/validate_datasets.py @@ -0,0 +1,38 @@ +import cyclopts +from rich import pretty, print, traceback + +from entropice.ml.dataset import DatasetEnsemble +from entropice.utils.types import all_temporal_modes, grid_configs + +pretty.install() +traceback.install() + +cli = cyclopts.App() + + +def _gather_ensemble_stats(e: DatasetEnsemble): + # Get a small sample of the cell ids + sample_cell_ids = e.cell_ids[:5] + features = e.make_features(sample_cell_ids) + + print( + f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]" + ) + print(f"Number of feature columns: {len(features.columns)}") + + for member in ["arcticdem", "embeddings", "era5"]: + member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")] + print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}") + print() + + +@cli.default() +def validate_datasets(): + for gc in grid_configs: + for temporal_mode in all_temporal_modes: + e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode) + _gather_ensemble_stats(e) + + +if __name__ == "__main__": + cli()