Fix training and overview page
This commit is contained in:
parent
4445834895
commit
c9c6af8370
17 changed files with 1643 additions and 1125 deletions
|
|
@ -68,7 +68,7 @@ dependencies = [
|
|||
"pandas-stubs>=2.3.3.251201,<3",
|
||||
"pytest>=9.0.2,<10",
|
||||
"autogluon-tabular[all,mitra]>=1.5.0",
|
||||
"shap>=0.50.0,<0.51",
|
||||
"shap>=0.50.0,<0.51", "h5py>=3.15.1,<4",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@ Pages:
|
|||
- Overview: List of available result directories with some summary statistics.
|
||||
- Training Data: Visualization of training data distributions.
|
||||
- Training Results Analysis: Analysis of training results and model performance.
|
||||
- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
|
||||
- Model State: Visualization of model state and features.
|
||||
- Inference: Visualization of inference results.
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
|
|
@ -26,13 +28,14 @@ def main():
|
|||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Overview": [overview_page],
|
||||
"Training": [training_data_page, training_analysis_page],
|
||||
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
||||
"Model State": [model_state_page],
|
||||
"Inference": [inference_page],
|
||||
}
|
||||
|
|
|
|||
458
src/entropice/dashboard/plots/dataset_statistics.py
Normal file
458
src/entropice/dashboard/plots/dataset_statistics.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""Visualization functions for the overview page.
|
||||
|
||||
This module contains reusable plotting functions for dataset analysis visualizations,
|
||||
including sample counts, feature counts, and dataset statistics.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
|
||||
|
||||
def _get_colors_from_members(unique_members: list[str]) -> dict[str, str]:
|
||||
unique_members = sorted(unique_members)
|
||||
unique_sources = {member.split("-")[0] for member in unique_members}
|
||||
source_member_map = {
|
||||
source: [member for member in unique_members if member.split("-")[0] == source] for source in unique_sources
|
||||
}
|
||||
source_color_map = {}
|
||||
for source, members in source_member_map.items():
|
||||
n_members = len(members) + 2 # Avoid very light/dark colors
|
||||
palette = get_palette(source, n_colors=n_members)
|
||||
for idx, member in enumerate(members, start=1):
|
||||
source_color_map[member] = palette[idx]
|
||||
return source_color_map
|
||||
|
||||
|
||||
def create_sample_count_bar_chart(sample_df: pd.DataFrame) -> go.Figure:
|
||||
"""Create bar chart showing sample counts by grid, target and temporal mode.
|
||||
|
||||
Args:
|
||||
sample_df: DataFrame with columns: Grid, Target, Samples (Coverage).
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
assert "Grid" in sample_df.columns
|
||||
assert "Target" in sample_df.columns
|
||||
assert "Temporal Mode" in sample_df.columns
|
||||
assert "Samples (Coverage)" in sample_df.columns
|
||||
|
||||
targets = sorted(sample_df["Target"].unique())
|
||||
modes = sorted(sample_df["Temporal Mode"].unique())
|
||||
|
||||
# Create subplots manually to have better control over colors
|
||||
fig = make_subplots(
|
||||
rows=1,
|
||||
cols=len(targets),
|
||||
subplot_titles=[f"Target: {target}" for target in targets],
|
||||
shared_yaxes=True,
|
||||
)
|
||||
# Get color palette for this plot
|
||||
colors = get_palette("Number of Samples", n_colors=len(modes))
|
||||
|
||||
# Track which modes have been added to legend
|
||||
modes_in_legend = set()
|
||||
|
||||
for col_idx, target in enumerate(targets, start=1):
|
||||
target_data = sample_df[sample_df["Target"] == target]
|
||||
|
||||
for mode_idx, mode in enumerate(modes):
|
||||
mode_data = target_data[target_data["Temporal Mode"] == mode]
|
||||
if len(mode_data) == 0:
|
||||
continue
|
||||
color = colors[mode_idx] if colors and mode_idx < len(colors) else None
|
||||
|
||||
# Show legend only the first time each mode appears
|
||||
show_in_legend = mode not in modes_in_legend
|
||||
if show_in_legend:
|
||||
modes_in_legend.add(mode)
|
||||
|
||||
# Create a unique legendgroup per mode so colors are consistent
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=mode_data["Grid"],
|
||||
y=mode_data["Samples (Coverage)"],
|
||||
name=mode,
|
||||
marker_color=color,
|
||||
legendgroup=mode, # Group by mode name
|
||||
showlegend=show_in_legend,
|
||||
),
|
||||
row=1,
|
||||
col=col_idx,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
|
||||
barmode="group",
|
||||
height=500,
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
fig.update_xaxes(tickangle=-45)
|
||||
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_count_stacked_bar(breakdown_df: pd.DataFrame) -> go.Figure:
|
||||
"""Create stacked bar chart showing feature counts by member.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the stacked bar chart visualization.
|
||||
|
||||
"""
|
||||
assert "Grid" in breakdown_df.columns
|
||||
assert "Temporal Mode" in breakdown_df.columns
|
||||
assert "Member" in breakdown_df.columns
|
||||
assert "Number of Features" in breakdown_df.columns
|
||||
|
||||
unique_members = sorted(breakdown_df["Member"].unique())
|
||||
temporal_modes = sorted(breakdown_df["Temporal Mode"].unique())
|
||||
source_color_map = _get_colors_from_members(unique_members)
|
||||
|
||||
# Create subplots for each temporal mode
|
||||
fig = make_subplots(
|
||||
rows=1,
|
||||
cols=len(temporal_modes),
|
||||
subplot_titles=[f"Temporal Mode: {mode}" for mode in temporal_modes],
|
||||
shared_yaxes=True,
|
||||
)
|
||||
|
||||
for col_idx, mode in enumerate(temporal_modes, start=1):
|
||||
mode_data = breakdown_df[breakdown_df["Temporal Mode"] == mode]
|
||||
|
||||
for member in unique_members:
|
||||
member_data = mode_data[mode_data["Member"] == member]
|
||||
color = source_color_map.get(member)
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=member_data["Grid"],
|
||||
y=member_data["Number of Features"],
|
||||
name=member,
|
||||
marker_color=color,
|
||||
legendgroup=member,
|
||||
showlegend=(col_idx == 1),
|
||||
),
|
||||
row=1,
|
||||
col=col_idx,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title_text="Input Features by Member Across Grid Configurations",
|
||||
barmode="stack",
|
||||
height=500,
|
||||
showlegend=True,
|
||||
)
|
||||
fig.update_xaxes(tickangle=-45)
|
||||
fig.update_yaxes(title_text="Number of Features", row=1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_inference_cells_bar(sample_df: pd.DataFrame) -> go.Figure:
|
||||
"""Create bar chart for inference cells by grid configuration.
|
||||
|
||||
Args:
|
||||
sample_df: DataFrame with columns: Grid, Temporal Mode, Inference Cells.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
assert "Grid" in sample_df.columns
|
||||
assert "Inference Cells" in sample_df.columns
|
||||
|
||||
n_grids = sample_df["Grid"].nunique()
|
||||
mode_colors: list[str] = get_palette("Number of Samples", n_colors=n_grids)
|
||||
|
||||
fig = px.bar(
|
||||
sample_df,
|
||||
x="Grid",
|
||||
y="Inference Cells",
|
||||
color="Grid",
|
||||
title="Spatial Coverage (Grid Cells with Complete Data)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Inference Cells": "Number of Cells",
|
||||
},
|
||||
color_discrete_sequence=mode_colors,
|
||||
text="Inference Cells",
|
||||
)
|
||||
|
||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig.update_layout(xaxis_tickangle=-45, height=500, showlegend=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_total_samples_bar(
|
||||
comparison_df: pd.DataFrame,
|
||||
grid_colors: list[str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create bar chart for total samples by grid configuration.
|
||||
|
||||
Args:
|
||||
comparison_df: DataFrame with columns: Grid, Total Samples.
|
||||
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="Grid",
|
||||
y="Total Samples",
|
||||
color="Grid",
|
||||
title="Training Samples (Binary Task)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Total Samples": "Number of Samples",
|
||||
},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Total Samples",
|
||||
)
|
||||
|
||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_breakdown_donut(
|
||||
grid_data: pd.DataFrame,
|
||||
grid_config: str,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create donut chart for feature breakdown by data source for a specific grid.
|
||||
|
||||
Args:
|
||||
grid_data: DataFrame with columns: Data Source, Number of Features.
|
||||
grid_config: Grid configuration name for the title.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the donut chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.pie(
|
||||
grid_data,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title=grid_config,
|
||||
hole=0.4,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent")
|
||||
fig.update_layout(showlegend=True, height=350)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_distribution_pie(
|
||||
breakdown_df: pd.DataFrame,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create pie chart for feature distribution by data source.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Data Source, Number of Features.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the pie chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.pie(
|
||||
breakdown_df,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title="Feature Distribution by Data Source",
|
||||
hole=0.4,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||
fig.update_layout(height=400)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_member_details_table(member_stats_dict: dict[str, dict]) -> pd.DataFrame:
|
||||
"""Create a detailed table for member statistics.
|
||||
|
||||
Args:
|
||||
member_stats_dict: Dictionary mapping member names to their statistics.
|
||||
Each stats dict should contain: feature_count, variable_names, dimensions, size_bytes.
|
||||
|
||||
Returns:
|
||||
DataFrame with detailed member information.
|
||||
|
||||
"""
|
||||
rows = []
|
||||
for member, stats in member_stats_dict.items():
|
||||
# Calculate total data points
|
||||
total_points = 1
|
||||
for dim_size in stats["dimensions"].values():
|
||||
total_points *= dim_size
|
||||
|
||||
# Format dimension string
|
||||
dim_str = " x ".join([f"{k}={v:,}" for k, v in stats["dimensions"].items()])
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"Member": member,
|
||||
"Features": stats["feature_count"],
|
||||
"Variables": len(stats["variable_names"]),
|
||||
"Dimensions": dim_str,
|
||||
"Data Points": f"{total_points:,}",
|
||||
"Size (MB)": f"{stats['size_bytes'] / (1024 * 1024):.2f}",
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def create_feature_breakdown_donuts_grid(breakdown_df: pd.DataFrame) -> go.Figure:
|
||||
"""Create grid of donut charts showing feature breakdown by member for each grid+temporal mode.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the grid of donut charts.
|
||||
|
||||
"""
|
||||
assert "Grid" in breakdown_df.columns
|
||||
assert "Temporal Mode" in breakdown_df.columns
|
||||
assert "Member" in breakdown_df.columns
|
||||
assert "Number of Features" in breakdown_df.columns
|
||||
|
||||
# Get unique grids and temporal modes
|
||||
grids = breakdown_df["Grid"].unique().tolist()
|
||||
temporal_modes = breakdown_df["Temporal Mode"].unique().tolist()
|
||||
unique_members = sorted(breakdown_df["Member"].unique())
|
||||
source_color_map = _get_colors_from_members(unique_members)
|
||||
|
||||
n_grids = len(grids)
|
||||
n_modes = len(temporal_modes)
|
||||
|
||||
# Create subplot grid (grids as rows, temporal modes as columns)
|
||||
# Don't use subplot_titles - we'll add custom annotations for row/column labels
|
||||
fig = make_subplots(
|
||||
rows=n_grids,
|
||||
cols=n_modes,
|
||||
specs=[[{"type": "pie"}] * n_modes for _ in range(n_grids)],
|
||||
horizontal_spacing=0.01,
|
||||
vertical_spacing=0.02,
|
||||
)
|
||||
|
||||
# Get all members to ensure consistent coloring across all donuts
|
||||
all_members = sorted(breakdown_df["Member"].unique())
|
||||
|
||||
for row_idx, grid in enumerate(grids, start=1):
|
||||
for col_idx, mode in enumerate(temporal_modes, start=1):
|
||||
subset = breakdown_df[(breakdown_df["Grid"] == grid) & (breakdown_df["Temporal Mode"] == mode)]
|
||||
|
||||
if len(subset) == 0:
|
||||
continue
|
||||
|
||||
# Build labels, values, and colors in consistent order
|
||||
labels = []
|
||||
values = []
|
||||
colors_list = []
|
||||
|
||||
for member in all_members:
|
||||
member_data = subset[subset["Member"] == member]
|
||||
if len(member_data) > 0:
|
||||
labels.append(member)
|
||||
values.append(member_data["Number of Features"].iloc[0])
|
||||
if source_color_map:
|
||||
colors_list.append(source_color_map[member])
|
||||
|
||||
# Only show legend for first subplot
|
||||
show_legend = row_idx == 1 and col_idx == 1
|
||||
|
||||
fig.add_trace(
|
||||
go.Pie(
|
||||
labels=labels,
|
||||
values=values,
|
||||
name=f"{grid} - {mode}",
|
||||
hole=0.4,
|
||||
marker={"colors": colors_list} if colors_list else None,
|
||||
textposition="inside",
|
||||
textinfo="percent",
|
||||
showlegend=show_legend,
|
||||
),
|
||||
row=row_idx,
|
||||
col=col_idx,
|
||||
)
|
||||
|
||||
# Calculate appropriate height based on number of rows
|
||||
height = max(500, n_grids * 400)
|
||||
|
||||
fig.update_layout(
|
||||
title_text="Feature Breakdown by Member Across Grid Configurations and Temporal Modes",
|
||||
height=height,
|
||||
showlegend=True,
|
||||
margin={"l": 100, "r": 50, "t": 150, "b": 50}, # Add left margin for row labels
|
||||
)
|
||||
|
||||
# Calculate subplot positions in paper coordinates
|
||||
# Account for spacing between subplots
|
||||
h_spacing = 0.01
|
||||
v_spacing = 0.02
|
||||
|
||||
# Calculate width and height of each subplot in paper coordinates
|
||||
subplot_width = (1.0 - h_spacing * (n_modes - 1)) / n_modes
|
||||
subplot_height = (1.0 - v_spacing * (n_grids - 1)) / n_grids
|
||||
|
||||
# Add column headers at the top (temporal modes)
|
||||
for col_idx, mode in enumerate(temporal_modes):
|
||||
# Calculate x position for this column's center
|
||||
x_pos = col_idx * (subplot_width + h_spacing) + subplot_width / 2
|
||||
fig.add_annotation(
|
||||
text=f"<b>{mode}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=x_pos,
|
||||
y=1.01,
|
||||
showarrow=False,
|
||||
xanchor="center",
|
||||
yanchor="bottom",
|
||||
font={"size": 18},
|
||||
)
|
||||
|
||||
# Add row headers on the left (grid configurations)
|
||||
for row_idx, grid in enumerate(grids):
|
||||
# Calculate y position for this row's center
|
||||
# Note: y coordinates go from bottom to top, but rows go from top to bottom
|
||||
y_pos = 1.0 - (row_idx * (subplot_height + v_spacing) + subplot_height / 2)
|
||||
fig.add_annotation(
|
||||
text=f"<b>{grid}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=-0.01,
|
||||
y=y_pos,
|
||||
showarrow=False,
|
||||
xanchor="right",
|
||||
yanchor="middle",
|
||||
font={"size": 18},
|
||||
textangle=-90,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
|
@ -1,276 +0,0 @@
|
|||
"""Visualization functions for the overview page.
|
||||
|
||||
This module contains reusable plotting functions for dataset analysis visualizations,
|
||||
including sample counts, feature counts, and dataset statistics.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
def create_sample_count_heatmap(
|
||||
pivot_df: pd.DataFrame,
|
||||
target: str,
|
||||
colorscale: list[str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create heatmap showing sample counts by grid and task.
|
||||
|
||||
Args:
|
||||
pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values.
|
||||
target: Target dataset name for the title.
|
||||
colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the heatmap visualization.
|
||||
|
||||
"""
|
||||
fig = px.imshow(
|
||||
pivot_df,
|
||||
labels={
|
||||
"x": "Task",
|
||||
"y": "Grid Configuration",
|
||||
"color": "Sample Count",
|
||||
},
|
||||
x=pivot_df.columns,
|
||||
y=pivot_df.index,
|
||||
color_continuous_scale=colorscale,
|
||||
aspect="auto",
|
||||
title=f"Target: {target}",
|
||||
)
|
||||
|
||||
# Add text annotations
|
||||
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
|
||||
fig.update_layout(height=400)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_sample_count_bar_chart(
|
||||
sample_df: pd.DataFrame,
|
||||
target_color_maps: dict[str, list[str]] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create bar chart showing sample counts by grid, target, and task.
|
||||
|
||||
Args:
|
||||
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
|
||||
target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
# Create subplots manually to have better control over colors
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
targets = sorted(sample_df["Target"].unique())
|
||||
tasks = sorted(sample_df["Task"].unique())
|
||||
|
||||
fig = make_subplots(
|
||||
rows=1,
|
||||
cols=len(targets),
|
||||
subplot_titles=[f"Target: {target}" for target in targets],
|
||||
shared_yaxes=True,
|
||||
)
|
||||
|
||||
for col_idx, target in enumerate(targets, 1):
|
||||
target_data = sample_df[sample_df["Target"] == target]
|
||||
# Get color palette for this target
|
||||
colors = target_color_maps.get(target, None) if target_color_maps else None
|
||||
|
||||
for task_idx, task in enumerate(tasks):
|
||||
task_data = target_data[target_data["Task"] == task]
|
||||
color = colors[task_idx] if colors and task_idx < len(colors) else None
|
||||
|
||||
# Create a unique legendgroup per task so colors are consistent
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=task_data["Grid"],
|
||||
y=task_data["Samples (Coverage)"],
|
||||
name=task,
|
||||
marker_color=color,
|
||||
legendgroup=task, # Group by task name
|
||||
showlegend=(col_idx == 1), # Only show legend for first subplot
|
||||
),
|
||||
row=1,
|
||||
col=col_idx,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
|
||||
barmode="group",
|
||||
height=500,
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
fig.update_xaxes(tickangle=-45)
|
||||
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_count_stacked_bar(
|
||||
breakdown_df: pd.DataFrame,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create stacked bar chart showing feature counts by data source.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the stacked bar chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.bar(
|
||||
breakdown_df,
|
||||
x="Grid",
|
||||
y="Number of Features",
|
||||
color="Data Source",
|
||||
barmode="stack",
|
||||
title="Input Features by Data Source Across Grid Configurations",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Number of Features": "Number of Features",
|
||||
},
|
||||
color_discrete_map=source_color_map,
|
||||
text_auto=False,
|
||||
)
|
||||
|
||||
fig.update_layout(height=500, xaxis_tickangle=-45)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_inference_cells_bar(
|
||||
comparison_df: pd.DataFrame,
|
||||
grid_colors: list[str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create bar chart for inference cells by grid configuration.
|
||||
|
||||
Args:
|
||||
comparison_df: DataFrame with columns: Grid, Inference Cells.
|
||||
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="Grid",
|
||||
y="Inference Cells",
|
||||
color="Grid",
|
||||
title="Spatial Coverage (Grid Cells with Complete Data)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Inference Cells": "Number of Cells",
|
||||
},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Inference Cells",
|
||||
)
|
||||
|
||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_total_samples_bar(
|
||||
comparison_df: pd.DataFrame,
|
||||
grid_colors: list[str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create bar chart for total samples by grid configuration.
|
||||
|
||||
Args:
|
||||
comparison_df: DataFrame with columns: Grid, Total Samples.
|
||||
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the bar chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="Grid",
|
||||
y="Total Samples",
|
||||
color="Grid",
|
||||
title="Training Samples (Binary Task)",
|
||||
labels={
|
||||
"Grid": "Grid Configuration",
|
||||
"Total Samples": "Number of Samples",
|
||||
},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Total Samples",
|
||||
)
|
||||
|
||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_breakdown_donut(
|
||||
grid_data: pd.DataFrame,
|
||||
grid_config: str,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create donut chart for feature breakdown by data source for a specific grid.
|
||||
|
||||
Args:
|
||||
grid_data: DataFrame with columns: Data Source, Number of Features.
|
||||
grid_config: Grid configuration name for the title.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the donut chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.pie(
|
||||
grid_data,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title=grid_config,
|
||||
hole=0.4,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent")
|
||||
fig.update_layout(showlegend=True, height=350)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_feature_distribution_pie(
|
||||
breakdown_df: pd.DataFrame,
|
||||
source_color_map: dict[str, str] | None = None,
|
||||
) -> go.Figure:
|
||||
"""Create pie chart for feature distribution by data source.
|
||||
|
||||
Args:
|
||||
breakdown_df: DataFrame with columns: Data Source, Number of Features.
|
||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||
If None, uses default Plotly colors.
|
||||
|
||||
Returns:
|
||||
Plotly Figure object containing the pie chart visualization.
|
||||
|
||||
"""
|
||||
fig = px.pie(
|
||||
breakdown_df,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title="Feature Distribution by Data Source",
|
||||
hole=0.4,
|
||||
color_discrete_map=source_color_map,
|
||||
color="Data Source",
|
||||
)
|
||||
|
||||
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||
|
||||
return fig
|
||||
376
src/entropice/dashboard/sections/dataset_statistics.py
Normal file
376
src/entropice/dashboard/sections/dataset_statistics.py
Normal file
|
|
@ -0,0 +1,376 @@
|
|||
"""Dataset Statistics Section for Entropice Dashboard."""
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.dataset_statistics import (
|
||||
create_feature_breakdown_donuts_grid,
|
||||
create_feature_count_stacked_bar,
|
||||
create_feature_distribution_pie,
|
||||
create_inference_cells_bar,
|
||||
create_member_details_table,
|
||||
create_sample_count_bar_chart,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
from entropice.utils.types import (
|
||||
GridLevel,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
Task,
|
||||
TemporalMode,
|
||||
all_l2_source_datasets,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
type DatasetStatsCache = dict[GridLevel, dict[TemporalMode, DatasetStatistics]]
|
||||
|
||||
|
||||
def render_training_samples_tab(sample_df: pd.DataFrame):
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available training samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
- **Grid Level**: varying by grid type
|
||||
"""
|
||||
)
|
||||
|
||||
# Create and display bar chart
|
||||
fig = create_sample_count_bar_chart(sample_df)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Display full table with formatting
|
||||
st.markdown("#### Detailed Sample Counts")
|
||||
# Format numbers with commas
|
||||
sample_df["Samples (Coverage)"] = sample_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||
# Format coverage as percentage with 2 decimal places
|
||||
sample_df["Coverage %"] = sample_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||
|
||||
st.dataframe(sample_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
def render_dataset_characteristics_tab(
|
||||
breakdown_df: pd.DataFrame, sample_df: pd.DataFrame, comparison_df: pd.DataFrame
|
||||
):
|
||||
"""Render static comparison of feature counts across all grid configurations."""
|
||||
st.markdown(
|
||||
"""
|
||||
Comparing dataset characteristics for all grid configurations with all data sources enabled.
|
||||
- **Features**: Total number of input features from all data sources
|
||||
- **Spatial Coverage**: Number of grid cells with complete data coverage
|
||||
"""
|
||||
)
|
||||
|
||||
# Get data from cache
|
||||
|
||||
# Create and display stacked bar chart
|
||||
fig = create_feature_count_stacked_bar(breakdown_df)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Add spatial coverage metric
|
||||
fig_cells = create_inference_cells_bar(sample_df)
|
||||
st.plotly_chart(fig_cells, width="stretch")
|
||||
|
||||
# Display full comparison table with formatting
|
||||
st.markdown("#### Detailed Comparison Table")
|
||||
st.dataframe(comparison_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
|
||||
"""Render feature breakdown by member across grid configurations and temporal modes."""
|
||||
st.markdown(
|
||||
"""
|
||||
Visualizing feature contribution of each data source (member) across all grid configurations
|
||||
and temporal modes. Each donut chart represents the feature breakdown for a specific
|
||||
grid-temporal mode combination.
|
||||
"""
|
||||
)
|
||||
# Create and display the grid of donut charts
|
||||
fig = create_feature_breakdown_donuts_grid(breakdown_df)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Display full breakdown table with formatting
|
||||
st.markdown("#### Detailed Breakdown Table")
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
|
||||
"""Render interactive detailed configuration explorer."""
|
||||
st.markdown(
|
||||
"""
|
||||
Explore detailed statistics for a specific dataset configuration.
|
||||
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
||||
"""
|
||||
)
|
||||
|
||||
# Grid selection
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
selected_grid_display = st.selectbox(
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
help="Select the grid system and resolution level",
|
||||
key="feature_grid_select",
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Create temporal mode options with proper formatting
|
||||
temporal_mode_options = ["feature", "synopsis"] + [str(year) for year in [2018, 2019, 2020, 2021, 2022, 2023]]
|
||||
temporal_mode_display = st.selectbox(
|
||||
"Temporal Mode",
|
||||
options=temporal_mode_options,
|
||||
index=0,
|
||||
help="Select the temporal mode (feature, synopsis, or specific year)",
|
||||
key="feature_temporal_mode_select",
|
||||
)
|
||||
# Convert back to proper type
|
||||
selected_temporal_mode: TemporalMode
|
||||
if temporal_mode_display.isdigit():
|
||||
selected_temporal_mode = int(temporal_mode_display) # type: ignore[assignment]
|
||||
else:
|
||||
selected_temporal_mode = temporal_mode_display # pyright: ignore[reportAssignmentType]
|
||||
|
||||
with col3:
|
||||
selected_target: TargetDataset = st.selectbox(
|
||||
"Target Dataset",
|
||||
options=["darts_v1", "darts_mllabels"],
|
||||
index=0,
|
||||
help="Select the target dataset",
|
||||
key="feature_target_select",
|
||||
)
|
||||
|
||||
with col4:
|
||||
selected_task: Task = st.selectbox(
|
||||
"Task",
|
||||
options=["binary", "count", "density", "count_regimes", "density_regimes"],
|
||||
index=0,
|
||||
help="Select the task type",
|
||||
key="feature_task_select",
|
||||
)
|
||||
|
||||
# Find the selected grid config
|
||||
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
|
||||
|
||||
# Get stats for the selected configuration
|
||||
mode_stats = dataset_stats[selected_grid_config.id]
|
||||
|
||||
# Check if the selected temporal mode has stats
|
||||
if selected_temporal_mode not in mode_stats:
|
||||
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
|
||||
return
|
||||
|
||||
stats = mode_stats[selected_temporal_mode]
|
||||
available_members = list(stats.members.keys())
|
||||
|
||||
# Members selection
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
# Use columns for checkboxes
|
||||
cols = st.columns(len(all_l2_source_datasets))
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
|
||||
for idx, member in enumerate(all_l2_source_datasets):
|
||||
with cols[idx]:
|
||||
default_value = member in available_members
|
||||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||
selected_members.append(member) # type: ignore[arg-type]
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Get target stats if available
|
||||
if selected_target not in stats.target:
|
||||
st.warning(f"Target {selected_target} is not available for this configuration.")
|
||||
return
|
||||
|
||||
target_stats = stats.target[selected_target]
|
||||
|
||||
# Check if task is available
|
||||
if selected_task not in target_stats:
|
||||
st.warning(
|
||||
f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
|
||||
)
|
||||
return
|
||||
|
||||
task_stats = target_stats[selected_task]
|
||||
|
||||
# High-level metrics
|
||||
st.markdown("#### Configuration Summary")
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
if selected_member_stats:
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
else:
|
||||
min_cells = 0
|
||||
st.metric(
|
||||
"Inference Cells",
|
||||
f"{min_cells:,}",
|
||||
help="Minimum number of cells across all selected data sources",
|
||||
)
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
st.metric("Training Samples", f"{task_stats.training_cells:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = total_features * task_stats.training_cells
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Task-specific statistics
|
||||
st.markdown("#### Task Statistics")
|
||||
task_col1, task_col2, task_col3 = st.columns(3)
|
||||
with task_col1:
|
||||
st.metric("Task Type", selected_task.replace("_", " ").title())
|
||||
with task_col2:
|
||||
st.metric("Coverage", f"{task_stats.coverage:.2f}%")
|
||||
with task_col3:
|
||||
if task_stats.class_counts:
|
||||
st.metric("Number of Classes", len(task_stats.class_counts))
|
||||
else:
|
||||
st.metric("Task Mode", "Regression")
|
||||
|
||||
# Class distribution for classification tasks
|
||||
if task_stats.class_distribution:
|
||||
st.markdown("#### Class Distribution")
|
||||
class_dist_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"Class": class_name,
|
||||
"Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
|
||||
"Percentage": f"{pct:.2f}%",
|
||||
}
|
||||
for class_name, pct in task_stats.class_distribution.items()
|
||||
]
|
||||
)
|
||||
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||
source_color_map_raw = {}
|
||||
for member in unique_members_for_color:
|
||||
source = member.split("-")[0]
|
||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||
palette = get_palette(source, n_colors=n_members + 2)
|
||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||
source_color_map_raw[member] = palette[idx + 1]
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
# Create detailed table
|
||||
member_details_dict = {
|
||||
member: {
|
||||
"feature_count": ms.feature_count,
|
||||
"variable_names": ms.variable_names,
|
||||
"dimensions": ms.dimensions,
|
||||
"size_bytes": ms.size_bytes,
|
||||
}
|
||||
for member, ms in selected_member_stats.items()
|
||||
}
|
||||
details_df = create_member_details_table(member_details_dict)
|
||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||
|
||||
# Individual member details
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size:,}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
st.markdown("---")
|
||||
else:
|
||||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_dataset_statistics(all_stats: DatasetStatsCache):
|
||||
"""Render the dataset statistics section with sample and feature counts."""
|
||||
st.header("📈 Dataset Statistics")
|
||||
|
||||
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
|
||||
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
||||
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
||||
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
||||
|
||||
# Create tabs for different analysis views
|
||||
analysis_tabs = st.tabs(
|
||||
[
|
||||
"📊 Training Samples",
|
||||
"📈 Dataset Characteristics",
|
||||
"🔍 Feature Breakdown",
|
||||
"⚙️ Configuration Explorer",
|
||||
]
|
||||
)
|
||||
|
||||
with analysis_tabs[0]:
|
||||
st.subheader("Training Samples by Configuration")
|
||||
render_training_samples_tab(training_sample_df)
|
||||
|
||||
with analysis_tabs[1]:
|
||||
st.subheader("Dataset Characteristics Across Grid Configurations")
|
||||
render_dataset_characteristics_tab(feature_breakdown_df, inference_sample_df, comparison_df)
|
||||
|
||||
with analysis_tabs[2]:
|
||||
st.subheader("Feature Breakdown by Data Source")
|
||||
render_feature_breakdown_tab(feature_breakdown_df)
|
||||
|
||||
with analysis_tabs[3]:
|
||||
st.subheader("Interactive Configuration Explorer")
|
||||
render_configuration_explorer_tab(all_stats)
|
||||
158
src/entropice/dashboard/sections/experiment_results.py
Normal file
158
src/entropice/dashboard/sections/experiment_results.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Experiment Results Section for the Entropice Dashboard."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
)
|
||||
|
||||
|
||||
def render_training_results_summary(training_results: list[TrainingResult]):
|
||||
"""Render summary metrics for training results."""
|
||||
st.header("📊 Training Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
experiments = {tr.experiment for tr in training_results if tr.experiment}
|
||||
st.metric("Experiments", len(experiments))
|
||||
|
||||
with col2:
|
||||
st.metric("Total Runs", len(training_results))
|
||||
|
||||
with col3:
|
||||
models = {tr.settings.model for tr in training_results}
|
||||
st.metric("Model Types", len(models))
|
||||
|
||||
with col4:
|
||||
latest = training_results[0] # Already sorted by creation time
|
||||
latest_datetime = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d %H:%M")
|
||||
st.metric("Latest Run", latest_datetime)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
|
||||
"""Render detailed experiment results table and expandable details."""
|
||||
st.header("🎯 Experiment Results")
|
||||
|
||||
# Filters
|
||||
experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
|
||||
tasks = sorted({tr.settings.task for tr in training_results})
|
||||
models = sorted({tr.settings.model for tr in training_results})
|
||||
grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
|
||||
|
||||
# Create filter columns
|
||||
filter_cols = st.columns(4)
|
||||
|
||||
with filter_cols[0]:
|
||||
if experiments:
|
||||
selected_experiment = st.selectbox(
|
||||
"Experiment",
|
||||
options=["All", *experiments],
|
||||
index=0,
|
||||
)
|
||||
else:
|
||||
selected_experiment = "All"
|
||||
|
||||
with filter_cols[1]:
|
||||
selected_task = st.selectbox(
|
||||
"Task",
|
||||
options=["All", *tasks],
|
||||
index=0,
|
||||
)
|
||||
|
||||
with filter_cols[2]:
|
||||
selected_model = st.selectbox(
|
||||
"Model",
|
||||
options=["All", *models],
|
||||
index=0,
|
||||
)
|
||||
|
||||
with filter_cols[3]:
|
||||
selected_grid = st.selectbox(
|
||||
"Grid",
|
||||
options=["All", *grids],
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
filtered_results = training_results
|
||||
if selected_experiment != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
|
||||
if selected_task != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
|
||||
if selected_model != "All":
|
||||
filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model]
|
||||
if selected_grid != "All":
|
||||
filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
|
||||
|
||||
st.subheader("Results Table")
|
||||
|
||||
summary_df = TrainingResult.to_dataframe(filtered_results)
|
||||
# Display with color coding for best scores
|
||||
st.dataframe(
|
||||
summary_df,
|
||||
width="stretch",
|
||||
hide_index=True,
|
||||
)
|
||||
|
||||
# Expandable details for each result
|
||||
st.subheader("Individual Experiment Details")
|
||||
|
||||
for tr in filtered_results:
|
||||
tr_info = tr.display_info
|
||||
display_name = tr_info.get_display_name("model_first")
|
||||
with st.expander(display_name):
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
|
||||
st.write("**Configuration:**")
|
||||
st.write(f"- **Experiment:** {tr.experiment}")
|
||||
st.write(f"- **Task:** {tr.settings.task}")
|
||||
st.write(f"- **Model:** {tr.settings.model}")
|
||||
st.write(f"- **Grid:** {grid_config.display_name}")
|
||||
st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}")
|
||||
st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}")
|
||||
st.write(f"- **Members:** {', '.join(tr.settings.members)}")
|
||||
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
||||
st.write(f"- **Classes:** {tr.settings.classes}")
|
||||
|
||||
st.write("\n**Files:**")
|
||||
for file in tr.files:
|
||||
if file.name == "search_settings.toml":
|
||||
st.write(f"- ⚙️ `{file.name}`")
|
||||
elif file.name == "best_estimator_model.pkl":
|
||||
st.write(f"- 🧮 `{file.name}`")
|
||||
elif file.name == "search_results.parquet":
|
||||
st.write(f"- 📊 `{file.name}`")
|
||||
elif file.name == "predicted_probabilities.parquet":
|
||||
st.write(f"- 🎯 `{file.name}`")
|
||||
else:
|
||||
st.write(f"- 📄 `{file.name}`")
|
||||
with col2:
|
||||
st.write("**CV Score Summary:**")
|
||||
|
||||
# Extract all test scores
|
||||
metric_df = tr.get_metric_dataframe()
|
||||
if metric_df is not None:
|
||||
st.dataframe(metric_df, width="stretch", hide_index=True)
|
||||
else:
|
||||
st.write("No test scores found in results.")
|
||||
|
||||
# Show parameter space explored
|
||||
if "initial_K" in tr.results.columns: # Common parameter
|
||||
st.write("\n**Parameter Ranges Explored:**")
|
||||
for param in ["initial_K", "eps_cl", "eps_e"]:
|
||||
if param in tr.results.columns:
|
||||
min_val = tr.results[param].min()
|
||||
max_val = tr.results[param].max()
|
||||
unique_vals = tr.results[param].nunique()
|
||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||
|
||||
with st.expander("Show CV Results DataFrame"):
|
||||
st.dataframe(tr.results, width="stretch", hide_index=True)
|
||||
|
||||
st.write(f"\n**Path:** `{tr.path}`")
|
||||
|
|
@ -90,6 +90,14 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
|
|||
A list of hex color strings.
|
||||
|
||||
"""
|
||||
# Hardcode some common variables to specific colormaps
|
||||
if variable == "ERA5":
|
||||
cmap = load_cmap(name="blue_material").resampled(n_colors)
|
||||
elif variable == "ArcticDEM":
|
||||
cmap = load_cmap(name="deep_purple_material").resampled(n_colors)
|
||||
elif variable == "AlphaEarth":
|
||||
cmap = load_cmap(name="green_material").resampled(n_colors)
|
||||
else:
|
||||
cmap = get_cmap(variable).resampled(n_colors)
|
||||
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
|
||||
return colors
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
|
||||
import antimeridian
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import toml
|
||||
|
|
@ -16,9 +15,8 @@ from shapely.geometry import shape
|
|||
import entropice.spatial.grids
|
||||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||
from entropice.ml.training import TrainingSettings
|
||||
from entropice.utils.types import L2SourceDataset, Task
|
||||
from entropice.utils.types import GridConfig
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
|
|
@ -32,22 +30,29 @@ def _fix_hex_geometry(geom):
|
|||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
"""Wrapper for training result data and metadata."""
|
||||
|
||||
path: Path
|
||||
experiment: str
|
||||
settings: TrainingSettings
|
||||
results: pd.DataFrame
|
||||
metrics: dict[str, float]
|
||||
confusion_matrix: xr.DataArray
|
||||
train_metrics: dict[str, float]
|
||||
test_metrics: dict[str, float]
|
||||
combined_metrics: dict[str, float]
|
||||
confusion_matrix: xr.Dataset | None
|
||||
created_at: float
|
||||
available_metrics: list[str]
|
||||
files: list[Path]
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||
def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
metrics_file = result_path / "test_metrics.toml"
|
||||
metrics_file = result_path / "metrics.toml"
|
||||
confusion_matrix_file = result_path / "confusion_matrix.nc"
|
||||
all_files = list(result_path.iterdir())
|
||||
if not result_file.exists():
|
||||
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||
if not settings_file.exists():
|
||||
|
|
@ -56,28 +61,46 @@ class TrainingResult:
|
|||
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||
if not metrics_file.exists():
|
||||
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
||||
if not confusion_matrix_file.exists():
|
||||
raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||
settings_dict = toml.load(settings_file)["settings"]
|
||||
|
||||
# Handle backward compatibility: add missing fields with defaults
|
||||
if "classes" not in settings_dict:
|
||||
settings_dict["classes"] = None
|
||||
if "param_grid" not in settings_dict:
|
||||
settings_dict["param_grid"] = {}
|
||||
if "cv_splits" not in settings_dict:
|
||||
settings_dict["cv_splits"] = 5
|
||||
if "metrics" not in settings_dict:
|
||||
settings_dict["metrics"] = []
|
||||
|
||||
settings = TrainingSettings(**settings_dict)
|
||||
results = pd.read_parquet(result_file)
|
||||
metrics = toml.load(metrics_file)["test_metrics"]
|
||||
confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
|
||||
metrics = toml.load(metrics_file)
|
||||
if not confusion_matrix_file.exists():
|
||||
confusion_matrix = None
|
||||
else:
|
||||
confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf")
|
||||
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||
|
||||
return cls(
|
||||
path=result_path,
|
||||
experiment=experiment_name or "N/A",
|
||||
settings=settings,
|
||||
results=results,
|
||||
metrics=metrics,
|
||||
train_metrics=metrics["train_metrics"],
|
||||
test_metrics=metrics["test_metrics"],
|
||||
combined_metrics=metrics["combined_metrics"],
|
||||
confusion_matrix=confusion_matrix,
|
||||
created_at=created_at,
|
||||
available_metrics=available_metrics,
|
||||
files=all_files,
|
||||
)
|
||||
|
||||
@property
|
||||
def display_info(self) -> TrainingResultDisplayInfo:
|
||||
"""Get display information for the training result."""
|
||||
return TrainingResultDisplayInfo(
|
||||
task=self.settings.task,
|
||||
model=self.settings.model,
|
||||
|
|
@ -126,9 +149,70 @@ class TrainingResult:
|
|||
st.error(f"Error loading predictions: {e}")
|
||||
return None
|
||||
|
||||
def get_metric_dataframe(self) -> pd.DataFrame | None:
|
||||
"""Get a DataFrame of available metrics for this training result."""
|
||||
metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")]
|
||||
if not metric_cols:
|
||||
return None
|
||||
metric_data = []
|
||||
for col in metric_cols:
|
||||
metric_name = col.replace("mean_test_", "").replace("neg_", "").title()
|
||||
metrics = self.results[col]
|
||||
# Check if the metric is negative
|
||||
if col.startswith("mean_test_neg_"):
|
||||
task_multiplier = 1 if self.settings.task != "density" else 100
|
||||
task_multiplier *= -1
|
||||
metrics = metrics * task_multiplier
|
||||
|
||||
metric_data.append(
|
||||
{
|
||||
"Metric": metric_name,
|
||||
"Best": f"{metrics.max():.4f}",
|
||||
"Mean": f"{metrics.mean():.4f}",
|
||||
"Std": f"{metrics.std():.4f}",
|
||||
"Worst": f"{metrics.min():.4f}",
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(metric_data)
|
||||
|
||||
def _get_best_metric_name(self) -> str:
|
||||
"""Get the primary metric name for a given task."""
|
||||
match self.settings.task:
|
||||
case "binary":
|
||||
return "f1"
|
||||
case "count_regimes" | "density_regimes":
|
||||
return "f1_weighted"
|
||||
case _: # regression tasks
|
||||
return "r2"
|
||||
|
||||
@staticmethod
|
||||
def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame:
|
||||
"""Convert a list of TrainingResult objects to a DataFrame for display."""
|
||||
records = []
|
||||
for tr in training_results:
|
||||
info = tr.display_info
|
||||
best_metric_name = tr._get_best_metric_name()
|
||||
|
||||
record = {
|
||||
"Experiment": tr.experiment if tr.experiment else "N/A",
|
||||
"Task": info.task,
|
||||
"Model": info.model,
|
||||
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
||||
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||
"Score-Metric": best_metric_name.title(),
|
||||
"Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name),
|
||||
"Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
|
||||
"Trials": len(tr.results),
|
||||
"Path": str(tr.path.name),
|
||||
}
|
||||
records.append(record)
|
||||
return pd.DataFrame.from_records(records)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_training_results() -> list[TrainingResult]:
|
||||
"""Load all training results from the results directory."""
|
||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||
training_results: list[TrainingResult] = []
|
||||
for result_path in results_dir.iterdir():
|
||||
|
|
@ -136,50 +220,22 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
continue
|
||||
try:
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
except FileNotFoundError as e:
|
||||
st.warning(f"Skipping incomplete training result: {e}")
|
||||
continue
|
||||
training_results.append(training_result)
|
||||
except FileNotFoundError as e:
|
||||
is_experiment_dir = False
|
||||
for experiment_path in result_path.iterdir():
|
||||
if not experiment_path.is_dir():
|
||||
continue
|
||||
try:
|
||||
experiment_name = experiment_path.name
|
||||
training_result = TrainingResult.from_path(experiment_path, experiment_name)
|
||||
training_results.append(training_result)
|
||||
is_experiment_dir = True
|
||||
except FileNotFoundError as e2:
|
||||
st.warning(f"Skipping incomplete training result: {e2}")
|
||||
if not is_experiment_dir:
|
||||
st.warning(f"Skipping incomplete training result: {e}")
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||
return training_results
|
||||
|
||||
|
||||
def load_all_training_data(
|
||||
e: DatasetEnsemble,
|
||||
) -> dict[Task, CategoricalTrainingDataset]:
|
||||
"""Load training data for all three tasks.
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
|
||||
Returns:
|
||||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||
|
||||
"""
|
||||
dataset = e.create(filter_target_col=e.covcol)
|
||||
return {
|
||||
"binary": e._cat_and_split(dataset, "binary", device="cpu"),
|
||||
"count": e._cat_and_split(dataset, "count", device="cpu"),
|
||||
"density": e._cat_and_split(dataset, "density", device="cpu"),
|
||||
}
|
||||
|
||||
|
||||
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
|
||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||
|
||||
Returns:
|
||||
xarray.Dataset with the raw data for the specified source.
|
||||
|
||||
"""
|
||||
targets = e._read_target()
|
||||
|
||||
# Load the member data lazily to get metadata
|
||||
ds = e._read_member(source, targets, lazy=False)
|
||||
|
||||
return ds, targets
|
||||
|
|
|
|||
|
|
@ -3,29 +3,28 @@
|
|||
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
|
||||
"""
|
||||
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from stopuhr import stopwatch
|
||||
|
||||
import entropice.spatial.grids
|
||||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import (
|
||||
Grid,
|
||||
GridLevel,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
Task,
|
||||
TemporalMode,
|
||||
all_l2_source_datasets,
|
||||
all_target_datasets,
|
||||
all_tasks,
|
||||
all_temporal_modes,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
|
|
@ -41,90 +40,63 @@ class MemberStatistics:
|
|||
size_bytes: int # Size of this member's data on disk in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
era5_agg = member.split("-")[1]
|
||||
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
|
||||
elif member == "ArcticDEM":
|
||||
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
||||
"""Pre-compute the statistics for a specific dataset member."""
|
||||
member_stats = {}
|
||||
for member in all_l2_source_datasets:
|
||||
ds = e.read_member(member, lazy=True)
|
||||
size_bytes = ds.nbytes
|
||||
|
||||
size_bytes = store.stat().st_size
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
|
||||
# Delete all coordinates which are not in the dimension
|
||||
for coord in ds.coords:
|
||||
if coord not in ds.dims:
|
||||
ds = ds.drop_vars(coord)
|
||||
n_cols_member = len(ds.data_vars)
|
||||
for dim in ds.sizes:
|
||||
if dim != "cell_ids":
|
||||
n_cols_member *= ds.sizes[dim]
|
||||
|
||||
return cls(
|
||||
member_stats[member] = cls(
|
||||
feature_count=n_cols_member,
|
||||
variable_names=list(ds.data_vars),
|
||||
dimensions=dict(ds.sizes),
|
||||
coordinates=list(ds.coords),
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
return member_stats
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TargetStatistics:
|
||||
"""Statistics for a specific target dataset."""
|
||||
|
||||
training_cells: dict[Task, int] # Number of cells used for training
|
||||
coverage: dict[Task, float] # Percentage of total cells covered by training data
|
||||
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
|
||||
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
|
||||
size_bytes: int # Size of the target dataset on disk in bytes
|
||||
training_cells: int # Number of cells used for training
|
||||
coverage: float # Percentage of total cells covered by training data
|
||||
class_counts: dict[str, int] | None # Class name to count mapping
|
||||
class_distribution: dict[str, float] | None # Class name to percentage mapping
|
||||
|
||||
@classmethod
|
||||
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
|
||||
if target == "darts_rts":
|
||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
|
||||
elif target == "darts_mllabels":
|
||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
|
||||
else:
|
||||
raise NotImplementedError(f"Target {target} not implemented.")
|
||||
target_gdf = gpd.read_parquet(target_store)
|
||||
size_bytes = target_store.stat().st_size
|
||||
training_cells: dict[Task, int] = {}
|
||||
training_coverage: dict[Task, float] = {}
|
||||
class_counts: dict[Task, dict[str, int]] = {}
|
||||
class_distribution: dict[Task, dict[str, float]] = {}
|
||||
def compute(cls, e: DatasetEnsemble, target: TargetDataset, total_cells: int) -> dict[Task, "TargetStatistics"]:
|
||||
"""Pre-compute the statistics for a specific target dataset."""
|
||||
target_stats = {}
|
||||
for task in all_tasks:
|
||||
task_col = taskcol[task][target]
|
||||
cov_col = covcol[target]
|
||||
targets = e.get_targets(target=target, task=task)
|
||||
|
||||
task_gdf = target_gdf[target_gdf[cov_col]]
|
||||
training_cells[task] = len(task_gdf)
|
||||
training_coverage[task] = len(task_gdf) / total_cells * 100
|
||||
training_cells = len(targets)
|
||||
training_coverage = len(targets) / total_cells * 100
|
||||
|
||||
model_labels = task_gdf[task_col].dropna()
|
||||
if task == "binary":
|
||||
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
||||
elif task == "count":
|
||||
binned = bin_values(model_labels.astype(int), task=task)
|
||||
elif task == "density":
|
||||
binned = bin_values(model_labels, task=task)
|
||||
if task in ["count", "density"]:
|
||||
class_counts = None
|
||||
class_distribution = None
|
||||
else:
|
||||
raise ValueError("Invalid task.")
|
||||
counts = binned.value_counts()
|
||||
assert targets["y"].dtype == "category", "Classification tasks must have categorical target dtype."
|
||||
counts = targets["y"].value_counts()
|
||||
distribution = counts / counts.sum() * 100
|
||||
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
|
||||
class_distribution[task] = distribution.to_dict()
|
||||
return TargetStatistics(
|
||||
class_counts = counts.to_dict()
|
||||
class_distribution = distribution.to_dict()
|
||||
target_stats[task] = cls(
|
||||
training_cells=training_cells,
|
||||
coverage=training_coverage,
|
||||
class_counts=class_counts,
|
||||
class_distribution=class_distribution,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
return target_stats
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -139,108 +111,154 @@ class DatasetStatistics:
|
|||
total_cells: int # Total number of grid cells potentially covered
|
||||
size_bytes: int # Size of the dataset on disk in bytes
|
||||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
||||
|
||||
@staticmethod
|
||||
def get_sample_count_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
def get_target_sample_count_df(
|
||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert sample count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
mode_stats = all_stats[grid_config.id]
|
||||
for temporal_mode, stats in mode_stats.items():
|
||||
for target_name, target_stats in stats.target.items():
|
||||
for task in all_tasks:
|
||||
training_cells = target_stats.training_cells[task]
|
||||
coverage = target_stats.coverage[task]
|
||||
# We can assume that all tasks have equal counts, since they only affect distribution, not presence
|
||||
target_task_stats = target_stats["binary"]
|
||||
n_samples = target_task_stats.training_cells
|
||||
coverage = target_task_stats.coverage
|
||||
for task, target_task_stats in target_stats.items():
|
||||
assert n_samples == target_task_stats.training_cells, (
|
||||
"Inconsistent sample counts across tasks for the same target."
|
||||
)
|
||||
assert coverage == target_task_stats.coverage, (
|
||||
"Inconsistent coverage across tasks for the same target."
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Target": target_name.replace("darts_", ""),
|
||||
"Task": task.capitalize(),
|
||||
"Samples (Coverage)": training_cells,
|
||||
"Target": target_name.replace("_", " ").title(),
|
||||
"Temporal Mode": str(temporal_mode).capitalize(),
|
||||
"Samples (Coverage)": n_samples,
|
||||
"Coverage %": coverage,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
|
||||
|
||||
@staticmethod
|
||||
def get_feature_count_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
def get_inference_sample_count_df(
|
||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert feature count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
data_sources = list(stats.members.keys())
|
||||
|
||||
# Determine minimum cells across all data sources
|
||||
mode_stats = all_stats[grid_config.id]
|
||||
# We can assume that all temporal modes have equal counts, since they only affect features, not samples
|
||||
stats = mode_stats["feature"]
|
||||
assert len(stats.members) > 0, "Dataset must have at least one member."
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||
|
||||
# Get sample count from first target dataset (darts_rts)
|
||||
first_target = stats.target["darts_rts"]
|
||||
total_samples = first_target.training_cells["binary"]
|
||||
for temporal_mode, stats in mode_stats.items():
|
||||
assert len(stats.members) > 0, "Dataset must have at least one member."
|
||||
assert (
|
||||
min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) == min_cells
|
||||
), "Inconsistent inference cell counts across temporal modes."
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Total Features": stats.total_features,
|
||||
"Data Sources": len(data_sources),
|
||||
"Inference Cells": min_cells,
|
||||
"Total Samples": total_samples,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
|
||||
|
||||
@staticmethod
|
||||
def get_comparison_df(
|
||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert comparison data to DataFrame for detailed table."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
mode_stats = all_stats[grid_config.id]
|
||||
for temporal_mode, stats in mode_stats.items():
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Temporal Mode": str(temporal_mode).capitalize(),
|
||||
"Total Features": f"{stats.total_features:,}",
|
||||
"Total Cells": f"{stats.total_cells:,}",
|
||||
"Min Cells": f"{min_cells:,}",
|
||||
"Size (MB)": f"{stats.size_bytes / (1024 * 1024):.2f}",
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows).sort_values(["Grid_Level_Sort", "Temporal Mode"]).drop(columns="Grid_Level_Sort")
|
||||
|
||||
@staticmethod
|
||||
def get_feature_breakdown_df(
|
||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
||||
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||
) -> pd.DataFrame:
|
||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
mode_stats = all_stats[grid_config.id]
|
||||
added_year_temporal_mode = False
|
||||
for temporal_mode, stats in mode_stats.items():
|
||||
if added_year_temporal_mode:
|
||||
continue
|
||||
if isinstance(temporal_mode, int):
|
||||
# Only add one year-based temporal mode for the breakdown
|
||||
added_year_temporal_mode = True
|
||||
temporal_mode = "Year-Based"
|
||||
for member_name, member_stats in stats.members.items():
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Data Source": member_name,
|
||||
"Temporal Mode": temporal_mode.capitalize(),
|
||||
"Member": member_name,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
return (
|
||||
pd.DataFrame(rows)
|
||||
.sort_values(["Grid_Level_Sort", "Temporal Mode", "Member"])
|
||||
.drop(columns="Grid_Level_Sort")
|
||||
)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
||||
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
||||
# @st.cache_data # ty:ignore[invalid-argument-type]
|
||||
def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode, DatasetStatistics]]:
|
||||
"""Precompute dataset statistics for all grid-level combinations and temporal modes."""
|
||||
cache_file = entropice.utils.paths.get_dataset_stats_cache()
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "rb") as f:
|
||||
dataset_stats = pickle.load(f)
|
||||
return dataset_stats
|
||||
|
||||
dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {}
|
||||
for grid_config in grid_configs:
|
||||
dataset_stats[grid_config.id] = {}
|
||||
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
||||
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
||||
total_cells = len(grid_gdf)
|
||||
assert total_cells > 0, "Grid must contain at least one cell."
|
||||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||
for temporal_mode in all_temporal_modes:
|
||||
e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode)
|
||||
target_statistics = {}
|
||||
for target in all_target_datasets:
|
||||
target_statistics[target] = TargetStatistics.compute(
|
||||
grid=grid_config.grid,
|
||||
level=grid_config.level,
|
||||
target=target,
|
||||
total_cells=total_cells,
|
||||
)
|
||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||
for member in all_l2_source_datasets:
|
||||
member_statistics[member] = MemberStatistics.compute(
|
||||
grid=grid_config.grid, level=grid_config.level, member=member
|
||||
)
|
||||
if isinstance(temporal_mode, int) and target == "darts_mllabels":
|
||||
# darts_mllabels does not support year-based temporal modes
|
||||
continue
|
||||
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
||||
member_statistics = MemberStatistics.compute(e)
|
||||
|
||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
||||
ts.size_bytes for ts in target_statistics.values()
|
||||
)
|
||||
dataset_stats[grid_config.id] = DatasetStatistics(
|
||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
||||
dataset_stats[grid_config.id][temporal_mode] = DatasetStatistics(
|
||||
total_features=total_features,
|
||||
total_cells=total_cells,
|
||||
size_bytes=total_size_bytes,
|
||||
|
|
@ -248,95 +266,12 @@ def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
|||
target=target_statistics,
|
||||
)
|
||||
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(dataset_stats, f)
|
||||
|
||||
return dataset_stats
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnsembleMemberStatistics:
|
||||
n_features: int # Number of features from this member in the ensemble
|
||||
p_features: float # Percentage of features from this member in the ensemble
|
||||
n_nanrows: int # Number of rows which contain any NaN
|
||||
size_bytes: int # Size of this member's data in the ensemble in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(
|
||||
cls,
|
||||
dataset: gpd.GeoDataFrame,
|
||||
member: L2SourceDataset,
|
||||
n_features: int,
|
||||
) -> "EnsembleMemberStatistics":
|
||||
if member == "AlphaEarth":
|
||||
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
|
||||
elif member == "ArcticDEM":
|
||||
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
era5_cols = dataset.columns.str.startswith("era5_")
|
||||
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
|
||||
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
|
||||
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
|
||||
if member == "ERA5-yearly":
|
||||
member_dataset = dataset[cols_with_three_splits]
|
||||
elif member == "ERA5-seasonal":
|
||||
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
|
||||
elif member == "ERA5-shoulder":
|
||||
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
|
||||
n_features_member = len(member_dataset.columns)
|
||||
p_features_member = n_features_member / n_features
|
||||
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
|
||||
return EnsembleMemberStatistics(
|
||||
n_features=n_features_member,
|
||||
p_features=p_features_member,
|
||||
n_nanrows=n_rows_with_nan_member,
|
||||
size_bytes=size_bytes_member,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnsembleDatasetStatistics:
|
||||
"""Statistics for a specified composition / ensemble at a specific grid and level.
|
||||
|
||||
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
|
||||
That way, the real number of features and cells after NaN filtering can be reported.
|
||||
"""
|
||||
|
||||
n_features: int # Number of features in the dataset
|
||||
n_cells: int # Number of grid cells covered
|
||||
n_nanrows: int # Number of rows which contain any NaN
|
||||
size_bytes: int # Size of the dataset on disk in bytes
|
||||
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
|
||||
|
||||
@classmethod
|
||||
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
|
||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||
# Assert that no column in all-nan
|
||||
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
|
||||
|
||||
size_bytes = dataset.memory_usage(deep=True).sum()
|
||||
n_features = len(dataset.columns)
|
||||
n_cells = len(dataset)
|
||||
|
||||
# Number of rows which contain any NaN
|
||||
n_rows_with_nan = dataset.isna().any(axis=1).sum()
|
||||
|
||||
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
|
||||
for member in ensemble.members:
|
||||
member_statistics[member] = EnsembleMemberStatistics.compute(
|
||||
dataset=dataset,
|
||||
member=member,
|
||||
n_features=n_features,
|
||||
)
|
||||
return cls(
|
||||
n_features=n_features,
|
||||
n_cells=n_cells,
|
||||
n_nanrows=n_rows_with_nan,
|
||||
size_bytes=size_bytes,
|
||||
members=member_statistics,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingDatasetStatistics:
|
||||
n_samples: int # Total number of samples in the dataset
|
||||
|
|
@ -346,14 +281,14 @@ class TrainingDatasetStatistics:
|
|||
n_test_samples: int # Number of cells used for testing
|
||||
train_test_ratio: float # Ratio of training to test samples
|
||||
|
||||
class_labels: list[str] # Ordered list of class labels
|
||||
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
|
||||
n_classes: int # Number of classes
|
||||
class_labels: list[str] | None # Ordered list of class labels
|
||||
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] | None # Min/max raw values
|
||||
n_classes: int | None # Number of classes
|
||||
|
||||
training_class_counts: dict[str, int] # Class counts in training set
|
||||
training_class_distribution: dict[str, float] # Class percentages in training set
|
||||
test_class_counts: dict[str, int] # Class counts in test set
|
||||
test_class_distribution: dict[str, float] # Class percentages in test set
|
||||
training_class_counts: dict[str, int] | None # Class counts in training set
|
||||
training_class_distribution: dict[str, float] | None # Class percentages in training set
|
||||
test_class_counts: dict[str, int] | None # Class counts in test set
|
||||
test_class_distribution: dict[str, float] | None # Class percentages in test set
|
||||
|
||||
raw_value_min: float # Minimum raw target value
|
||||
raw_value_max: float # Maximum raw target value
|
||||
|
|
@ -361,7 +296,7 @@ class TrainingDatasetStatistics:
|
|||
raw_value_median: float # Median raw target value
|
||||
raw_value_std: float # Standard deviation of raw target values
|
||||
|
||||
imbalance_ratio: float # Smallest class count / largest class count (overall)
|
||||
imbalance_ratio: float | None # Smallest class count / largest class count (overall)
|
||||
size_bytes: int # Total memory usage of features in bytes
|
||||
|
||||
@classmethod
|
||||
|
|
@ -369,58 +304,64 @@ class TrainingDatasetStatistics:
|
|||
cls,
|
||||
ensemble: DatasetEnsemble,
|
||||
task: Task,
|
||||
dataset: gpd.GeoDataFrame | None = None,
|
||||
target: TargetDataset,
|
||||
) -> "TrainingDatasetStatistics":
|
||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
|
||||
training_dataset = ensemble.create_training_set(task=task, target=target)
|
||||
|
||||
# Sample counts
|
||||
n_samples = len(categorical_dataset)
|
||||
n_training_samples = len(categorical_dataset.y.train)
|
||||
n_test_samples = len(categorical_dataset.y.test)
|
||||
n_samples = len(training_dataset)
|
||||
n_training_samples = (training_dataset.split == "train").sum()
|
||||
n_test_samples = (training_dataset.split == "test").sum()
|
||||
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
|
||||
|
||||
# Feature statistics
|
||||
n_features = len(categorical_dataset.X.data.columns)
|
||||
feature_names = list(categorical_dataset.X.data.columns)
|
||||
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
|
||||
n_features = len(training_dataset.features.columns)
|
||||
feature_names = training_dataset.features.columns.tolist()
|
||||
size_bytes = training_dataset.features.memory_usage(deep=True).sum()
|
||||
|
||||
# Class information
|
||||
class_labels = categorical_dataset.y.labels
|
||||
class_intervals = categorical_dataset.y.intervals
|
||||
class_labels = training_dataset.target_labels
|
||||
if class_labels is None:
|
||||
class_intervals = None
|
||||
n_classes = None
|
||||
training_class_counts = None
|
||||
training_class_distribution = None
|
||||
test_class_counts = None
|
||||
test_class_distribution = None
|
||||
imbalance_ratio = None
|
||||
else:
|
||||
class_intervals = training_dataset.target_intervals
|
||||
n_classes = len(class_labels)
|
||||
|
||||
# Training class distribution
|
||||
train_y_series = pd.Series(categorical_dataset.y.train)
|
||||
train_counts = train_y_series.value_counts().sort_index()
|
||||
train_y = training_dataset.targets[training_dataset.split == "train"]["y"]
|
||||
test_y = training_dataset.targets[training_dataset.split == "test"]["y"]
|
||||
|
||||
train_counts = train_y.value_counts().sort_index()
|
||||
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
|
||||
train_total = sum(training_class_counts.values())
|
||||
training_class_distribution = {
|
||||
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
|
||||
}
|
||||
|
||||
# Test class distribution
|
||||
test_y_series = pd.Series(categorical_dataset.y.test)
|
||||
test_counts = test_y_series.value_counts().sort_index()
|
||||
test_counts = test_y.value_counts().sort_index()
|
||||
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
|
||||
test_total = sum(test_class_counts.values())
|
||||
test_class_distribution = {
|
||||
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
|
||||
}
|
||||
|
||||
# Imbalance ratio (smallest class / largest class across both splits)
|
||||
all_counts = list(train_counts.values()) + list(test_class_counts.values())
|
||||
nonzero_counts = [c for c in all_counts if c > 0]
|
||||
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
||||
|
||||
# Raw value statistics
|
||||
raw_values = categorical_dataset.y.raw_values
|
||||
raw_values = training_dataset.targets["z"]
|
||||
raw_value_min = float(raw_values.min())
|
||||
raw_value_max = float(raw_values.max())
|
||||
raw_value_mean = float(raw_values.mean())
|
||||
raw_value_median = float(raw_values.median())
|
||||
raw_value_std = float(raw_values.std())
|
||||
|
||||
# Imbalance ratio (smallest class / largest class across both splits)
|
||||
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
|
||||
nonzero_counts = [c for c in all_counts if c > 0]
|
||||
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
||||
|
||||
return cls(
|
||||
n_samples=n_samples,
|
||||
n_features=n_features,
|
||||
|
|
|
|||
|
|
@ -1,508 +1,17 @@
|
|||
"""Overview page: List of available result directories with some summary statistics."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.overview import (
|
||||
create_feature_breakdown_donut,
|
||||
create_feature_count_stacked_bar,
|
||||
create_feature_distribution_pie,
|
||||
create_inference_cells_bar,
|
||||
create_sample_count_bar_chart,
|
||||
from entropice.dashboard.sections.dataset_statistics import render_dataset_statistics
|
||||
from entropice.dashboard.sections.experiment_results import (
|
||||
render_experiment_results,
|
||||
render_training_results_summary,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import (
|
||||
DatasetStatistics,
|
||||
load_all_default_dataset_statistics,
|
||||
)
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available training samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
- **Grid Level**: varying by grid type
|
||||
"""
|
||||
)
|
||||
|
||||
# Get sample count DataFrame from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
|
||||
|
||||
# Get color palettes for each target dataset
|
||||
n_tasks = sample_df["Task"].nunique()
|
||||
target_color_maps = {
|
||||
"rts": get_palette("task_types", n_colors=n_tasks),
|
||||
"mllabels": get_palette("data_sources", n_colors=n_tasks),
|
||||
}
|
||||
|
||||
# Create and display bar chart
|
||||
fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Display full table with formatting
|
||||
st.markdown("#### Detailed Sample Counts")
|
||||
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||
# Format coverage as percentage with 2 decimal places
|
||||
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
def render_feature_count_comparison():
|
||||
"""Render static comparison of feature counts across all grid configurations."""
|
||||
st.markdown(
|
||||
"""
|
||||
Comparing dataset characteristics for all grid configurations with all data sources enabled.
|
||||
- **Features**: Total number of input features from all data sources
|
||||
- **Spatial Coverage**: Number of grid cells with complete data coverage
|
||||
"""
|
||||
)
|
||||
|
||||
# Get data from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
# Create and display stacked bar chart
|
||||
fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Add spatial coverage metric
|
||||
n_grids = len(comparison_df)
|
||||
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
||||
|
||||
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
|
||||
st.plotly_chart(fig_cells, use_container_width=True)
|
||||
|
||||
# Display full comparison table with formatting
|
||||
st.markdown("#### Detailed Comparison Table")
|
||||
display_df = comparison_df[
|
||||
[
|
||||
"Grid",
|
||||
"Total Features",
|
||||
"Data Sources",
|
||||
"Inference Cells",
|
||||
]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Total Features", "Inference Cells"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_count_explorer():
|
||||
"""Render interactive detailed configuration explorer using fragments."""
|
||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||
|
||||
# Grid selection
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
grid_level_combined = st.selectbox(
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
help="Select the grid system and resolution level",
|
||||
key="feature_grid_select",
|
||||
)
|
||||
|
||||
with col2:
|
||||
target = st.selectbox(
|
||||
"Target Dataset",
|
||||
options=["darts_rts", "darts_mllabels"],
|
||||
index=0,
|
||||
help="Select the target dataset",
|
||||
key="feature_target_select",
|
||||
)
|
||||
|
||||
# Find the selected grid config
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
|
||||
# Get available members from the stats
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
stats = all_stats[selected_grid_config.id]
|
||||
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
|
||||
|
||||
# Members selection
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
all_members = cast(
|
||||
list[L2SourceDataset],
|
||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||
)
|
||||
|
||||
# Use columns for checkboxes
|
||||
cols = st.columns(len(all_members))
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
|
||||
for idx, member in enumerate(all_members):
|
||||
with cols[idx]:
|
||||
default_value = member in available_members
|
||||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||
selected_members.append(cast(L2SourceDataset, member))
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
# Get statistics from cache (already loaded)
|
||||
grid_stats = all_stats[selected_grid_config.id]
|
||||
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
|
||||
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Get target stats
|
||||
target_stats = grid_stats.target[cast(TargetDataset, target)]
|
||||
|
||||
# High-level metrics
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
st.metric(
|
||||
"Inference Cells",
|
||||
f"{min_cells:,}",
|
||||
help="Number of union of cells across all data sources",
|
||||
)
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
# Use binary task training cells as sample count
|
||||
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = total_features * target_stats.training_cells["binary"]
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
metric_cols = st.columns(4)
|
||||
with metric_cols[0]:
|
||||
st.metric("Features", member_stats.feature_count)
|
||||
with metric_cols[1]:
|
||||
st.metric("Variables", len(member_stats.variable_names))
|
||||
with metric_cols[2]:
|
||||
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
|
||||
st.metric("Shape", dim_str)
|
||||
with metric_cols[3]:
|
||||
total_points = 1
|
||||
for dim_size in member_stats.dimensions.values():
|
||||
total_points *= dim_size
|
||||
st.metric("Data Points", f"{total_points:,}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
# Size on disk
|
||||
size_mb = member_stats.size_bytes / (1024 * 1024)
|
||||
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
|
||||
|
||||
st.markdown("---")
|
||||
else:
|
||||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_dataset_analysis():
|
||||
"""Render the dataset analysis section with sample and feature counts."""
|
||||
st.header("📈 Dataset Analysis")
|
||||
|
||||
# Create tabs for different analysis views
|
||||
analysis_tabs = st.tabs(
|
||||
[
|
||||
"📊 Training Samples",
|
||||
"📈 Dataset Characteristics",
|
||||
"🔍 Feature Breakdown",
|
||||
"⚙️ Configuration Explorer",
|
||||
]
|
||||
)
|
||||
|
||||
with analysis_tabs[0]:
|
||||
st.subheader("Training Samples by Configuration")
|
||||
render_sample_count_overview()
|
||||
|
||||
with analysis_tabs[1]:
|
||||
st.subheader("Dataset Characteristics Across Grid Configurations")
|
||||
render_feature_count_comparison()
|
||||
|
||||
with analysis_tabs[2]:
|
||||
st.subheader("Feature Breakdown by Data Source")
|
||||
# Get data from cache
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
||||
n_sources = len(unique_sources)
|
||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
||||
|
||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||
|
||||
# Sparse Resolution girds
|
||||
for res in ["sparse", "low", "medium"]:
|
||||
cols = st.columns(2)
|
||||
with cols[0]:
|
||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
|
||||
for gc in grid_configs_res:
|
||||
grid_display = gc.display_name
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
||||
with cols[1]:
|
||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
|
||||
for gc in grid_configs_res:
|
||||
grid_display = gc.display_name
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
||||
|
||||
# Create donut charts for each grid configuration
|
||||
# num_grids = len(comparison_df)
|
||||
# cols_per_row = 3
|
||||
# num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
||||
|
||||
# for row_idx in range(num_rows):
|
||||
# cols = st.columns(cols_per_row)
|
||||
# for col_idx in range(cols_per_row):
|
||||
# grid_idx = row_idx * cols_per_row + col_idx
|
||||
# if grid_idx < num_grids:
|
||||
# grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||
# grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
||||
|
||||
# with cols[col_idx]:
|
||||
# fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
|
||||
# st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with analysis_tabs[3]:
|
||||
st.subheader("Interactive Configuration Explorer")
|
||||
render_feature_count_explorer()
|
||||
|
||||
|
||||
def render_training_results_summary(training_results):
|
||||
"""Render summary metrics for training results."""
|
||||
st.header("📊 Training Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
tasks = {tr.settings.task for tr in training_results}
|
||||
st.metric("Tasks", len(tasks))
|
||||
|
||||
with col2:
|
||||
grids = {tr.settings.grid for tr in training_results}
|
||||
st.metric("Grid Types", len(grids))
|
||||
|
||||
with col3:
|
||||
models = {tr.settings.model for tr in training_results}
|
||||
st.metric("Model Types", len(models))
|
||||
|
||||
with col4:
|
||||
latest = training_results[0] # Already sorted by creation time
|
||||
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
||||
st.metric("Latest Run", latest_date)
|
||||
|
||||
|
||||
def render_experiment_results(training_results):
|
||||
"""Render detailed experiment results table and expandable details."""
|
||||
st.header("🎯 Experiment Results")
|
||||
st.subheader("Results Table")
|
||||
|
||||
# Build a summary dataframe
|
||||
summary_data = []
|
||||
for tr in training_results:
|
||||
# Extract best scores from the results dataframe
|
||||
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
||||
|
||||
best_scores = {}
|
||||
for col in score_cols:
|
||||
metric_name = col.replace("mean_test_", "")
|
||||
best_score = tr.results[col].max()
|
||||
best_scores[metric_name] = best_score
|
||||
|
||||
# Get primary metric (usually the first one or accuracy)
|
||||
primary_metric = (
|
||||
"accuracy"
|
||||
if "mean_test_accuracy" in tr.results.columns
|
||||
else score_cols[0].replace("mean_test_", "")
|
||||
if score_cols
|
||||
else "N/A"
|
||||
)
|
||||
primary_score = best_scores.get(primary_metric, 0.0)
|
||||
|
||||
summary_data.append(
|
||||
{
|
||||
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
||||
"Task": tr.settings.task,
|
||||
"Grid": tr.settings.grid,
|
||||
"Level": tr.settings.level,
|
||||
"Model": tr.settings.model,
|
||||
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
||||
"Trials": len(tr.results),
|
||||
"Path": str(tr.path.name),
|
||||
}
|
||||
)
|
||||
|
||||
summary_df = pd.DataFrame(summary_data)
|
||||
|
||||
# Display with color coding for best scores
|
||||
st.dataframe(
|
||||
summary_df,
|
||||
width="stretch",
|
||||
hide_index=True,
|
||||
)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Expandable details for each result
|
||||
st.subheader("Individual Experiment Details")
|
||||
|
||||
for tr in training_results:
|
||||
display_name = (
|
||||
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
|
||||
)
|
||||
with st.expander(display_name):
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.write("**Configuration:**")
|
||||
st.write(f"- **Task:** {tr.settings.task}")
|
||||
st.write(f"- **Grid:** {tr.settings.grid}")
|
||||
st.write(f"- **Level:** {tr.settings.level}")
|
||||
st.write(f"- **Model:** {tr.settings.model}")
|
||||
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
||||
st.write(f"- **Classes:** {tr.settings.classes}")
|
||||
|
||||
st.write("\n**Files:**")
|
||||
st.write("- 📊 search_results.parquet")
|
||||
st.write("- 🧮 best_estimator_state.nc")
|
||||
st.write("- 🎯 predicted_probabilities.parquet")
|
||||
st.write("- ⚙️ search_settings.toml")
|
||||
|
||||
with col2:
|
||||
st.write("**Best Scores:**")
|
||||
|
||||
# Extract all test scores
|
||||
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
||||
|
||||
if score_cols:
|
||||
metric_data = []
|
||||
for col in score_cols:
|
||||
metric_name = col.replace("mean_test_", "").title()
|
||||
best_score = tr.results[col].max()
|
||||
mean_score = tr.results[col].mean()
|
||||
std_score = tr.results[col].std()
|
||||
|
||||
metric_data.append(
|
||||
{
|
||||
"Metric": metric_name,
|
||||
"Best": f"{best_score:.4f}",
|
||||
"Mean": f"{mean_score:.4f}",
|
||||
"Std": f"{std_score:.4f}",
|
||||
}
|
||||
)
|
||||
|
||||
metric_df = pd.DataFrame(metric_data)
|
||||
st.dataframe(metric_df, width="stretch", hide_index=True)
|
||||
else:
|
||||
st.write("No test scores found in results.")
|
||||
|
||||
# Show parameter space explored
|
||||
if "initial_K" in tr.results.columns: # Common parameter
|
||||
st.write("\n**Parameter Ranges Explored:**")
|
||||
for param in ["initial_K", "eps_cl", "eps_e"]:
|
||||
if param in tr.results.columns:
|
||||
min_val = tr.results[param].min()
|
||||
max_val = tr.results[param].max()
|
||||
unique_vals = tr.results[param].nunique()
|
||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||
|
||||
st.write(f"\n**Path:** `{tr.path}`")
|
||||
|
||||
|
||||
def render_overview_page():
|
||||
|
|
@ -537,7 +46,8 @@ def render_overview_page():
|
|||
st.divider()
|
||||
|
||||
# Render dataset analysis section
|
||||
render_dataset_analysis()
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
render_dataset_statistics(all_stats)
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class AutoGluonSettings:
|
|||
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
||||
"""Combined settings for AutoGluon training."""
|
||||
|
||||
classes: list[str]
|
||||
classes: list[str] | None
|
||||
problem_type: str
|
||||
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ def autogluon_train(
|
|||
combined_settings = AutoGluonTrainingSettings(
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
classes=training_data.y.labels,
|
||||
classes=training_data.target_labels,
|
||||
problem_type=problem_type,
|
||||
)
|
||||
settings_file = results_dir / "training_settings.toml"
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
|||
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
|
||||
pivcols = set(collapsed.index.names) - {"cell_ids"}
|
||||
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
|
||||
collapsed.columns = ["_".join(v) for v in collapsed.columns]
|
||||
collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns]
|
||||
if use_dummy:
|
||||
collapsed = collapsed.dropna(how="all")
|
||||
return collapsed
|
||||
|
|
@ -378,7 +378,8 @@ class DatasetEnsemble:
|
|||
case ("darts_v1" | "darts_v2", int()):
|
||||
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
||||
targets = xr.open_zarr(target_store, consolidated=False)
|
||||
targets = targets.sel(year=self.temporal_mode)
|
||||
case ("darts_mllabels", str()): # Years are not supported
|
||||
target_store = entropice.utils.paths.get_darts_file(
|
||||
grid=self.grid, level=self.level, version="mllabels"
|
||||
|
|
@ -467,6 +468,10 @@ class DatasetEnsemble:
|
|||
|
||||
# Apply the temporal mode
|
||||
match (member.split("-"), self.temporal_mode):
|
||||
case (["ArcticDEM"], _):
|
||||
pass # No temporal dimension
|
||||
case (_, "feature"):
|
||||
pass
|
||||
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
|
||||
ds_mean = ds.mean(dim="year")
|
||||
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
||||
|
|
@ -475,12 +480,8 @@ class DatasetEnsemble:
|
|||
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
|
||||
)
|
||||
ds = xr.merge([ds_mean, ds_trend])
|
||||
case (["ArcticDEM"], "synopsis"):
|
||||
pass # No temporal dimension
|
||||
case (_, int() as year):
|
||||
ds = ds.sel(year=year, drop=True)
|
||||
case (_, "feature"):
|
||||
pass
|
||||
case _:
|
||||
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
|
||||
|
||||
|
|
@ -556,7 +557,7 @@ class DatasetEnsemble:
|
|||
cell_ids: pd.Series,
|
||||
era5_agg: Literal["yearly", "seasonal", "shoulder"],
|
||||
) -> pd.DataFrame:
|
||||
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
|
||||
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=True)
|
||||
era5_df = _collapse_to_dataframe(era5)
|
||||
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
|
||||
# Ensure all target cell_ids are present, fill missing with NaN
|
||||
|
|
@ -565,7 +566,10 @@ class DatasetEnsemble:
|
|||
|
||||
@stopwatch("Preparing AlphaEarth Embeddings")
|
||||
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
||||
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=True)
|
||||
# For non-synopsis modes there is only a single data variable "embeddings"
|
||||
if self.temporal_mode != "synopsis":
|
||||
embeddings = embeddings["embeddings"]
|
||||
embeddings_df = _collapse_to_dataframe(embeddings)
|
||||
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
|
||||
# Ensure all target cell_ids are present, fill missing with NaN
|
||||
|
|
@ -574,7 +578,7 @@ class DatasetEnsemble:
|
|||
|
||||
@stopwatch("Preparing ArcticDEM")
|
||||
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
||||
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=False)
|
||||
if len(arcticdem["cell_ids"]) == 0:
|
||||
# No data for these cells - create empty DataFrame with expected columns
|
||||
# Use the Dataset metadata to determine column structure
|
||||
|
|
@ -620,7 +624,7 @@ class DatasetEnsemble:
|
|||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||
|
||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
||||
# @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
||||
def create_training_set(
|
||||
self,
|
||||
task: Task,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import cyclopts
|
||||
import numpy as np
|
||||
|
|
@ -115,17 +116,19 @@ class TrainingSettings(DatasetEnsemble, CVSettings):
|
|||
def random_cv(
|
||||
dataset_ensemble: DatasetEnsemble,
|
||||
settings: CVSettings = CVSettings(),
|
||||
):
|
||||
experiment: str | None = None,
|
||||
) -> Path:
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||
settings (CVSettings): The cross-validation settings.
|
||||
experiment (str | None): Optional experiment name for results directory.
|
||||
|
||||
"""
|
||||
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
|
||||
use_array_api = settings.model == "espa"
|
||||
device = "torch" if use_array_api else "cuda"
|
||||
use_array_api = settings.model != "xgboost"
|
||||
device = "torch" if settings.model == "espa" else "cuda"
|
||||
set_config(array_api_dispatch=use_array_api)
|
||||
|
||||
print("Creating training data...")
|
||||
|
|
@ -173,7 +176,8 @@ def random_cv(
|
|||
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
||||
|
||||
results_dir = get_cv_results_dir(
|
||||
"random_search",
|
||||
experiment=experiment,
|
||||
name="random_search",
|
||||
grid=dataset_ensemble.grid,
|
||||
level=dataset_ensemble.level,
|
||||
task=settings.task,
|
||||
|
|
@ -283,6 +287,7 @@ def random_cv(
|
|||
|
||||
stopwatch.summary()
|
||||
print("Done.")
|
||||
return results_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -139,7 +139,14 @@ def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
|
|||
return cache_dir / f"cells_{cells_hash}.parquet"
|
||||
|
||||
|
||||
def get_dataset_stats_cache() -> Path:
|
||||
cache_dir = DATASET_ENSEMBLES_DIR / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "dataset_stats.pckl"
|
||||
|
||||
|
||||
def get_cv_results_dir(
|
||||
experiment: str | None,
|
||||
name: str,
|
||||
grid: Grid,
|
||||
level: int,
|
||||
|
|
@ -147,7 +154,12 @@ def get_cv_results_dir(
|
|||
) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
||||
if experiment is not None:
|
||||
experiment_dir = RESULTS_DIR / experiment
|
||||
experiment_dir.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
experiment_dir = RESULTS_DIR
|
||||
results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}"
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
return results_dir
|
||||
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ type GridLevel = Literal[
|
|||
"healpix9",
|
||||
"healpix10",
|
||||
]
|
||||
type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
|
||||
type TargetDataset = Literal["darts_v1", "darts_mllabels"]
|
||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||
# TODO: Consider implementing a "timeseries" temporal mode
|
||||
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
||||
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||
type Stage = Literal["train", "inference", "visualization"]
|
||||
|
||||
|
|
@ -37,8 +37,9 @@ class GridConfig:
|
|||
sort_key: str
|
||||
|
||||
@classmethod
|
||||
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
|
||||
def from_grid_level(cls, grid_level: GridLevel | tuple[Literal["hex", "healpix"], int]) -> "GridConfig":
|
||||
"""Create a GridConfig from a GridLevel string."""
|
||||
if isinstance(grid_level, str):
|
||||
if grid_level.startswith("hex"):
|
||||
grid = "hex"
|
||||
level = int(grid_level[3:])
|
||||
|
|
@ -47,7 +48,9 @@ class GridConfig:
|
|||
level = int(grid_level[7:])
|
||||
else:
|
||||
raise ValueError(f"Invalid grid level: {grid_level}")
|
||||
|
||||
else:
|
||||
grid, level = grid_level
|
||||
grid_level: GridLevel = f"{grid}{level}" # ty:ignore[invalid-assignment]
|
||||
display_name = f"{grid.capitalize()}-{level}"
|
||||
|
||||
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
|
||||
|
|
@ -74,8 +77,8 @@ class GridConfig:
|
|||
|
||||
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
||||
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
||||
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
|
||||
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
||||
all_target_datasets: list[TargetDataset] = ["darts_v1", "darts_mllabels"]
|
||||
all_l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
|
|
|
|||
222
tests/test_training.py
Normal file
222
tests/test_training.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""Tests for training.py module, specifically random_cv function.
|
||||
|
||||
This test suite validates the random_cv training function across all model-task
|
||||
combinations using a minimal hex level 3 grid with synopsis temporal mode.
|
||||
|
||||
Test Coverage:
|
||||
- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
|
||||
- Device handling for each model type (torch/CUDA/cuML compatibility)
|
||||
- Multi-label target dataset support
|
||||
- Temporal mode configuration (synopsis)
|
||||
- Output file creation and validation
|
||||
|
||||
Running Tests:
|
||||
# Run all training tests (18 tests total, ~3 iterations each)
|
||||
pixi run pytest tests/test_training.py -v
|
||||
|
||||
# Run only device handling tests
|
||||
pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
|
||||
|
||||
# Run a specific model-task combination
|
||||
pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
|
||||
|
||||
Note: Tests use minimal iterations (3) and level 3 grid for speed.
|
||||
Full production runs use higher iteration counts (100-2000).
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.training import CVSettings, random_cv
|
||||
from entropice.utils.types import Model, Task
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_ensemble():
|
||||
"""Create a minimal DatasetEnsemble for testing.
|
||||
|
||||
Uses hex level 3 grid with synopsis temporal mode for fast testing.
|
||||
"""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3,
|
||||
temporal_mode="synopsis",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_results():
|
||||
"""Clean up results directory after each test.
|
||||
|
||||
This fixture collects the actual result directories created during tests
|
||||
and removes them after the test completes.
|
||||
"""
|
||||
created_dirs = []
|
||||
|
||||
def register_dir(results_dir):
|
||||
"""Register a directory to be cleaned up."""
|
||||
created_dirs.append(results_dir)
|
||||
return results_dir
|
||||
|
||||
yield register_dir
|
||||
|
||||
# Clean up only the directories created during this test
|
||||
for results_dir in created_dirs:
|
||||
if results_dir.exists():
|
||||
shutil.rmtree(results_dir)
|
||||
|
||||
|
||||
# Model-task combinations to test
|
||||
# Note: Not all combinations make sense, but we test all to ensure robustness
|
||||
MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
|
||||
TASKS: list[Task] = ["binary", "count", "density"]
|
||||
|
||||
|
||||
class TestRandomCV:
|
||||
"""Test suite for random_cv function."""
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("task", TASKS)
|
||||
def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
|
||||
"""Test random_cv with all model-task combinations.
|
||||
|
||||
This test runs 3 iterations for each combination to verify:
|
||||
- The function completes without errors
|
||||
- Device handling works correctly for each model type
|
||||
- All output files are created
|
||||
"""
|
||||
# Use darts_v1 as the primary target for all tests
|
||||
settings = CVSettings(
|
||||
n_iter=3,
|
||||
task=task,
|
||||
target="darts_v1",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Run the cross-validation and get the results directory
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify results directory was created
|
||||
assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
|
||||
|
||||
# Verify all expected output files exist
|
||||
expected_files = [
|
||||
"search_settings.toml",
|
||||
"best_estimator_model.pkl",
|
||||
"search_results.parquet",
|
||||
"metrics.toml",
|
||||
"predicted_probabilities.parquet",
|
||||
]
|
||||
|
||||
# Add task-specific files
|
||||
if task in ["binary", "count", "density"]:
|
||||
# All tasks that use classification (including count/density when binned)
|
||||
# Note: count and density without _regimes suffix might be regression
|
||||
if task == "binary" or "_regimes" in task:
|
||||
expected_files.append("confusion_matrix.nc")
|
||||
|
||||
# Add model-specific files
|
||||
if model in ["espa", "xgboost", "rf"]:
|
||||
expected_files.append("best_estimator_state.nc")
|
||||
|
||||
for filename in expected_files:
|
||||
filepath = results_dir / filename
|
||||
assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
|
||||
"""Test that device handling works correctly for each model type.
|
||||
|
||||
Different models require different device configurations:
|
||||
- espa: Uses torch with array API dispatch
|
||||
- xgboost: Uses CUDA without array API dispatch
|
||||
- rf/knn: GPU-accelerated via cuML
|
||||
"""
|
||||
settings = CVSettings(
|
||||
n_iter=3,
|
||||
task="binary", # Simple binary task for device testing
|
||||
target="darts_v1",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# This should complete without device-related errors
|
||||
try:
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
except RuntimeError as e:
|
||||
# Check if error is device-related
|
||||
error_msg = str(e).lower()
|
||||
device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
|
||||
if any(keyword in error_msg for keyword in device_keywords):
|
||||
pytest.fail(f"Device handling error for {model=}: {e}")
|
||||
else:
|
||||
# Re-raise non-device errors
|
||||
raise
|
||||
|
||||
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
|
||||
"""Test random_cv with multi-label target dataset."""
|
||||
settings = CVSettings(
|
||||
n_iter=3,
|
||||
task="binary",
|
||||
target="darts_mllabels",
|
||||
model="espa",
|
||||
)
|
||||
|
||||
# Run the cross-validation and get the results directory
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=test_ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify results were created
|
||||
assert results_dir.exists(), "Results directory not created"
|
||||
assert (results_dir / "search_settings.toml").exists()
|
||||
|
||||
def test_temporal_mode_synopsis(self, cleanup_results):
|
||||
"""Test that temporal_mode='synopsis' is correctly used."""
|
||||
import toml
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3,
|
||||
temporal_mode="synopsis",
|
||||
members=["AlphaEarth"],
|
||||
add_lonlat=True,
|
||||
)
|
||||
|
||||
settings = CVSettings(
|
||||
n_iter=3,
|
||||
task="binary",
|
||||
target="darts_v1",
|
||||
model="espa",
|
||||
)
|
||||
|
||||
# This should use synopsis mode (all years aggregated)
|
||||
results_dir = random_cv(
|
||||
dataset_ensemble=ensemble,
|
||||
settings=settings,
|
||||
experiment="test_training",
|
||||
)
|
||||
cleanup_results(results_dir)
|
||||
|
||||
# Verify the settings were stored correctly
|
||||
assert results_dir.exists(), "Results directory not created"
|
||||
with open(results_dir / "search_settings.toml") as f:
|
||||
stored_settings = toml.load(f)
|
||||
|
||||
assert stored_settings["settings"]["temporal_mode"] == "synopsis"
|
||||
38
tests/validate_datasets.py
Normal file
38
tests/validate_datasets.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import cyclopts
|
||||
from rich import pretty, print, traceback
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import all_temporal_modes, grid_configs
|
||||
|
||||
pretty.install()
|
||||
traceback.install()
|
||||
|
||||
cli = cyclopts.App()
|
||||
|
||||
|
||||
def _gather_ensemble_stats(e: DatasetEnsemble):
|
||||
# Get a small sample of the cell ids
|
||||
sample_cell_ids = e.cell_ids[:5]
|
||||
features = e.make_features(sample_cell_ids)
|
||||
|
||||
print(
|
||||
f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
|
||||
)
|
||||
print(f"Number of feature columns: {len(features.columns)}")
|
||||
|
||||
for member in ["arcticdem", "embeddings", "era5"]:
|
||||
member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
|
||||
print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
|
||||
print()
|
||||
|
||||
|
||||
@cli.default()
|
||||
def validate_datasets():
|
||||
for gc in grid_configs:
|
||||
for temporal_mode in all_temporal_modes:
|
||||
e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
|
||||
_gather_ensemble_stats(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue