Fix training and overview page
This commit is contained in:
parent
4445834895
commit
c9c6af8370
17 changed files with 1643 additions and 1125 deletions
|
|
@ -68,7 +68,7 @@ dependencies = [
|
||||||
"pandas-stubs>=2.3.3.251201,<3",
|
"pandas-stubs>=2.3.3.251201,<3",
|
||||||
"pytest>=9.0.2,<10",
|
"pytest>=9.0.2,<10",
|
||||||
"autogluon-tabular[all,mitra]>=1.5.0",
|
"autogluon-tabular[all,mitra]>=1.5.0",
|
||||||
"shap>=0.50.0,<0.51",
|
"shap>=0.50.0,<0.51", "h5py>=3.15.1,<4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,14 @@ Pages:
|
||||||
- Overview: List of available result directories with some summary statistics.
|
- Overview: List of available result directories with some summary statistics.
|
||||||
- Training Data: Visualization of training data distributions.
|
- Training Data: Visualization of training data distributions.
|
||||||
- Training Results Analysis: Analysis of training results and model performance.
|
- Training Results Analysis: Analysis of training results and model performance.
|
||||||
|
- AutoGluon Analysis: Analysis of AutoGluon training results with SHAP visualizations.
|
||||||
- Model State: Visualization of model state and features.
|
- Model State: Visualization of model state and features.
|
||||||
- Inference: Visualization of inference results.
|
- Inference: Visualization of inference results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
|
||||||
from entropice.dashboard.views.inference_page import render_inference_page
|
from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.views.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
|
|
@ -26,13 +28,14 @@ def main():
|
||||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
|
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
||||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||||
|
|
||||||
pg = st.navigation(
|
pg = st.navigation(
|
||||||
{
|
{
|
||||||
"Overview": [overview_page],
|
"Overview": [overview_page],
|
||||||
"Training": [training_data_page, training_analysis_page],
|
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
||||||
"Model State": [model_state_page],
|
"Model State": [model_state_page],
|
||||||
"Inference": [inference_page],
|
"Inference": [inference_page],
|
||||||
}
|
}
|
||||||
|
|
|
||||||
458
src/entropice/dashboard/plots/dataset_statistics.py
Normal file
458
src/entropice/dashboard/plots/dataset_statistics.py
Normal file
|
|
@ -0,0 +1,458 @@
|
||||||
|
"""Visualization functions for the overview page.
|
||||||
|
|
||||||
|
This module contains reusable plotting functions for dataset analysis visualizations,
|
||||||
|
including sample counts, feature counts, and dataset statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from plotly.subplots import make_subplots
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
|
|
||||||
|
|
||||||
|
def _get_colors_from_members(unique_members: list[str]) -> dict[str, str]:
|
||||||
|
unique_members = sorted(unique_members)
|
||||||
|
unique_sources = {member.split("-")[0] for member in unique_members}
|
||||||
|
source_member_map = {
|
||||||
|
source: [member for member in unique_members if member.split("-")[0] == source] for source in unique_sources
|
||||||
|
}
|
||||||
|
source_color_map = {}
|
||||||
|
for source, members in source_member_map.items():
|
||||||
|
n_members = len(members) + 2 # Avoid very light/dark colors
|
||||||
|
palette = get_palette(source, n_colors=n_members)
|
||||||
|
for idx, member in enumerate(members, start=1):
|
||||||
|
source_color_map[member] = palette[idx]
|
||||||
|
return source_color_map
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_count_bar_chart(sample_df: pd.DataFrame) -> go.Figure:
|
||||||
|
"""Create bar chart showing sample counts by grid, target and temporal mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_df: DataFrame with columns: Grid, Target, Samples (Coverage).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the bar chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert "Grid" in sample_df.columns
|
||||||
|
assert "Target" in sample_df.columns
|
||||||
|
assert "Temporal Mode" in sample_df.columns
|
||||||
|
assert "Samples (Coverage)" in sample_df.columns
|
||||||
|
|
||||||
|
targets = sorted(sample_df["Target"].unique())
|
||||||
|
modes = sorted(sample_df["Temporal Mode"].unique())
|
||||||
|
|
||||||
|
# Create subplots manually to have better control over colors
|
||||||
|
fig = make_subplots(
|
||||||
|
rows=1,
|
||||||
|
cols=len(targets),
|
||||||
|
subplot_titles=[f"Target: {target}" for target in targets],
|
||||||
|
shared_yaxes=True,
|
||||||
|
)
|
||||||
|
# Get color palette for this plot
|
||||||
|
colors = get_palette("Number of Samples", n_colors=len(modes))
|
||||||
|
|
||||||
|
# Track which modes have been added to legend
|
||||||
|
modes_in_legend = set()
|
||||||
|
|
||||||
|
for col_idx, target in enumerate(targets, start=1):
|
||||||
|
target_data = sample_df[sample_df["Target"] == target]
|
||||||
|
|
||||||
|
for mode_idx, mode in enumerate(modes):
|
||||||
|
mode_data = target_data[target_data["Temporal Mode"] == mode]
|
||||||
|
if len(mode_data) == 0:
|
||||||
|
continue
|
||||||
|
color = colors[mode_idx] if colors and mode_idx < len(colors) else None
|
||||||
|
|
||||||
|
# Show legend only the first time each mode appears
|
||||||
|
show_in_legend = mode not in modes_in_legend
|
||||||
|
if show_in_legend:
|
||||||
|
modes_in_legend.add(mode)
|
||||||
|
|
||||||
|
# Create a unique legendgroup per mode so colors are consistent
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
x=mode_data["Grid"],
|
||||||
|
y=mode_data["Samples (Coverage)"],
|
||||||
|
name=mode,
|
||||||
|
marker_color=color,
|
||||||
|
legendgroup=mode, # Group by mode name
|
||||||
|
showlegend=show_in_legend,
|
||||||
|
),
|
||||||
|
row=1,
|
||||||
|
col=col_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
|
||||||
|
barmode="group",
|
||||||
|
height=500,
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_xaxes(tickangle=-45)
|
||||||
|
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_count_stacked_bar(breakdown_df: pd.DataFrame) -> go.Figure:
|
||||||
|
"""Create stacked bar chart showing feature counts by member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the stacked bar chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert "Grid" in breakdown_df.columns
|
||||||
|
assert "Temporal Mode" in breakdown_df.columns
|
||||||
|
assert "Member" in breakdown_df.columns
|
||||||
|
assert "Number of Features" in breakdown_df.columns
|
||||||
|
|
||||||
|
unique_members = sorted(breakdown_df["Member"].unique())
|
||||||
|
temporal_modes = sorted(breakdown_df["Temporal Mode"].unique())
|
||||||
|
source_color_map = _get_colors_from_members(unique_members)
|
||||||
|
|
||||||
|
# Create subplots for each temporal mode
|
||||||
|
fig = make_subplots(
|
||||||
|
rows=1,
|
||||||
|
cols=len(temporal_modes),
|
||||||
|
subplot_titles=[f"Temporal Mode: {mode}" for mode in temporal_modes],
|
||||||
|
shared_yaxes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for col_idx, mode in enumerate(temporal_modes, start=1):
|
||||||
|
mode_data = breakdown_df[breakdown_df["Temporal Mode"] == mode]
|
||||||
|
|
||||||
|
for member in unique_members:
|
||||||
|
member_data = mode_data[mode_data["Member"] == member]
|
||||||
|
color = source_color_map.get(member)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
x=member_data["Grid"],
|
||||||
|
y=member_data["Number of Features"],
|
||||||
|
name=member,
|
||||||
|
marker_color=color,
|
||||||
|
legendgroup=member,
|
||||||
|
showlegend=(col_idx == 1),
|
||||||
|
),
|
||||||
|
row=1,
|
||||||
|
col=col_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title_text="Input Features by Member Across Grid Configurations",
|
||||||
|
barmode="stack",
|
||||||
|
height=500,
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
fig.update_xaxes(tickangle=-45)
|
||||||
|
fig.update_yaxes(title_text="Number of Features", row=1, col=1)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_inference_cells_bar(sample_df: pd.DataFrame) -> go.Figure:
|
||||||
|
"""Create bar chart for inference cells by grid configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_df: DataFrame with columns: Grid, Temporal Mode, Inference Cells.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the bar chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert "Grid" in sample_df.columns
|
||||||
|
assert "Inference Cells" in sample_df.columns
|
||||||
|
|
||||||
|
n_grids = sample_df["Grid"].nunique()
|
||||||
|
mode_colors: list[str] = get_palette("Number of Samples", n_colors=n_grids)
|
||||||
|
|
||||||
|
fig = px.bar(
|
||||||
|
sample_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Inference Cells",
|
||||||
|
color="Grid",
|
||||||
|
title="Spatial Coverage (Grid Cells with Complete Data)",
|
||||||
|
labels={
|
||||||
|
"Grid": "Grid Configuration",
|
||||||
|
"Inference Cells": "Number of Cells",
|
||||||
|
},
|
||||||
|
color_discrete_sequence=mode_colors,
|
||||||
|
text="Inference Cells",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||||
|
fig.update_layout(xaxis_tickangle=-45, height=500, showlegend=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_total_samples_bar(
|
||||||
|
comparison_df: pd.DataFrame,
|
||||||
|
grid_colors: list[str] | None = None,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create bar chart for total samples by grid configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comparison_df: DataFrame with columns: Grid, Total Samples.
|
||||||
|
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the bar chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
fig = px.bar(
|
||||||
|
comparison_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Total Samples",
|
||||||
|
color="Grid",
|
||||||
|
title="Training Samples (Binary Task)",
|
||||||
|
labels={
|
||||||
|
"Grid": "Grid Configuration",
|
||||||
|
"Total Samples": "Number of Samples",
|
||||||
|
},
|
||||||
|
color_discrete_sequence=grid_colors,
|
||||||
|
text="Total Samples",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||||
|
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_breakdown_donut(
|
||||||
|
grid_data: pd.DataFrame,
|
||||||
|
grid_config: str,
|
||||||
|
source_color_map: dict[str, str] | None = None,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create donut chart for feature breakdown by data source for a specific grid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_data: DataFrame with columns: Data Source, Number of Features.
|
||||||
|
grid_config: Grid configuration name for the title.
|
||||||
|
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||||
|
If None, uses default Plotly colors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the donut chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
fig = px.pie(
|
||||||
|
grid_data,
|
||||||
|
names="Data Source",
|
||||||
|
values="Number of Features",
|
||||||
|
title=grid_config,
|
||||||
|
hole=0.4,
|
||||||
|
color_discrete_map=source_color_map,
|
||||||
|
color="Data Source",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_traces(textposition="inside", textinfo="percent")
|
||||||
|
fig.update_layout(showlegend=True, height=350)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_distribution_pie(
|
||||||
|
breakdown_df: pd.DataFrame,
|
||||||
|
source_color_map: dict[str, str] | None = None,
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create pie chart for feature distribution by data source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
breakdown_df: DataFrame with columns: Data Source, Number of Features.
|
||||||
|
source_color_map: Optional dictionary mapping data source names to specific colors.
|
||||||
|
If None, uses default Plotly colors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the pie chart visualization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
fig = px.pie(
|
||||||
|
breakdown_df,
|
||||||
|
names="Data Source",
|
||||||
|
values="Number of Features",
|
||||||
|
title="Feature Distribution by Data Source",
|
||||||
|
hole=0.4,
|
||||||
|
color_discrete_map=source_color_map,
|
||||||
|
color="Data Source",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||||
|
fig.update_layout(height=400)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_member_details_table(member_stats_dict: dict[str, dict]) -> pd.DataFrame:
|
||||||
|
"""Create a detailed table for member statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_stats_dict: Dictionary mapping member names to their statistics.
|
||||||
|
Each stats dict should contain: feature_count, variable_names, dimensions, size_bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with detailed member information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
rows = []
|
||||||
|
for member, stats in member_stats_dict.items():
|
||||||
|
# Calculate total data points
|
||||||
|
total_points = 1
|
||||||
|
for dim_size in stats["dimensions"].values():
|
||||||
|
total_points *= dim_size
|
||||||
|
|
||||||
|
# Format dimension string
|
||||||
|
dim_str = " x ".join([f"{k}={v:,}" for k, v in stats["dimensions"].items()])
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Member": member,
|
||||||
|
"Features": stats["feature_count"],
|
||||||
|
"Variables": len(stats["variable_names"]),
|
||||||
|
"Dimensions": dim_str,
|
||||||
|
"Data Points": f"{total_points:,}",
|
||||||
|
"Size (MB)": f"{stats['size_bytes'] / (1024 * 1024):.2f}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def create_feature_breakdown_donuts_grid(breakdown_df: pd.DataFrame) -> go.Figure:
|
||||||
|
"""Create grid of donut charts showing feature breakdown by member for each grid+temporal mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
breakdown_df: DataFrame with columns: Grid, Temporal Mode, Member, Number of Features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure object containing the grid of donut charts.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert "Grid" in breakdown_df.columns
|
||||||
|
assert "Temporal Mode" in breakdown_df.columns
|
||||||
|
assert "Member" in breakdown_df.columns
|
||||||
|
assert "Number of Features" in breakdown_df.columns
|
||||||
|
|
||||||
|
# Get unique grids and temporal modes
|
||||||
|
grids = breakdown_df["Grid"].unique().tolist()
|
||||||
|
temporal_modes = breakdown_df["Temporal Mode"].unique().tolist()
|
||||||
|
unique_members = sorted(breakdown_df["Member"].unique())
|
||||||
|
source_color_map = _get_colors_from_members(unique_members)
|
||||||
|
|
||||||
|
n_grids = len(grids)
|
||||||
|
n_modes = len(temporal_modes)
|
||||||
|
|
||||||
|
# Create subplot grid (grids as rows, temporal modes as columns)
|
||||||
|
# Don't use subplot_titles - we'll add custom annotations for row/column labels
|
||||||
|
fig = make_subplots(
|
||||||
|
rows=n_grids,
|
||||||
|
cols=n_modes,
|
||||||
|
specs=[[{"type": "pie"}] * n_modes for _ in range(n_grids)],
|
||||||
|
horizontal_spacing=0.01,
|
||||||
|
vertical_spacing=0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all members to ensure consistent coloring across all donuts
|
||||||
|
all_members = sorted(breakdown_df["Member"].unique())
|
||||||
|
|
||||||
|
for row_idx, grid in enumerate(grids, start=1):
|
||||||
|
for col_idx, mode in enumerate(temporal_modes, start=1):
|
||||||
|
subset = breakdown_df[(breakdown_df["Grid"] == grid) & (breakdown_df["Temporal Mode"] == mode)]
|
||||||
|
|
||||||
|
if len(subset) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build labels, values, and colors in consistent order
|
||||||
|
labels = []
|
||||||
|
values = []
|
||||||
|
colors_list = []
|
||||||
|
|
||||||
|
for member in all_members:
|
||||||
|
member_data = subset[subset["Member"] == member]
|
||||||
|
if len(member_data) > 0:
|
||||||
|
labels.append(member)
|
||||||
|
values.append(member_data["Number of Features"].iloc[0])
|
||||||
|
if source_color_map:
|
||||||
|
colors_list.append(source_color_map[member])
|
||||||
|
|
||||||
|
# Only show legend for first subplot
|
||||||
|
show_legend = row_idx == 1 and col_idx == 1
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Pie(
|
||||||
|
labels=labels,
|
||||||
|
values=values,
|
||||||
|
name=f"{grid} - {mode}",
|
||||||
|
hole=0.4,
|
||||||
|
marker={"colors": colors_list} if colors_list else None,
|
||||||
|
textposition="inside",
|
||||||
|
textinfo="percent",
|
||||||
|
showlegend=show_legend,
|
||||||
|
),
|
||||||
|
row=row_idx,
|
||||||
|
col=col_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate appropriate height based on number of rows
|
||||||
|
height = max(500, n_grids * 400)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title_text="Feature Breakdown by Member Across Grid Configurations and Temporal Modes",
|
||||||
|
height=height,
|
||||||
|
showlegend=True,
|
||||||
|
margin={"l": 100, "r": 50, "t": 150, "b": 50}, # Add left margin for row labels
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate subplot positions in paper coordinates
|
||||||
|
# Account for spacing between subplots
|
||||||
|
h_spacing = 0.01
|
||||||
|
v_spacing = 0.02
|
||||||
|
|
||||||
|
# Calculate width and height of each subplot in paper coordinates
|
||||||
|
subplot_width = (1.0 - h_spacing * (n_modes - 1)) / n_modes
|
||||||
|
subplot_height = (1.0 - v_spacing * (n_grids - 1)) / n_grids
|
||||||
|
|
||||||
|
# Add column headers at the top (temporal modes)
|
||||||
|
for col_idx, mode in enumerate(temporal_modes):
|
||||||
|
# Calculate x position for this column's center
|
||||||
|
x_pos = col_idx * (subplot_width + h_spacing) + subplot_width / 2
|
||||||
|
fig.add_annotation(
|
||||||
|
text=f"<b>{mode}</b>",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=x_pos,
|
||||||
|
y=1.01,
|
||||||
|
showarrow=False,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="bottom",
|
||||||
|
font={"size": 18},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add row headers on the left (grid configurations)
|
||||||
|
for row_idx, grid in enumerate(grids):
|
||||||
|
# Calculate y position for this row's center
|
||||||
|
# Note: y coordinates go from bottom to top, but rows go from top to bottom
|
||||||
|
y_pos = 1.0 - (row_idx * (subplot_height + v_spacing) + subplot_height / 2)
|
||||||
|
fig.add_annotation(
|
||||||
|
text=f"<b>{grid}</b>",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=-0.01,
|
||||||
|
y=y_pos,
|
||||||
|
showarrow=False,
|
||||||
|
xanchor="right",
|
||||||
|
yanchor="middle",
|
||||||
|
font={"size": 18},
|
||||||
|
textangle=-90,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
@ -1,276 +0,0 @@
|
||||||
"""Visualization functions for the overview page.
|
|
||||||
|
|
||||||
This module contains reusable plotting functions for dataset analysis visualizations,
|
|
||||||
including sample counts, feature counts, and dataset statistics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import plotly.express as px
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
|
|
||||||
|
|
||||||
def create_sample_count_heatmap(
|
|
||||||
pivot_df: pd.DataFrame,
|
|
||||||
target: str,
|
|
||||||
colorscale: list[str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create heatmap showing sample counts by grid and task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pivot_df: Pivoted dataframe with Grid as index, Task as columns, and sample counts as values.
|
|
||||||
target: Target dataset name for the title.
|
|
||||||
colorscale: Optional color palette for the heatmap. If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the heatmap visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.imshow(
|
|
||||||
pivot_df,
|
|
||||||
labels={
|
|
||||||
"x": "Task",
|
|
||||||
"y": "Grid Configuration",
|
|
||||||
"color": "Sample Count",
|
|
||||||
},
|
|
||||||
x=pivot_df.columns,
|
|
||||||
y=pivot_df.index,
|
|
||||||
color_continuous_scale=colorscale,
|
|
||||||
aspect="auto",
|
|
||||||
title=f"Target: {target}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add text annotations
|
|
||||||
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
|
|
||||||
fig.update_layout(height=400)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_sample_count_bar_chart(
|
|
||||||
sample_df: pd.DataFrame,
|
|
||||||
target_color_maps: dict[str, list[str]] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create bar chart showing sample counts by grid, target, and task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
|
|
||||||
target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
|
|
||||||
If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the bar chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Create subplots manually to have better control over colors
|
|
||||||
from plotly.subplots import make_subplots
|
|
||||||
|
|
||||||
targets = sorted(sample_df["Target"].unique())
|
|
||||||
tasks = sorted(sample_df["Task"].unique())
|
|
||||||
|
|
||||||
fig = make_subplots(
|
|
||||||
rows=1,
|
|
||||||
cols=len(targets),
|
|
||||||
subplot_titles=[f"Target: {target}" for target in targets],
|
|
||||||
shared_yaxes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for col_idx, target in enumerate(targets, 1):
|
|
||||||
target_data = sample_df[sample_df["Target"] == target]
|
|
||||||
# Get color palette for this target
|
|
||||||
colors = target_color_maps.get(target, None) if target_color_maps else None
|
|
||||||
|
|
||||||
for task_idx, task in enumerate(tasks):
|
|
||||||
task_data = target_data[target_data["Task"] == task]
|
|
||||||
color = colors[task_idx] if colors and task_idx < len(colors) else None
|
|
||||||
|
|
||||||
# Create a unique legendgroup per task so colors are consistent
|
|
||||||
fig.add_trace(
|
|
||||||
go.Bar(
|
|
||||||
x=task_data["Grid"],
|
|
||||||
y=task_data["Samples (Coverage)"],
|
|
||||||
name=task,
|
|
||||||
marker_color=color,
|
|
||||||
legendgroup=task, # Group by task name
|
|
||||||
showlegend=(col_idx == 1), # Only show legend for first subplot
|
|
||||||
),
|
|
||||||
row=1,
|
|
||||||
col=col_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
|
|
||||||
barmode="group",
|
|
||||||
height=500,
|
|
||||||
showlegend=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_xaxes(tickangle=-45)
|
|
||||||
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_feature_count_stacked_bar(
|
|
||||||
breakdown_df: pd.DataFrame,
|
|
||||||
source_color_map: dict[str, str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create stacked bar chart showing feature counts by data source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
|
|
||||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
|
||||||
If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the stacked bar chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.bar(
|
|
||||||
breakdown_df,
|
|
||||||
x="Grid",
|
|
||||||
y="Number of Features",
|
|
||||||
color="Data Source",
|
|
||||||
barmode="stack",
|
|
||||||
title="Input Features by Data Source Across Grid Configurations",
|
|
||||||
labels={
|
|
||||||
"Grid": "Grid Configuration",
|
|
||||||
"Number of Features": "Number of Features",
|
|
||||||
},
|
|
||||||
color_discrete_map=source_color_map,
|
|
||||||
text_auto=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(height=500, xaxis_tickangle=-45)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_cells_bar(
|
|
||||||
comparison_df: pd.DataFrame,
|
|
||||||
grid_colors: list[str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create bar chart for inference cells by grid configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
comparison_df: DataFrame with columns: Grid, Inference Cells.
|
|
||||||
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the bar chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.bar(
|
|
||||||
comparison_df,
|
|
||||||
x="Grid",
|
|
||||||
y="Inference Cells",
|
|
||||||
color="Grid",
|
|
||||||
title="Spatial Coverage (Grid Cells with Complete Data)",
|
|
||||||
labels={
|
|
||||||
"Grid": "Grid Configuration",
|
|
||||||
"Inference Cells": "Number of Cells",
|
|
||||||
},
|
|
||||||
color_discrete_sequence=grid_colors,
|
|
||||||
text="Inference Cells",
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
|
||||||
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_total_samples_bar(
|
|
||||||
comparison_df: pd.DataFrame,
|
|
||||||
grid_colors: list[str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create bar chart for total samples by grid configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
comparison_df: DataFrame with columns: Grid, Total Samples.
|
|
||||||
grid_colors: Optional color palette for grid configurations. If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the bar chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.bar(
|
|
||||||
comparison_df,
|
|
||||||
x="Grid",
|
|
||||||
y="Total Samples",
|
|
||||||
color="Grid",
|
|
||||||
title="Training Samples (Binary Task)",
|
|
||||||
labels={
|
|
||||||
"Grid": "Grid Configuration",
|
|
||||||
"Total Samples": "Number of Samples",
|
|
||||||
},
|
|
||||||
color_discrete_sequence=grid_colors,
|
|
||||||
text="Total Samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_traces(texttemplate="%{text:,}", textposition="outside")
|
|
||||||
fig.update_layout(xaxis_tickangle=-45, showlegend=False)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_feature_breakdown_donut(
|
|
||||||
grid_data: pd.DataFrame,
|
|
||||||
grid_config: str,
|
|
||||||
source_color_map: dict[str, str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create donut chart for feature breakdown by data source for a specific grid.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grid_data: DataFrame with columns: Data Source, Number of Features.
|
|
||||||
grid_config: Grid configuration name for the title.
|
|
||||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
|
||||||
If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the donut chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.pie(
|
|
||||||
grid_data,
|
|
||||||
names="Data Source",
|
|
||||||
values="Number of Features",
|
|
||||||
title=grid_config,
|
|
||||||
hole=0.4,
|
|
||||||
color_discrete_map=source_color_map,
|
|
||||||
color="Data Source",
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_traces(textposition="inside", textinfo="percent")
|
|
||||||
fig.update_layout(showlegend=True, height=350)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def create_feature_distribution_pie(
|
|
||||||
breakdown_df: pd.DataFrame,
|
|
||||||
source_color_map: dict[str, str] | None = None,
|
|
||||||
) -> go.Figure:
|
|
||||||
"""Create pie chart for feature distribution by data source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
breakdown_df: DataFrame with columns: Data Source, Number of Features.
|
|
||||||
source_color_map: Optional dictionary mapping data source names to specific colors.
|
|
||||||
If None, uses default Plotly colors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object containing the pie chart visualization.
|
|
||||||
|
|
||||||
"""
|
|
||||||
fig = px.pie(
|
|
||||||
breakdown_df,
|
|
||||||
names="Data Source",
|
|
||||||
values="Number of Features",
|
|
||||||
title="Feature Distribution by Data Source",
|
|
||||||
hole=0.4,
|
|
||||||
color_discrete_map=source_color_map,
|
|
||||||
color="Data Source",
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_traces(textposition="inside", textinfo="percent+label")
|
|
||||||
|
|
||||||
return fig
|
|
||||||
376
src/entropice/dashboard/sections/dataset_statistics.py
Normal file
376
src/entropice/dashboard/sections/dataset_statistics.py
Normal file
|
|
@ -0,0 +1,376 @@
|
||||||
|
"""Dataset Statistics Section for Entropice Dashboard."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.dataset_statistics import (
|
||||||
|
create_feature_breakdown_donuts_grid,
|
||||||
|
create_feature_count_stacked_bar,
|
||||||
|
create_feature_distribution_pie,
|
||||||
|
create_inference_cells_bar,
|
||||||
|
create_member_details_table,
|
||||||
|
create_sample_count_bar_chart,
|
||||||
|
)
|
||||||
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
|
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||||
|
from entropice.utils.types import (
|
||||||
|
GridLevel,
|
||||||
|
L2SourceDataset,
|
||||||
|
TargetDataset,
|
||||||
|
Task,
|
||||||
|
TemporalMode,
|
||||||
|
all_l2_source_datasets,
|
||||||
|
grid_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
type DatasetStatsCache = dict[GridLevel, dict[TemporalMode, DatasetStatistics]]
|
||||||
|
|
||||||
|
|
||||||
|
def render_training_samples_tab(sample_df: pd.DataFrame):
|
||||||
|
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows the number of available training samples for each combination of:
|
||||||
|
- **Task**: binary, count, density
|
||||||
|
- **Target Dataset**: darts_rts, darts_mllabels
|
||||||
|
- **Grid System**: hex, healpix
|
||||||
|
- **Grid Level**: varying by grid type
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and display bar chart
|
||||||
|
fig = create_sample_count_bar_chart(sample_df)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Display full table with formatting
|
||||||
|
st.markdown("#### Detailed Sample Counts")
|
||||||
|
# Format numbers with commas
|
||||||
|
sample_df["Samples (Coverage)"] = sample_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||||
|
# Format coverage as percentage with 2 decimal places
|
||||||
|
sample_df["Coverage %"] = sample_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||||
|
|
||||||
|
st.dataframe(sample_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
|
def render_dataset_characteristics_tab(
|
||||||
|
breakdown_df: pd.DataFrame, sample_df: pd.DataFrame, comparison_df: pd.DataFrame
|
||||||
|
):
|
||||||
|
"""Render static comparison of feature counts across all grid configurations."""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Comparing dataset characteristics for all grid configurations with all data sources enabled.
|
||||||
|
- **Features**: Total number of input features from all data sources
|
||||||
|
- **Spatial Coverage**: Number of grid cells with complete data coverage
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get data from cache
|
||||||
|
|
||||||
|
# Create and display stacked bar chart
|
||||||
|
fig = create_feature_count_stacked_bar(breakdown_df)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Add spatial coverage metric
|
||||||
|
fig_cells = create_inference_cells_bar(sample_df)
|
||||||
|
st.plotly_chart(fig_cells, width="stretch")
|
||||||
|
|
||||||
|
# Display full comparison table with formatting
|
||||||
|
st.markdown("#### Detailed Comparison Table")
|
||||||
|
st.dataframe(comparison_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
|
def render_feature_breakdown_tab(breakdown_df: pd.DataFrame):
|
||||||
|
"""Render feature breakdown by member across grid configurations and temporal modes."""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Visualizing feature contribution of each data source (member) across all grid configurations
|
||||||
|
and temporal modes. Each donut chart represents the feature breakdown for a specific
|
||||||
|
grid-temporal mode combination.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Create and display the grid of donut charts
|
||||||
|
fig = create_feature_breakdown_donuts_grid(breakdown_df)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Display full breakdown table with formatting
|
||||||
|
st.markdown("#### Detailed Breakdown Table")
|
||||||
|
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_configuration_explorer_tab(dataset_stats: DatasetStatsCache): # noqa: C901
|
||||||
|
"""Render interactive detailed configuration explorer."""
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
Explore detailed statistics for a specific dataset configuration.
|
||||||
|
Select your grid, temporal mode, target dataset, task, and data sources to see comprehensive statistics.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grid selection
|
||||||
|
grid_options = [gc.display_name for gc in grid_configs]
|
||||||
|
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
selected_grid_display = st.selectbox(
|
||||||
|
"Grid Configuration",
|
||||||
|
options=grid_options,
|
||||||
|
index=0,
|
||||||
|
help="Select the grid system and resolution level",
|
||||||
|
key="feature_grid_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# Create temporal mode options with proper formatting
|
||||||
|
temporal_mode_options = ["feature", "synopsis"] + [str(year) for year in [2018, 2019, 2020, 2021, 2022, 2023]]
|
||||||
|
temporal_mode_display = st.selectbox(
|
||||||
|
"Temporal Mode",
|
||||||
|
options=temporal_mode_options,
|
||||||
|
index=0,
|
||||||
|
help="Select the temporal mode (feature, synopsis, or specific year)",
|
||||||
|
key="feature_temporal_mode_select",
|
||||||
|
)
|
||||||
|
# Convert back to proper type
|
||||||
|
selected_temporal_mode: TemporalMode
|
||||||
|
if temporal_mode_display.isdigit():
|
||||||
|
selected_temporal_mode = int(temporal_mode_display) # type: ignore[assignment]
|
||||||
|
else:
|
||||||
|
selected_temporal_mode = temporal_mode_display # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
selected_target: TargetDataset = st.selectbox(
|
||||||
|
"Target Dataset",
|
||||||
|
options=["darts_v1", "darts_mllabels"],
|
||||||
|
index=0,
|
||||||
|
help="Select the target dataset",
|
||||||
|
key="feature_target_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
selected_task: Task = st.selectbox(
|
||||||
|
"Task",
|
||||||
|
options=["binary", "count", "density", "count_regimes", "density_regimes"],
|
||||||
|
index=0,
|
||||||
|
help="Select the task type",
|
||||||
|
key="feature_task_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the selected grid config
|
||||||
|
selected_grid_config = next(gc for gc in grid_configs if gc.display_name == selected_grid_display)
|
||||||
|
|
||||||
|
# Get stats for the selected configuration
|
||||||
|
mode_stats = dataset_stats[selected_grid_config.id]
|
||||||
|
|
||||||
|
# Check if the selected temporal mode has stats
|
||||||
|
if selected_temporal_mode not in mode_stats:
|
||||||
|
st.warning(f"No statistics available for {selected_temporal_mode} mode with this grid configuration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
stats = mode_stats[selected_temporal_mode]
|
||||||
|
available_members = list(stats.members.keys())
|
||||||
|
|
||||||
|
# Members selection
|
||||||
|
st.markdown("#### Select Data Sources")
|
||||||
|
|
||||||
|
# Use columns for checkboxes
|
||||||
|
cols = st.columns(len(all_l2_source_datasets))
|
||||||
|
selected_members: list[L2SourceDataset] = []
|
||||||
|
|
||||||
|
for idx, member in enumerate(all_l2_source_datasets):
|
||||||
|
with cols[idx]:
|
||||||
|
default_value = member in available_members
|
||||||
|
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||||
|
selected_members.append(member) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Show results if at least one member is selected
|
||||||
|
if selected_members:
|
||||||
|
# Filter to selected members only
|
||||||
|
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||||
|
|
||||||
|
# Calculate total features for selected members
|
||||||
|
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||||
|
|
||||||
|
# Get target stats if available
|
||||||
|
if selected_target not in stats.target:
|
||||||
|
st.warning(f"Target {selected_target} is not available for this configuration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
target_stats = stats.target[selected_target]
|
||||||
|
|
||||||
|
# Check if task is available
|
||||||
|
if selected_task not in target_stats:
|
||||||
|
st.warning(
|
||||||
|
f"Task {selected_task} is not available for target {selected_target} in {selected_temporal_mode} mode."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
task_stats = target_stats[selected_task]
|
||||||
|
|
||||||
|
# High-level metrics
|
||||||
|
st.markdown("#### Configuration Summary")
|
||||||
|
col1, col2, col3, col4, col5 = st.columns(5)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Features", f"{total_features:,}")
|
||||||
|
with col2:
|
||||||
|
# Calculate minimum cells across all data sources (for inference capability)
|
||||||
|
if selected_member_stats:
|
||||||
|
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||||
|
else:
|
||||||
|
min_cells = 0
|
||||||
|
st.metric(
|
||||||
|
"Inference Cells",
|
||||||
|
f"{min_cells:,}",
|
||||||
|
help="Minimum number of cells across all selected data sources",
|
||||||
|
)
|
||||||
|
with col3:
|
||||||
|
st.metric("Data Sources", len(selected_members))
|
||||||
|
with col4:
|
||||||
|
st.metric("Training Samples", f"{task_stats.training_cells:,}")
|
||||||
|
with col5:
|
||||||
|
# Calculate total data points
|
||||||
|
total_points = total_features * task_stats.training_cells
|
||||||
|
st.metric("Total Data Points", f"{total_points:,}")
|
||||||
|
|
||||||
|
# Task-specific statistics
|
||||||
|
st.markdown("#### Task Statistics")
|
||||||
|
task_col1, task_col2, task_col3 = st.columns(3)
|
||||||
|
with task_col1:
|
||||||
|
st.metric("Task Type", selected_task.replace("_", " ").title())
|
||||||
|
with task_col2:
|
||||||
|
st.metric("Coverage", f"{task_stats.coverage:.2f}%")
|
||||||
|
with task_col3:
|
||||||
|
if task_stats.class_counts:
|
||||||
|
st.metric("Number of Classes", len(task_stats.class_counts))
|
||||||
|
else:
|
||||||
|
st.metric("Task Mode", "Regression")
|
||||||
|
|
||||||
|
# Class distribution for classification tasks
|
||||||
|
if task_stats.class_distribution:
|
||||||
|
st.markdown("#### Class Distribution")
|
||||||
|
class_dist_df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"Class": class_name,
|
||||||
|
"Count": task_stats.class_counts[class_name] if task_stats.class_counts else 0,
|
||||||
|
"Percentage": f"{pct:.2f}%",
|
||||||
|
}
|
||||||
|
for class_name, pct in task_stats.class_distribution.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
# Feature breakdown by source
|
||||||
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
|
|
||||||
|
breakdown_data = []
|
||||||
|
for member, member_stats in selected_member_stats.items():
|
||||||
|
breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Data Source": member,
|
||||||
|
"Number of Features": member_stats.feature_count,
|
||||||
|
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
breakdown_df = pd.DataFrame(breakdown_data)
|
||||||
|
|
||||||
|
# Get all unique data sources and create color map
|
||||||
|
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||||
|
source_color_map_raw = {}
|
||||||
|
for member in unique_members_for_color:
|
||||||
|
source = member.split("-")[0]
|
||||||
|
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||||
|
palette = get_palette(source, n_colors=n_members + 2)
|
||||||
|
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||||
|
source_color_map_raw[member] = palette[idx + 1]
|
||||||
|
|
||||||
|
# Create and display pie chart
|
||||||
|
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Show detailed table
|
||||||
|
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
# Detailed member information
|
||||||
|
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||||
|
# Create detailed table
|
||||||
|
member_details_dict = {
|
||||||
|
member: {
|
||||||
|
"feature_count": ms.feature_count,
|
||||||
|
"variable_names": ms.variable_names,
|
||||||
|
"dimensions": ms.dimensions,
|
||||||
|
"size_bytes": ms.size_bytes,
|
||||||
|
}
|
||||||
|
for member, ms in selected_member_stats.items()
|
||||||
|
}
|
||||||
|
details_df = create_member_details_table(member_details_dict)
|
||||||
|
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
# Individual member details
|
||||||
|
for member, member_stats in selected_member_stats.items():
|
||||||
|
st.markdown(f"### {member}")
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
st.markdown("**Variables:**")
|
||||||
|
vars_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||||
|
for v in member_stats.variable_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(vars_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Dimensions
|
||||||
|
st.markdown("**Dimensions:**")
|
||||||
|
dim_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||||
|
f"{dim_name}: {dim_size:,}</span>"
|
||||||
|
for dim_name, dim_size in member_stats.dimensions.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
else:
|
||||||
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
||||||
|
def render_dataset_statistics(all_stats: DatasetStatsCache):
|
||||||
|
"""Render the dataset statistics section with sample and feature counts."""
|
||||||
|
st.header("📈 Dataset Statistics")
|
||||||
|
|
||||||
|
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
|
||||||
|
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
||||||
|
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||||
|
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
||||||
|
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
||||||
|
|
||||||
|
# Create tabs for different analysis views
|
||||||
|
analysis_tabs = st.tabs(
|
||||||
|
[
|
||||||
|
"📊 Training Samples",
|
||||||
|
"📈 Dataset Characteristics",
|
||||||
|
"🔍 Feature Breakdown",
|
||||||
|
"⚙️ Configuration Explorer",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with analysis_tabs[0]:
|
||||||
|
st.subheader("Training Samples by Configuration")
|
||||||
|
render_training_samples_tab(training_sample_df)
|
||||||
|
|
||||||
|
with analysis_tabs[1]:
|
||||||
|
st.subheader("Dataset Characteristics Across Grid Configurations")
|
||||||
|
render_dataset_characteristics_tab(feature_breakdown_df, inference_sample_df, comparison_df)
|
||||||
|
|
||||||
|
with analysis_tabs[2]:
|
||||||
|
st.subheader("Feature Breakdown by Data Source")
|
||||||
|
render_feature_breakdown_tab(feature_breakdown_df)
|
||||||
|
|
||||||
|
with analysis_tabs[3]:
|
||||||
|
st.subheader("Interactive Configuration Explorer")
|
||||||
|
render_configuration_explorer_tab(all_stats)
|
||||||
158
src/entropice/dashboard/sections/experiment_results.py
Normal file
158
src/entropice/dashboard/sections/experiment_results.py
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
"""Experiment Results Section for the Entropice Dashboard."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
|
from entropice.utils.types import (
|
||||||
|
GridConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_training_results_summary(training_results: list[TrainingResult]):
|
||||||
|
"""Render summary metrics for training results."""
|
||||||
|
st.header("📊 Training Results Summary")
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
experiments = {tr.experiment for tr in training_results if tr.experiment}
|
||||||
|
st.metric("Experiments", len(experiments))
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("Total Runs", len(training_results))
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
models = {tr.settings.model for tr in training_results}
|
||||||
|
st.metric("Model Types", len(models))
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
latest = training_results[0] # Already sorted by creation time
|
||||||
|
latest_datetime = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d %H:%M")
|
||||||
|
st.metric("Latest Run", latest_datetime)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
|
||||||
|
"""Render detailed experiment results table and expandable details."""
|
||||||
|
st.header("🎯 Experiment Results")
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
|
||||||
|
tasks = sorted({tr.settings.task for tr in training_results})
|
||||||
|
models = sorted({tr.settings.model for tr in training_results})
|
||||||
|
grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
|
||||||
|
|
||||||
|
# Create filter columns
|
||||||
|
filter_cols = st.columns(4)
|
||||||
|
|
||||||
|
with filter_cols[0]:
|
||||||
|
if experiments:
|
||||||
|
selected_experiment = st.selectbox(
|
||||||
|
"Experiment",
|
||||||
|
options=["All", *experiments],
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
selected_experiment = "All"
|
||||||
|
|
||||||
|
with filter_cols[1]:
|
||||||
|
selected_task = st.selectbox(
|
||||||
|
"Task",
|
||||||
|
options=["All", *tasks],
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with filter_cols[2]:
|
||||||
|
selected_model = st.selectbox(
|
||||||
|
"Model",
|
||||||
|
options=["All", *models],
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with filter_cols[3]:
|
||||||
|
selected_grid = st.selectbox(
|
||||||
|
"Grid",
|
||||||
|
options=["All", *grids],
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
filtered_results = training_results
|
||||||
|
if selected_experiment != "All":
|
||||||
|
filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
|
||||||
|
if selected_task != "All":
|
||||||
|
filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
|
||||||
|
if selected_model != "All":
|
||||||
|
filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model]
|
||||||
|
if selected_grid != "All":
|
||||||
|
filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
|
||||||
|
|
||||||
|
st.subheader("Results Table")
|
||||||
|
|
||||||
|
summary_df = TrainingResult.to_dataframe(filtered_results)
|
||||||
|
# Display with color coding for best scores
|
||||||
|
st.dataframe(
|
||||||
|
summary_df,
|
||||||
|
width="stretch",
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expandable details for each result
|
||||||
|
st.subheader("Individual Experiment Details")
|
||||||
|
|
||||||
|
for tr in filtered_results:
|
||||||
|
tr_info = tr.display_info
|
||||||
|
display_name = tr_info.get_display_name("model_first")
|
||||||
|
with st.expander(display_name):
|
||||||
|
col1, col2 = st.columns([1, 2])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
grid_config = GridConfig.from_grid_level((tr.settings.grid, tr.settings.level))
|
||||||
|
st.write("**Configuration:**")
|
||||||
|
st.write(f"- **Experiment:** {tr.experiment}")
|
||||||
|
st.write(f"- **Task:** {tr.settings.task}")
|
||||||
|
st.write(f"- **Model:** {tr.settings.model}")
|
||||||
|
st.write(f"- **Grid:** {grid_config.display_name}")
|
||||||
|
st.write(f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}")
|
||||||
|
st.write(f"- **Temporal Mode:** {tr.settings.temporal_mode}")
|
||||||
|
st.write(f"- **Members:** {', '.join(tr.settings.members)}")
|
||||||
|
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
||||||
|
st.write(f"- **Classes:** {tr.settings.classes}")
|
||||||
|
|
||||||
|
st.write("\n**Files:**")
|
||||||
|
for file in tr.files:
|
||||||
|
if file.name == "search_settings.toml":
|
||||||
|
st.write(f"- ⚙️ `{file.name}`")
|
||||||
|
elif file.name == "best_estimator_model.pkl":
|
||||||
|
st.write(f"- 🧮 `{file.name}`")
|
||||||
|
elif file.name == "search_results.parquet":
|
||||||
|
st.write(f"- 📊 `{file.name}`")
|
||||||
|
elif file.name == "predicted_probabilities.parquet":
|
||||||
|
st.write(f"- 🎯 `{file.name}`")
|
||||||
|
else:
|
||||||
|
st.write(f"- 📄 `{file.name}`")
|
||||||
|
with col2:
|
||||||
|
st.write("**CV Score Summary:**")
|
||||||
|
|
||||||
|
# Extract all test scores
|
||||||
|
metric_df = tr.get_metric_dataframe()
|
||||||
|
if metric_df is not None:
|
||||||
|
st.dataframe(metric_df, width="stretch", hide_index=True)
|
||||||
|
else:
|
||||||
|
st.write("No test scores found in results.")
|
||||||
|
|
||||||
|
# Show parameter space explored
|
||||||
|
if "initial_K" in tr.results.columns: # Common parameter
|
||||||
|
st.write("\n**Parameter Ranges Explored:**")
|
||||||
|
for param in ["initial_K", "eps_cl", "eps_e"]:
|
||||||
|
if param in tr.results.columns:
|
||||||
|
min_val = tr.results[param].min()
|
||||||
|
max_val = tr.results[param].max()
|
||||||
|
unique_vals = tr.results[param].nunique()
|
||||||
|
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||||
|
|
||||||
|
with st.expander("Show CV Results DataFrame"):
|
||||||
|
st.dataframe(tr.results, width="stretch", hide_index=True)
|
||||||
|
|
||||||
|
st.write(f"\n**Path:** `{tr.path}`")
|
||||||
|
|
@ -90,7 +90,15 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
|
||||||
A list of hex color strings.
|
A list of hex color strings.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cmap = get_cmap(variable).resampled(n_colors)
|
# Hardcode some common variables to specific colormaps
|
||||||
|
if variable == "ERA5":
|
||||||
|
cmap = load_cmap(name="blue_material").resampled(n_colors)
|
||||||
|
elif variable == "ArcticDEM":
|
||||||
|
cmap = load_cmap(name="deep_purple_material").resampled(n_colors)
|
||||||
|
elif variable == "AlphaEarth":
|
||||||
|
cmap = load_cmap(name="green_material").resampled(n_colors)
|
||||||
|
else:
|
||||||
|
cmap = get_cmap(variable).resampled(n_colors)
|
||||||
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
|
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
|
||||||
return colors
|
return colors
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import antimeridian
|
import antimeridian
|
||||||
import geopandas as gpd
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import toml
|
import toml
|
||||||
|
|
@ -16,9 +15,8 @@ from shapely.geometry import shape
|
||||||
import entropice.spatial.grids
|
import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
|
||||||
from entropice.ml.training import TrainingSettings
|
from entropice.ml.training import TrainingSettings
|
||||||
from entropice.utils.types import L2SourceDataset, Task
|
from entropice.utils.types import GridConfig
|
||||||
|
|
||||||
|
|
||||||
def _fix_hex_geometry(geom):
|
def _fix_hex_geometry(geom):
|
||||||
|
|
@ -32,22 +30,29 @@ def _fix_hex_geometry(geom):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingResult:
|
class TrainingResult:
|
||||||
|
"""Wrapper for training result data and metadata."""
|
||||||
|
|
||||||
path: Path
|
path: Path
|
||||||
|
experiment: str
|
||||||
settings: TrainingSettings
|
settings: TrainingSettings
|
||||||
results: pd.DataFrame
|
results: pd.DataFrame
|
||||||
metrics: dict[str, float]
|
train_metrics: dict[str, float]
|
||||||
confusion_matrix: xr.DataArray
|
test_metrics: dict[str, float]
|
||||||
|
combined_metrics: dict[str, float]
|
||||||
|
confusion_matrix: xr.Dataset | None
|
||||||
created_at: float
|
created_at: float
|
||||||
available_metrics: list[str]
|
available_metrics: list[str]
|
||||||
|
files: list[Path]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "TrainingResult":
|
||||||
"""Load a TrainingResult from a given result directory path."""
|
"""Load a TrainingResult from a given result directory path."""
|
||||||
result_file = result_path / "search_results.parquet"
|
result_file = result_path / "search_results.parquet"
|
||||||
preds_file = result_path / "predicted_probabilities.parquet"
|
preds_file = result_path / "predicted_probabilities.parquet"
|
||||||
settings_file = result_path / "search_settings.toml"
|
settings_file = result_path / "search_settings.toml"
|
||||||
metrics_file = result_path / "test_metrics.toml"
|
metrics_file = result_path / "metrics.toml"
|
||||||
confusion_matrix_file = result_path / "confusion_matrix.nc"
|
confusion_matrix_file = result_path / "confusion_matrix.nc"
|
||||||
|
all_files = list(result_path.iterdir())
|
||||||
if not result_file.exists():
|
if not result_file.exists():
|
||||||
raise FileNotFoundError(f"Missing results file in {result_path}")
|
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||||
if not settings_file.exists():
|
if not settings_file.exists():
|
||||||
|
|
@ -56,28 +61,46 @@ class TrainingResult:
|
||||||
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||||
if not metrics_file.exists():
|
if not metrics_file.exists():
|
||||||
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
raise FileNotFoundError(f"Missing metrics file in {result_path}")
|
||||||
if not confusion_matrix_file.exists():
|
|
||||||
raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
|
|
||||||
|
|
||||||
created_at = result_path.stat().st_ctime
|
created_at = result_path.stat().st_ctime
|
||||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
settings_dict = toml.load(settings_file)["settings"]
|
||||||
|
|
||||||
|
# Handle backward compatibility: add missing fields with defaults
|
||||||
|
if "classes" not in settings_dict:
|
||||||
|
settings_dict["classes"] = None
|
||||||
|
if "param_grid" not in settings_dict:
|
||||||
|
settings_dict["param_grid"] = {}
|
||||||
|
if "cv_splits" not in settings_dict:
|
||||||
|
settings_dict["cv_splits"] = 5
|
||||||
|
if "metrics" not in settings_dict:
|
||||||
|
settings_dict["metrics"] = []
|
||||||
|
|
||||||
|
settings = TrainingSettings(**settings_dict)
|
||||||
results = pd.read_parquet(result_file)
|
results = pd.read_parquet(result_file)
|
||||||
metrics = toml.load(metrics_file)["test_metrics"]
|
metrics = toml.load(metrics_file)
|
||||||
confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
|
if not confusion_matrix_file.exists():
|
||||||
|
confusion_matrix = None
|
||||||
|
else:
|
||||||
|
confusion_matrix = xr.open_dataset(confusion_matrix_file, engine="h5netcdf")
|
||||||
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
path=result_path,
|
path=result_path,
|
||||||
|
experiment=experiment_name or "N/A",
|
||||||
settings=settings,
|
settings=settings,
|
||||||
results=results,
|
results=results,
|
||||||
metrics=metrics,
|
train_metrics=metrics["train_metrics"],
|
||||||
|
test_metrics=metrics["test_metrics"],
|
||||||
|
combined_metrics=metrics["combined_metrics"],
|
||||||
confusion_matrix=confusion_matrix,
|
confusion_matrix=confusion_matrix,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
available_metrics=available_metrics,
|
available_metrics=available_metrics,
|
||||||
|
files=all_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_info(self) -> TrainingResultDisplayInfo:
|
def display_info(self) -> TrainingResultDisplayInfo:
|
||||||
|
"""Get display information for the training result."""
|
||||||
return TrainingResultDisplayInfo(
|
return TrainingResultDisplayInfo(
|
||||||
task=self.settings.task,
|
task=self.settings.task,
|
||||||
model=self.settings.model,
|
model=self.settings.model,
|
||||||
|
|
@ -126,9 +149,70 @@ class TrainingResult:
|
||||||
st.error(f"Error loading predictions: {e}")
|
st.error(f"Error loading predictions: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_metric_dataframe(self) -> pd.DataFrame | None:
|
||||||
|
"""Get a DataFrame of available metrics for this training result."""
|
||||||
|
metric_cols = [col for col in self.results.columns if col.startswith("mean_test_")]
|
||||||
|
if not metric_cols:
|
||||||
|
return None
|
||||||
|
metric_data = []
|
||||||
|
for col in metric_cols:
|
||||||
|
metric_name = col.replace("mean_test_", "").replace("neg_", "").title()
|
||||||
|
metrics = self.results[col]
|
||||||
|
# Check if the metric is negative
|
||||||
|
if col.startswith("mean_test_neg_"):
|
||||||
|
task_multiplier = 1 if self.settings.task != "density" else 100
|
||||||
|
task_multiplier *= -1
|
||||||
|
metrics = metrics * task_multiplier
|
||||||
|
|
||||||
|
metric_data.append(
|
||||||
|
{
|
||||||
|
"Metric": metric_name,
|
||||||
|
"Best": f"{metrics.max():.4f}",
|
||||||
|
"Mean": f"{metrics.mean():.4f}",
|
||||||
|
"Std": f"{metrics.std():.4f}",
|
||||||
|
"Worst": f"{metrics.min():.4f}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(metric_data)
|
||||||
|
|
||||||
|
def _get_best_metric_name(self) -> str:
|
||||||
|
"""Get the primary metric name for a given task."""
|
||||||
|
match self.settings.task:
|
||||||
|
case "binary":
|
||||||
|
return "f1"
|
||||||
|
case "count_regimes" | "density_regimes":
|
||||||
|
return "f1_weighted"
|
||||||
|
case _: # regression tasks
|
||||||
|
return "r2"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_dataframe(training_results: list["TrainingResult"]) -> pd.DataFrame:
|
||||||
|
"""Convert a list of TrainingResult objects to a DataFrame for display."""
|
||||||
|
records = []
|
||||||
|
for tr in training_results:
|
||||||
|
info = tr.display_info
|
||||||
|
best_metric_name = tr._get_best_metric_name()
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"Experiment": tr.experiment if tr.experiment else "N/A",
|
||||||
|
"Task": info.task,
|
||||||
|
"Model": info.model,
|
||||||
|
"Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
|
||||||
|
"Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
|
||||||
|
"Score-Metric": best_metric_name.title(),
|
||||||
|
"Best Models Score (Train-Set)": tr.train_metrics.get(best_metric_name),
|
||||||
|
"Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
|
||||||
|
"Trials": len(tr.results),
|
||||||
|
"Path": str(tr.path.name),
|
||||||
|
}
|
||||||
|
records.append(record)
|
||||||
|
return pd.DataFrame.from_records(records)
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def load_all_training_results() -> list[TrainingResult]:
|
def load_all_training_results() -> list[TrainingResult]:
|
||||||
|
"""Load all training results from the results directory."""
|
||||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||||
training_results: list[TrainingResult] = []
|
training_results: list[TrainingResult] = []
|
||||||
for result_path in results_dir.iterdir():
|
for result_path in results_dir.iterdir():
|
||||||
|
|
@ -136,50 +220,22 @@ def load_all_training_results() -> list[TrainingResult]:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
training_result = TrainingResult.from_path(result_path)
|
training_result = TrainingResult.from_path(result_path)
|
||||||
|
training_results.append(training_result)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
st.warning(f"Skipping incomplete training result: {e}")
|
is_experiment_dir = False
|
||||||
continue
|
for experiment_path in result_path.iterdir():
|
||||||
training_results.append(training_result)
|
if not experiment_path.is_dir():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
experiment_name = experiment_path.name
|
||||||
|
training_result = TrainingResult.from_path(experiment_path, experiment_name)
|
||||||
|
training_results.append(training_result)
|
||||||
|
is_experiment_dir = True
|
||||||
|
except FileNotFoundError as e2:
|
||||||
|
st.warning(f"Skipping incomplete training result: {e2}")
|
||||||
|
if not is_experiment_dir:
|
||||||
|
st.warning(f"Skipping incomplete training result: {e}")
|
||||||
|
|
||||||
# Sort by creation time (most recent first)
|
# Sort by creation time (most recent first)
|
||||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||||
return training_results
|
return training_results
|
||||||
|
|
||||||
|
|
||||||
def load_all_training_data(
|
|
||||||
e: DatasetEnsemble,
|
|
||||||
) -> dict[Task, CategoricalTrainingDataset]:
|
|
||||||
"""Load training data for all three tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: DatasetEnsemble object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
||||||
|
|
||||||
"""
|
|
||||||
dataset = e.create(filter_target_col=e.covcol)
|
|
||||||
return {
|
|
||||||
"binary": e._cat_and_split(dataset, "binary", device="cpu"),
|
|
||||||
"count": e._cat_and_split(dataset, "count", device="cpu"),
|
|
||||||
"density": e._cat_and_split(dataset, "density", device="cpu"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
|
|
||||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: DatasetEnsemble object.
|
|
||||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
xarray.Dataset with the raw data for the specified source.
|
|
||||||
|
|
||||||
"""
|
|
||||||
targets = e._read_target()
|
|
||||||
|
|
||||||
# Load the member data lazily to get metadata
|
|
||||||
ds = e._read_member(source, targets, lazy=False)
|
|
||||||
|
|
||||||
return ds, targets
|
|
||||||
|
|
|
||||||
|
|
@ -3,29 +3,28 @@
|
||||||
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
|
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import geopandas as gpd
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
|
||||||
import xarray as xr
|
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
import entropice.spatial.grids
|
import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
from entropice.dashboard.utils.loaders import TrainingResult
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.utils.types import (
|
from entropice.utils.types import (
|
||||||
Grid,
|
|
||||||
GridLevel,
|
GridLevel,
|
||||||
L2SourceDataset,
|
L2SourceDataset,
|
||||||
TargetDataset,
|
TargetDataset,
|
||||||
Task,
|
Task,
|
||||||
|
TemporalMode,
|
||||||
all_l2_source_datasets,
|
all_l2_source_datasets,
|
||||||
all_target_datasets,
|
all_target_datasets,
|
||||||
all_tasks,
|
all_tasks,
|
||||||
|
all_temporal_modes,
|
||||||
grid_configs,
|
grid_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -41,90 +40,63 @@ class MemberStatistics:
|
||||||
size_bytes: int # Size of this member's data on disk in bytes
|
size_bytes: int # Size of this member's data on disk in bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
|
def compute(cls, e: DatasetEnsemble) -> dict[L2SourceDataset, "MemberStatistics"]:
|
||||||
if member == "AlphaEarth":
|
"""Pre-compute the statistics for a specific dataset member."""
|
||||||
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
|
member_stats = {}
|
||||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
for member in all_l2_source_datasets:
|
||||||
era5_agg = member.split("-")[1]
|
ds = e.read_member(member, lazy=True)
|
||||||
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
|
size_bytes = ds.nbytes
|
||||||
elif member == "ArcticDEM":
|
|
||||||
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Member {member} not implemented.")
|
|
||||||
|
|
||||||
size_bytes = store.stat().st_size
|
n_cols_member = len(ds.data_vars)
|
||||||
ds = xr.open_zarr(store, consolidated=False)
|
for dim in ds.sizes:
|
||||||
|
if dim != "cell_ids":
|
||||||
|
n_cols_member *= ds.sizes[dim]
|
||||||
|
|
||||||
# Delete all coordinates which are not in the dimension
|
member_stats[member] = cls(
|
||||||
for coord in ds.coords:
|
feature_count=n_cols_member,
|
||||||
if coord not in ds.dims:
|
variable_names=list(ds.data_vars),
|
||||||
ds = ds.drop_vars(coord)
|
dimensions=dict(ds.sizes),
|
||||||
n_cols_member = len(ds.data_vars)
|
coordinates=list(ds.coords),
|
||||||
for dim in ds.sizes:
|
size_bytes=size_bytes,
|
||||||
if dim != "cell_ids":
|
)
|
||||||
n_cols_member *= ds.sizes[dim]
|
return member_stats
|
||||||
|
|
||||||
return cls(
|
|
||||||
feature_count=n_cols_member,
|
|
||||||
variable_names=list(ds.data_vars),
|
|
||||||
dimensions=dict(ds.sizes),
|
|
||||||
coordinates=list(ds.coords),
|
|
||||||
size_bytes=size_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TargetStatistics:
|
class TargetStatistics:
|
||||||
"""Statistics for a specific target dataset."""
|
"""Statistics for a specific target dataset."""
|
||||||
|
|
||||||
training_cells: dict[Task, int] # Number of cells used for training
|
training_cells: int # Number of cells used for training
|
||||||
coverage: dict[Task, float] # Percentage of total cells covered by training data
|
coverage: float # Percentage of total cells covered by training data
|
||||||
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
|
class_counts: dict[str, int] | None # Class name to count mapping
|
||||||
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
|
class_distribution: dict[str, float] | None # Class name to percentage mapping
|
||||||
size_bytes: int # Size of the target dataset on disk in bytes
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
|
def compute(cls, e: DatasetEnsemble, target: TargetDataset, total_cells: int) -> dict[Task, "TargetStatistics"]:
|
||||||
if target == "darts_rts":
|
"""Pre-compute the statistics for a specific target dataset."""
|
||||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
|
target_stats = {}
|
||||||
elif target == "darts_mllabels":
|
|
||||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Target {target} not implemented.")
|
|
||||||
target_gdf = gpd.read_parquet(target_store)
|
|
||||||
size_bytes = target_store.stat().st_size
|
|
||||||
training_cells: dict[Task, int] = {}
|
|
||||||
training_coverage: dict[Task, float] = {}
|
|
||||||
class_counts: dict[Task, dict[str, int]] = {}
|
|
||||||
class_distribution: dict[Task, dict[str, float]] = {}
|
|
||||||
for task in all_tasks:
|
for task in all_tasks:
|
||||||
task_col = taskcol[task][target]
|
targets = e.get_targets(target=target, task=task)
|
||||||
cov_col = covcol[target]
|
|
||||||
|
|
||||||
task_gdf = target_gdf[target_gdf[cov_col]]
|
training_cells = len(targets)
|
||||||
training_cells[task] = len(task_gdf)
|
training_coverage = len(targets) / total_cells * 100
|
||||||
training_coverage[task] = len(task_gdf) / total_cells * 100
|
|
||||||
|
|
||||||
model_labels = task_gdf[task_col].dropna()
|
if task in ["count", "density"]:
|
||||||
if task == "binary":
|
class_counts = None
|
||||||
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
class_distribution = None
|
||||||
elif task == "count":
|
|
||||||
binned = bin_values(model_labels.astype(int), task=task)
|
|
||||||
elif task == "density":
|
|
||||||
binned = bin_values(model_labels, task=task)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid task.")
|
assert targets["y"].dtype == "category", "Classification tasks must have categorical target dtype."
|
||||||
counts = binned.value_counts()
|
counts = targets["y"].value_counts()
|
||||||
distribution = counts / counts.sum() * 100
|
distribution = counts / counts.sum() * 100
|
||||||
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
|
class_counts = counts.to_dict()
|
||||||
class_distribution[task] = distribution.to_dict()
|
class_distribution = distribution.to_dict()
|
||||||
return TargetStatistics(
|
target_stats[task] = cls(
|
||||||
training_cells=training_cells,
|
training_cells=training_cells,
|
||||||
coverage=training_coverage,
|
coverage=training_coverage,
|
||||||
class_counts=class_counts,
|
class_counts=class_counts,
|
||||||
class_distribution=class_distribution,
|
class_distribution=class_distribution,
|
||||||
size_bytes=size_bytes,
|
)
|
||||||
)
|
return target_stats
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -139,204 +111,167 @@ class DatasetStatistics:
|
||||||
total_cells: int # Total number of grid cells potentially covered
|
total_cells: int # Total number of grid cells potentially covered
|
||||||
size_bytes: int # Size of the dataset on disk in bytes
|
size_bytes: int # Size of the dataset on disk in bytes
|
||||||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
target: dict[TargetDataset, dict[Task, TargetStatistics]] # Statistics per target dataset and Task
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_sample_count_df(
|
def get_target_sample_count_df(
|
||||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Convert sample count data to DataFrame."""
|
"""Convert sample count data to DataFrame."""
|
||||||
rows = []
|
rows = []
|
||||||
for grid_config in grid_configs:
|
for grid_config in grid_configs:
|
||||||
stats = all_stats[grid_config.id]
|
mode_stats = all_stats[grid_config.id]
|
||||||
for target_name, target_stats in stats.target.items():
|
for temporal_mode, stats in mode_stats.items():
|
||||||
for task in all_tasks:
|
for target_name, target_stats in stats.target.items():
|
||||||
training_cells = target_stats.training_cells[task]
|
# We can assume that all tasks have equal counts, since they only affect distribution, not presence
|
||||||
coverage = target_stats.coverage[task]
|
target_task_stats = target_stats["binary"]
|
||||||
|
n_samples = target_task_stats.training_cells
|
||||||
|
coverage = target_task_stats.coverage
|
||||||
|
for task, target_task_stats in target_stats.items():
|
||||||
|
assert n_samples == target_task_stats.training_cells, (
|
||||||
|
"Inconsistent sample counts across tasks for the same target."
|
||||||
|
)
|
||||||
|
assert coverage == target_task_stats.coverage, (
|
||||||
|
"Inconsistent coverage across tasks for the same target."
|
||||||
|
)
|
||||||
rows.append(
|
rows.append(
|
||||||
{
|
{
|
||||||
"Grid": grid_config.display_name,
|
"Grid": grid_config.display_name,
|
||||||
"Target": target_name.replace("darts_", ""),
|
"Target": target_name.replace("_", " ").title(),
|
||||||
"Task": task.capitalize(),
|
"Temporal Mode": str(temporal_mode).capitalize(),
|
||||||
"Samples (Coverage)": training_cells,
|
"Samples (Coverage)": n_samples,
|
||||||
"Coverage %": coverage,
|
"Coverage %": coverage,
|
||||||
"Grid_Level_Sort": grid_config.sort_key,
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return pd.DataFrame(rows)
|
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_feature_count_df(
|
def get_inference_sample_count_df(
|
||||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Convert feature count data to DataFrame."""
|
"""Convert feature count data to DataFrame."""
|
||||||
rows = []
|
rows = []
|
||||||
for grid_config in grid_configs:
|
for grid_config in grid_configs:
|
||||||
stats = all_stats[grid_config.id]
|
mode_stats = all_stats[grid_config.id]
|
||||||
data_sources = list(stats.members.keys())
|
# We can assume that all temporal modes have equal counts, since they only affect features, not samples
|
||||||
|
stats = mode_stats["feature"]
|
||||||
# Determine minimum cells across all data sources
|
assert len(stats.members) > 0, "Dataset must have at least one member."
|
||||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||||
|
|
||||||
# Get sample count from first target dataset (darts_rts)
|
for temporal_mode, stats in mode_stats.items():
|
||||||
first_target = stats.target["darts_rts"]
|
assert len(stats.members) > 0, "Dataset must have at least one member."
|
||||||
total_samples = first_target.training_cells["binary"]
|
assert (
|
||||||
|
min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values()) == min_cells
|
||||||
|
), "Inconsistent inference cell counts across temporal modes."
|
||||||
|
|
||||||
rows.append(
|
rows.append(
|
||||||
{
|
{
|
||||||
"Grid": grid_config.display_name,
|
"Grid": grid_config.display_name,
|
||||||
"Total Features": stats.total_features,
|
|
||||||
"Data Sources": len(data_sources),
|
|
||||||
"Inference Cells": min_cells,
|
"Inference Cells": min_cells,
|
||||||
"Total Samples": total_samples,
|
|
||||||
"Grid_Level_Sort": grid_config.sort_key,
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return pd.DataFrame(rows)
|
return pd.DataFrame(rows).sort_values("Grid_Level_Sort").drop(columns="Grid_Level_Sort")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_comparison_df(
|
||||||
|
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Convert comparison data to DataFrame for detailed table."""
|
||||||
|
rows = []
|
||||||
|
for grid_config in grid_configs:
|
||||||
|
mode_stats = all_stats[grid_config.id]
|
||||||
|
for temporal_mode, stats in mode_stats.items():
|
||||||
|
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config.display_name,
|
||||||
|
"Temporal Mode": str(temporal_mode).capitalize(),
|
||||||
|
"Total Features": f"{stats.total_features:,}",
|
||||||
|
"Total Cells": f"{stats.total_cells:,}",
|
||||||
|
"Min Cells": f"{min_cells:,}",
|
||||||
|
"Size (MB)": f"{stats.size_bytes / (1024 * 1024):.2f}",
|
||||||
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows).sort_values(["Grid_Level_Sort", "Temporal Mode"]).drop(columns="Grid_Level_Sort")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_feature_breakdown_df(
|
def get_feature_breakdown_df(
|
||||||
all_stats: dict[GridLevel, "DatasetStatistics"],
|
all_stats: dict[GridLevel, dict[TemporalMode, "DatasetStatistics"]],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||||
rows = []
|
rows = []
|
||||||
for grid_config in grid_configs:
|
for grid_config in grid_configs:
|
||||||
stats = all_stats[grid_config.id]
|
mode_stats = all_stats[grid_config.id]
|
||||||
for member_name, member_stats in stats.members.items():
|
added_year_temporal_mode = False
|
||||||
rows.append(
|
for temporal_mode, stats in mode_stats.items():
|
||||||
{
|
if added_year_temporal_mode:
|
||||||
"Grid": grid_config.display_name,
|
continue
|
||||||
"Data Source": member_name,
|
if isinstance(temporal_mode, int):
|
||||||
"Number of Features": member_stats.feature_count,
|
# Only add one year-based temporal mode for the breakdown
|
||||||
"Grid_Level_Sort": grid_config.sort_key,
|
added_year_temporal_mode = True
|
||||||
}
|
temporal_mode = "Year-Based"
|
||||||
)
|
for member_name, member_stats in stats.members.items():
|
||||||
return pd.DataFrame(rows)
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config.display_name,
|
||||||
|
"Temporal Mode": temporal_mode.capitalize(),
|
||||||
|
"Member": member_name,
|
||||||
|
"Number of Features": member_stats.feature_count,
|
||||||
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
pd.DataFrame(rows)
|
||||||
|
.sort_values(["Grid_Level_Sort", "Temporal Mode", "Member"])
|
||||||
|
.drop(columns="Grid_Level_Sort")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
# @st.cache_data # ty:ignore[invalid-argument-type]
|
||||||
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
def load_all_default_dataset_statistics() -> dict[GridLevel, dict[TemporalMode, DatasetStatistics]]:
|
||||||
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
"""Precompute dataset statistics for all grid-level combinations and temporal modes."""
|
||||||
|
cache_file = entropice.utils.paths.get_dataset_stats_cache()
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, "rb") as f:
|
||||||
|
dataset_stats = pickle.load(f)
|
||||||
|
return dataset_stats
|
||||||
|
|
||||||
|
dataset_stats: dict[GridLevel, dict[TemporalMode, DatasetStatistics]] = {}
|
||||||
for grid_config in grid_configs:
|
for grid_config in grid_configs:
|
||||||
|
dataset_stats[grid_config.id] = {}
|
||||||
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
||||||
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
||||||
total_cells = len(grid_gdf)
|
total_cells = len(grid_gdf)
|
||||||
assert total_cells > 0, "Grid must contain at least one cell."
|
assert total_cells > 0, "Grid must contain at least one cell."
|
||||||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
for temporal_mode in all_temporal_modes:
|
||||||
for target in all_target_datasets:
|
e = DatasetEnsemble(grid=grid_config.grid, level=grid_config.level, temporal_mode=temporal_mode)
|
||||||
target_statistics[target] = TargetStatistics.compute(
|
target_statistics = {}
|
||||||
grid=grid_config.grid,
|
for target in all_target_datasets:
|
||||||
level=grid_config.level,
|
if isinstance(temporal_mode, int) and target == "darts_mllabels":
|
||||||
target=target,
|
# darts_mllabels does not support year-based temporal modes
|
||||||
|
continue
|
||||||
|
target_statistics[target] = TargetStatistics.compute(e, target=target, total_cells=total_cells)
|
||||||
|
member_statistics = MemberStatistics.compute(e)
|
||||||
|
|
||||||
|
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||||
|
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values())
|
||||||
|
dataset_stats[grid_config.id][temporal_mode] = DatasetStatistics(
|
||||||
|
total_features=total_features,
|
||||||
total_cells=total_cells,
|
total_cells=total_cells,
|
||||||
)
|
size_bytes=total_size_bytes,
|
||||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
members=member_statistics,
|
||||||
for member in all_l2_source_datasets:
|
target=target_statistics,
|
||||||
member_statistics[member] = MemberStatistics.compute(
|
|
||||||
grid=grid_config.grid, level=grid_config.level, member=member
|
|
||||||
)
|
)
|
||||||
|
|
||||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
with open(cache_file, "wb") as f:
|
||||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
pickle.dump(dataset_stats, f)
|
||||||
ts.size_bytes for ts in target_statistics.values()
|
|
||||||
)
|
|
||||||
dataset_stats[grid_config.id] = DatasetStatistics(
|
|
||||||
total_features=total_features,
|
|
||||||
total_cells=total_cells,
|
|
||||||
size_bytes=total_size_bytes,
|
|
||||||
members=member_statistics,
|
|
||||||
target=target_statistics,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset_stats
|
return dataset_stats
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class EnsembleMemberStatistics:
|
|
||||||
n_features: int # Number of features from this member in the ensemble
|
|
||||||
p_features: float # Percentage of features from this member in the ensemble
|
|
||||||
n_nanrows: int # Number of rows which contain any NaN
|
|
||||||
size_bytes: int # Size of this member's data in the ensemble in bytes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def compute(
|
|
||||||
cls,
|
|
||||||
dataset: gpd.GeoDataFrame,
|
|
||||||
member: L2SourceDataset,
|
|
||||||
n_features: int,
|
|
||||||
) -> "EnsembleMemberStatistics":
|
|
||||||
if member == "AlphaEarth":
|
|
||||||
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
|
|
||||||
elif member == "ArcticDEM":
|
|
||||||
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
|
|
||||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
|
||||||
era5_cols = dataset.columns.str.startswith("era5_")
|
|
||||||
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
|
|
||||||
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
|
|
||||||
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
|
|
||||||
if member == "ERA5-yearly":
|
|
||||||
member_dataset = dataset[cols_with_three_splits]
|
|
||||||
elif member == "ERA5-seasonal":
|
|
||||||
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
|
|
||||||
elif member == "ERA5-shoulder":
|
|
||||||
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Member {member} not implemented.")
|
|
||||||
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
|
|
||||||
n_features_member = len(member_dataset.columns)
|
|
||||||
p_features_member = n_features_member / n_features
|
|
||||||
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
|
|
||||||
return EnsembleMemberStatistics(
|
|
||||||
n_features=n_features_member,
|
|
||||||
p_features=p_features_member,
|
|
||||||
n_nanrows=n_rows_with_nan_member,
|
|
||||||
size_bytes=size_bytes_member,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class EnsembleDatasetStatistics:
|
|
||||||
"""Statistics for a specified composition / ensemble at a specific grid and level.
|
|
||||||
|
|
||||||
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
|
|
||||||
That way, the real number of features and cells after NaN filtering can be reported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
n_features: int # Number of features in the dataset
|
|
||||||
n_cells: int # Number of grid cells covered
|
|
||||||
n_nanrows: int # Number of rows which contain any NaN
|
|
||||||
size_bytes: int # Size of the dataset on disk in bytes
|
|
||||||
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
|
|
||||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
|
||||||
# Assert that no column in all-nan
|
|
||||||
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
|
|
||||||
|
|
||||||
size_bytes = dataset.memory_usage(deep=True).sum()
|
|
||||||
n_features = len(dataset.columns)
|
|
||||||
n_cells = len(dataset)
|
|
||||||
|
|
||||||
# Number of rows which contain any NaN
|
|
||||||
n_rows_with_nan = dataset.isna().any(axis=1).sum()
|
|
||||||
|
|
||||||
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
|
|
||||||
for member in ensemble.members:
|
|
||||||
member_statistics[member] = EnsembleMemberStatistics.compute(
|
|
||||||
dataset=dataset,
|
|
||||||
member=member,
|
|
||||||
n_features=n_features,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
n_features=n_features,
|
|
||||||
n_cells=n_cells,
|
|
||||||
n_nanrows=n_rows_with_nan,
|
|
||||||
size_bytes=size_bytes,
|
|
||||||
members=member_statistics,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TrainingDatasetStatistics:
|
class TrainingDatasetStatistics:
|
||||||
n_samples: int # Total number of samples in the dataset
|
n_samples: int # Total number of samples in the dataset
|
||||||
|
|
@ -346,14 +281,14 @@ class TrainingDatasetStatistics:
|
||||||
n_test_samples: int # Number of cells used for testing
|
n_test_samples: int # Number of cells used for testing
|
||||||
train_test_ratio: float # Ratio of training to test samples
|
train_test_ratio: float # Ratio of training to test samples
|
||||||
|
|
||||||
class_labels: list[str] # Ordered list of class labels
|
class_labels: list[str] | None # Ordered list of class labels
|
||||||
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
|
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] | None # Min/max raw values
|
||||||
n_classes: int # Number of classes
|
n_classes: int | None # Number of classes
|
||||||
|
|
||||||
training_class_counts: dict[str, int] # Class counts in training set
|
training_class_counts: dict[str, int] | None # Class counts in training set
|
||||||
training_class_distribution: dict[str, float] # Class percentages in training set
|
training_class_distribution: dict[str, float] | None # Class percentages in training set
|
||||||
test_class_counts: dict[str, int] # Class counts in test set
|
test_class_counts: dict[str, int] | None # Class counts in test set
|
||||||
test_class_distribution: dict[str, float] # Class percentages in test set
|
test_class_distribution: dict[str, float] | None # Class percentages in test set
|
||||||
|
|
||||||
raw_value_min: float # Minimum raw target value
|
raw_value_min: float # Minimum raw target value
|
||||||
raw_value_max: float # Maximum raw target value
|
raw_value_max: float # Maximum raw target value
|
||||||
|
|
@ -361,7 +296,7 @@ class TrainingDatasetStatistics:
|
||||||
raw_value_median: float # Median raw target value
|
raw_value_median: float # Median raw target value
|
||||||
raw_value_std: float # Standard deviation of raw target values
|
raw_value_std: float # Standard deviation of raw target values
|
||||||
|
|
||||||
imbalance_ratio: float # Smallest class count / largest class count (overall)
|
imbalance_ratio: float | None # Smallest class count / largest class count (overall)
|
||||||
size_bytes: int # Total memory usage of features in bytes
|
size_bytes: int # Total memory usage of features in bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -369,58 +304,64 @@ class TrainingDatasetStatistics:
|
||||||
cls,
|
cls,
|
||||||
ensemble: DatasetEnsemble,
|
ensemble: DatasetEnsemble,
|
||||||
task: Task,
|
task: Task,
|
||||||
dataset: gpd.GeoDataFrame | None = None,
|
target: TargetDataset,
|
||||||
) -> "TrainingDatasetStatistics":
|
) -> "TrainingDatasetStatistics":
|
||||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
training_dataset = ensemble.create_training_set(task=task, target=target)
|
||||||
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
|
|
||||||
|
|
||||||
# Sample counts
|
# Sample counts
|
||||||
n_samples = len(categorical_dataset)
|
n_samples = len(training_dataset)
|
||||||
n_training_samples = len(categorical_dataset.y.train)
|
n_training_samples = (training_dataset.split == "train").sum()
|
||||||
n_test_samples = len(categorical_dataset.y.test)
|
n_test_samples = (training_dataset.split == "test").sum()
|
||||||
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
|
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
|
||||||
|
|
||||||
# Feature statistics
|
# Feature statistics
|
||||||
n_features = len(categorical_dataset.X.data.columns)
|
n_features = len(training_dataset.features.columns)
|
||||||
feature_names = list(categorical_dataset.X.data.columns)
|
feature_names = training_dataset.features.columns.tolist()
|
||||||
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
|
size_bytes = training_dataset.features.memory_usage(deep=True).sum()
|
||||||
|
|
||||||
# Class information
|
# Class information
|
||||||
class_labels = categorical_dataset.y.labels
|
class_labels = training_dataset.target_labels
|
||||||
class_intervals = categorical_dataset.y.intervals
|
if class_labels is None:
|
||||||
n_classes = len(class_labels)
|
class_intervals = None
|
||||||
|
n_classes = None
|
||||||
|
training_class_counts = None
|
||||||
|
training_class_distribution = None
|
||||||
|
test_class_counts = None
|
||||||
|
test_class_distribution = None
|
||||||
|
imbalance_ratio = None
|
||||||
|
else:
|
||||||
|
class_intervals = training_dataset.target_intervals
|
||||||
|
n_classes = len(class_labels)
|
||||||
|
|
||||||
# Training class distribution
|
train_y = training_dataset.targets[training_dataset.split == "train"]["y"]
|
||||||
train_y_series = pd.Series(categorical_dataset.y.train)
|
test_y = training_dataset.targets[training_dataset.split == "test"]["y"]
|
||||||
train_counts = train_y_series.value_counts().sort_index()
|
|
||||||
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
|
|
||||||
train_total = sum(training_class_counts.values())
|
|
||||||
training_class_distribution = {
|
|
||||||
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test class distribution
|
train_counts = train_y.value_counts().sort_index()
|
||||||
test_y_series = pd.Series(categorical_dataset.y.test)
|
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
|
||||||
test_counts = test_y_series.value_counts().sort_index()
|
train_total = sum(training_class_counts.values())
|
||||||
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
|
training_class_distribution = {
|
||||||
test_total = sum(test_class_counts.values())
|
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
|
||||||
test_class_distribution = {
|
}
|
||||||
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
|
test_counts = test_y.value_counts().sort_index()
|
||||||
}
|
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
|
||||||
|
test_total = sum(test_class_counts.values())
|
||||||
|
test_class_distribution = {
|
||||||
|
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Imbalance ratio (smallest class / largest class across both splits)
|
||||||
|
all_counts = list(train_counts.values()) + list(test_class_counts.values())
|
||||||
|
nonzero_counts = [c for c in all_counts if c > 0]
|
||||||
|
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
||||||
|
|
||||||
# Raw value statistics
|
# Raw value statistics
|
||||||
raw_values = categorical_dataset.y.raw_values
|
raw_values = training_dataset.targets["z"]
|
||||||
raw_value_min = float(raw_values.min())
|
raw_value_min = float(raw_values.min())
|
||||||
raw_value_max = float(raw_values.max())
|
raw_value_max = float(raw_values.max())
|
||||||
raw_value_mean = float(raw_values.mean())
|
raw_value_mean = float(raw_values.mean())
|
||||||
raw_value_median = float(raw_values.median())
|
raw_value_median = float(raw_values.median())
|
||||||
raw_value_std = float(raw_values.std())
|
raw_value_std = float(raw_values.std())
|
||||||
|
|
||||||
# Imbalance ratio (smallest class / largest class across both splits)
|
|
||||||
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
|
|
||||||
nonzero_counts = [c for c in all_counts if c > 0]
|
|
||||||
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
n_samples=n_samples,
|
n_samples=n_samples,
|
||||||
n_features=n_features,
|
n_features=n_features,
|
||||||
|
|
|
||||||
|
|
@ -1,508 +1,17 @@
|
||||||
"""Overview page: List of available result directories with some summary statistics."""
|
"""Overview page: List of available result directories with some summary statistics."""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.dashboard.plots.overview import (
|
from entropice.dashboard.sections.dataset_statistics import render_dataset_statistics
|
||||||
create_feature_breakdown_donut,
|
from entropice.dashboard.sections.experiment_results import (
|
||||||
create_feature_count_stacked_bar,
|
render_experiment_results,
|
||||||
create_feature_distribution_pie,
|
render_training_results_summary,
|
||||||
create_inference_cells_bar,
|
|
||||||
create_sample_count_bar_chart,
|
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
|
||||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||||
from entropice.dashboard.utils.stats import (
|
from entropice.dashboard.utils.stats import (
|
||||||
DatasetStatistics,
|
|
||||||
load_all_default_dataset_statistics,
|
load_all_default_dataset_statistics,
|
||||||
)
|
)
|
||||||
from entropice.utils.types import (
|
|
||||||
GridConfig,
|
|
||||||
L2SourceDataset,
|
|
||||||
TargetDataset,
|
|
||||||
grid_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def render_sample_count_overview():
|
|
||||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
This visualization shows the number of available training samples for each combination of:
|
|
||||||
- **Task**: binary, count, density
|
|
||||||
- **Target Dataset**: darts_rts, darts_mllabels
|
|
||||||
- **Grid System**: hex, healpix
|
|
||||||
- **Grid Level**: varying by grid type
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get sample count DataFrame from cache
|
|
||||||
all_stats = load_all_default_dataset_statistics()
|
|
||||||
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
|
|
||||||
|
|
||||||
# Get color palettes for each target dataset
|
|
||||||
n_tasks = sample_df["Task"].nunique()
|
|
||||||
target_color_maps = {
|
|
||||||
"rts": get_palette("task_types", n_colors=n_tasks),
|
|
||||||
"mllabels": get_palette("data_sources", n_colors=n_tasks),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create and display bar chart
|
|
||||||
fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
|
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
|
||||||
|
|
||||||
# Display full table with formatting
|
|
||||||
st.markdown("#### Detailed Sample Counts")
|
|
||||||
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
|
||||||
|
|
||||||
# Format numbers with commas
|
|
||||||
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
|
||||||
# Format coverage as percentage with 2 decimal places
|
|
||||||
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
|
||||||
|
|
||||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
|
||||||
|
|
||||||
|
|
||||||
def render_feature_count_comparison():
|
|
||||||
"""Render static comparison of feature counts across all grid configurations."""
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
Comparing dataset characteristics for all grid configurations with all data sources enabled.
|
|
||||||
- **Features**: Total number of input features from all data sources
|
|
||||||
- **Spatial Coverage**: Number of grid cells with complete data coverage
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get data from cache
|
|
||||||
all_stats = load_all_default_dataset_statistics()
|
|
||||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
|
||||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
|
||||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
|
||||||
|
|
||||||
# Get all unique data sources and create color map
|
|
||||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
|
||||||
n_sources = len(unique_sources)
|
|
||||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
|
||||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
|
||||||
|
|
||||||
# Create and display stacked bar chart
|
|
||||||
fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
|
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
|
||||||
|
|
||||||
# Add spatial coverage metric
|
|
||||||
n_grids = len(comparison_df)
|
|
||||||
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
|
||||||
|
|
||||||
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
|
|
||||||
st.plotly_chart(fig_cells, use_container_width=True)
|
|
||||||
|
|
||||||
# Display full comparison table with formatting
|
|
||||||
st.markdown("#### Detailed Comparison Table")
|
|
||||||
display_df = comparison_df[
|
|
||||||
[
|
|
||||||
"Grid",
|
|
||||||
"Total Features",
|
|
||||||
"Data Sources",
|
|
||||||
"Inference Cells",
|
|
||||||
]
|
|
||||||
].copy()
|
|
||||||
|
|
||||||
# Format numbers with commas
|
|
||||||
for col in ["Total Features", "Inference Cells"]:
|
|
||||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
|
||||||
|
|
||||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
|
||||||
def render_feature_count_explorer():
|
|
||||||
"""Render interactive detailed configuration explorer using fragments."""
|
|
||||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
|
||||||
|
|
||||||
# Grid selection
|
|
||||||
grid_options = [gc.display_name for gc in grid_configs]
|
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
grid_level_combined = st.selectbox(
|
|
||||||
"Grid Configuration",
|
|
||||||
options=grid_options,
|
|
||||||
index=0,
|
|
||||||
help="Select the grid system and resolution level",
|
|
||||||
key="feature_grid_select",
|
|
||||||
)
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
target = st.selectbox(
|
|
||||||
"Target Dataset",
|
|
||||||
options=["darts_rts", "darts_mllabels"],
|
|
||||||
index=0,
|
|
||||||
help="Select the target dataset",
|
|
||||||
key="feature_target_select",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find the selected grid config
|
|
||||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
|
||||||
|
|
||||||
# Get available members from the stats
|
|
||||||
all_stats = load_all_default_dataset_statistics()
|
|
||||||
stats = all_stats[selected_grid_config.id]
|
|
||||||
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
|
|
||||||
|
|
||||||
# Members selection
|
|
||||||
st.markdown("#### Select Data Sources")
|
|
||||||
|
|
||||||
all_members = cast(
|
|
||||||
list[L2SourceDataset],
|
|
||||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use columns for checkboxes
|
|
||||||
cols = st.columns(len(all_members))
|
|
||||||
selected_members: list[L2SourceDataset] = []
|
|
||||||
|
|
||||||
for idx, member in enumerate(all_members):
|
|
||||||
with cols[idx]:
|
|
||||||
default_value = member in available_members
|
|
||||||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
|
||||||
selected_members.append(cast(L2SourceDataset, member))
|
|
||||||
|
|
||||||
# Show results if at least one member is selected
|
|
||||||
if selected_members:
|
|
||||||
# Get statistics from cache (already loaded)
|
|
||||||
grid_stats = all_stats[selected_grid_config.id]
|
|
||||||
|
|
||||||
# Filter to selected members only
|
|
||||||
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
|
|
||||||
|
|
||||||
# Calculate total features for selected members
|
|
||||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
|
||||||
|
|
||||||
# Get target stats
|
|
||||||
target_stats = grid_stats.target[cast(TargetDataset, target)]
|
|
||||||
|
|
||||||
# High-level metrics
|
|
||||||
col1, col2, col3, col4, col5 = st.columns(5)
|
|
||||||
with col1:
|
|
||||||
st.metric("Total Features", f"{total_features:,}")
|
|
||||||
with col2:
|
|
||||||
# Calculate minimum cells across all data sources (for inference capability)
|
|
||||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
|
||||||
st.metric(
|
|
||||||
"Inference Cells",
|
|
||||||
f"{min_cells:,}",
|
|
||||||
help="Number of union of cells across all data sources",
|
|
||||||
)
|
|
||||||
with col3:
|
|
||||||
st.metric("Data Sources", len(selected_members))
|
|
||||||
with col4:
|
|
||||||
# Use binary task training cells as sample count
|
|
||||||
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
|
|
||||||
with col5:
|
|
||||||
# Calculate total data points
|
|
||||||
total_points = total_features * target_stats.training_cells["binary"]
|
|
||||||
st.metric("Total Data Points", f"{total_points:,}")
|
|
||||||
|
|
||||||
# Feature breakdown by source
|
|
||||||
st.markdown("#### Feature Breakdown by Data Source")
|
|
||||||
|
|
||||||
breakdown_data = []
|
|
||||||
for member, member_stats in selected_member_stats.items():
|
|
||||||
breakdown_data.append(
|
|
||||||
{
|
|
||||||
"Data Source": member,
|
|
||||||
"Number of Features": member_stats.feature_count,
|
|
||||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
breakdown_df = pd.DataFrame(breakdown_data)
|
|
||||||
|
|
||||||
# Get all unique data sources and create color map
|
|
||||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
|
||||||
n_sources = len(unique_sources)
|
|
||||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
|
||||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
|
||||||
|
|
||||||
# Create and display pie chart
|
|
||||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
|
|
||||||
st.plotly_chart(fig, width="stretch")
|
|
||||||
|
|
||||||
# Show detailed table
|
|
||||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
|
||||||
|
|
||||||
# Detailed member information
|
|
||||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
|
||||||
for member, member_stats in selected_member_stats.items():
|
|
||||||
st.markdown(f"### {member}")
|
|
||||||
|
|
||||||
metric_cols = st.columns(4)
|
|
||||||
with metric_cols[0]:
|
|
||||||
st.metric("Features", member_stats.feature_count)
|
|
||||||
with metric_cols[1]:
|
|
||||||
st.metric("Variables", len(member_stats.variable_names))
|
|
||||||
with metric_cols[2]:
|
|
||||||
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
|
|
||||||
st.metric("Shape", dim_str)
|
|
||||||
with metric_cols[3]:
|
|
||||||
total_points = 1
|
|
||||||
for dim_size in member_stats.dimensions.values():
|
|
||||||
total_points *= dim_size
|
|
||||||
st.metric("Data Points", f"{total_points:,}")
|
|
||||||
|
|
||||||
# Variables
|
|
||||||
st.markdown("**Variables:**")
|
|
||||||
vars_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
|
||||||
for v in member_stats.variable_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(vars_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Dimensions
|
|
||||||
st.markdown("**Dimensions:**")
|
|
||||||
dim_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
|
||||||
f"{dim_name}: {dim_size}</span>"
|
|
||||||
for dim_name, dim_size in member_stats.dimensions.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Size on disk
|
|
||||||
size_mb = member_stats.size_bytes / (1024 * 1024)
|
|
||||||
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
else:
|
|
||||||
st.info("👆 Select at least one data source to see feature statistics")
|
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_analysis():
|
|
||||||
"""Render the dataset analysis section with sample and feature counts."""
|
|
||||||
st.header("📈 Dataset Analysis")
|
|
||||||
|
|
||||||
# Create tabs for different analysis views
|
|
||||||
analysis_tabs = st.tabs(
|
|
||||||
[
|
|
||||||
"📊 Training Samples",
|
|
||||||
"📈 Dataset Characteristics",
|
|
||||||
"🔍 Feature Breakdown",
|
|
||||||
"⚙️ Configuration Explorer",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with analysis_tabs[0]:
|
|
||||||
st.subheader("Training Samples by Configuration")
|
|
||||||
render_sample_count_overview()
|
|
||||||
|
|
||||||
with analysis_tabs[1]:
|
|
||||||
st.subheader("Dataset Characteristics Across Grid Configurations")
|
|
||||||
render_feature_count_comparison()
|
|
||||||
|
|
||||||
with analysis_tabs[2]:
|
|
||||||
st.subheader("Feature Breakdown by Data Source")
|
|
||||||
# Get data from cache
|
|
||||||
all_stats = load_all_default_dataset_statistics()
|
|
||||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
|
||||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
|
||||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
|
||||||
|
|
||||||
# Get all unique data sources and create color map
|
|
||||||
unique_sources = sorted(breakdown_df["Data Source"].unique())
|
|
||||||
n_sources = len(unique_sources)
|
|
||||||
source_color_list = get_palette("data_sources", n_colors=n_sources)
|
|
||||||
source_color_map = dict(zip(unique_sources, source_color_list))
|
|
||||||
|
|
||||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
|
||||||
|
|
||||||
# Sparse Resolution girds
|
|
||||||
for res in ["sparse", "low", "medium"]:
|
|
||||||
cols = st.columns(2)
|
|
||||||
with cols[0]:
|
|
||||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
|
|
||||||
for gc in grid_configs_res:
|
|
||||||
grid_display = gc.display_name
|
|
||||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
|
||||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
|
||||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
|
||||||
with cols[1]:
|
|
||||||
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
|
|
||||||
for gc in grid_configs_res:
|
|
||||||
grid_display = gc.display_name
|
|
||||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
|
|
||||||
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
|
|
||||||
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
|
|
||||||
|
|
||||||
# Create donut charts for each grid configuration
|
|
||||||
# num_grids = len(comparison_df)
|
|
||||||
# cols_per_row = 3
|
|
||||||
# num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
|
||||||
|
|
||||||
# for row_idx in range(num_rows):
|
|
||||||
# cols = st.columns(cols_per_row)
|
|
||||||
# for col_idx in range(cols_per_row):
|
|
||||||
# grid_idx = row_idx * cols_per_row + col_idx
|
|
||||||
# if grid_idx < num_grids:
|
|
||||||
# grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
|
||||||
# grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
|
||||||
|
|
||||||
# with cols[col_idx]:
|
|
||||||
# fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
|
|
||||||
# st.plotly_chart(fig, use_container_width=True)
|
|
||||||
|
|
||||||
with analysis_tabs[3]:
|
|
||||||
st.subheader("Interactive Configuration Explorer")
|
|
||||||
render_feature_count_explorer()
|
|
||||||
|
|
||||||
|
|
||||||
def render_training_results_summary(training_results):
|
|
||||||
"""Render summary metrics for training results."""
|
|
||||||
st.header("📊 Training Results Summary")
|
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
tasks = {tr.settings.task for tr in training_results}
|
|
||||||
st.metric("Tasks", len(tasks))
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
grids = {tr.settings.grid for tr in training_results}
|
|
||||||
st.metric("Grid Types", len(grids))
|
|
||||||
|
|
||||||
with col3:
|
|
||||||
models = {tr.settings.model for tr in training_results}
|
|
||||||
st.metric("Model Types", len(models))
|
|
||||||
|
|
||||||
with col4:
|
|
||||||
latest = training_results[0] # Already sorted by creation time
|
|
||||||
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
|
||||||
st.metric("Latest Run", latest_date)
|
|
||||||
|
|
||||||
|
|
||||||
def render_experiment_results(training_results):
|
|
||||||
"""Render detailed experiment results table and expandable details."""
|
|
||||||
st.header("🎯 Experiment Results")
|
|
||||||
st.subheader("Results Table")
|
|
||||||
|
|
||||||
# Build a summary dataframe
|
|
||||||
summary_data = []
|
|
||||||
for tr in training_results:
|
|
||||||
# Extract best scores from the results dataframe
|
|
||||||
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
|
||||||
|
|
||||||
best_scores = {}
|
|
||||||
for col in score_cols:
|
|
||||||
metric_name = col.replace("mean_test_", "")
|
|
||||||
best_score = tr.results[col].max()
|
|
||||||
best_scores[metric_name] = best_score
|
|
||||||
|
|
||||||
# Get primary metric (usually the first one or accuracy)
|
|
||||||
primary_metric = (
|
|
||||||
"accuracy"
|
|
||||||
if "mean_test_accuracy" in tr.results.columns
|
|
||||||
else score_cols[0].replace("mean_test_", "")
|
|
||||||
if score_cols
|
|
||||||
else "N/A"
|
|
||||||
)
|
|
||||||
primary_score = best_scores.get(primary_metric, 0.0)
|
|
||||||
|
|
||||||
summary_data.append(
|
|
||||||
{
|
|
||||||
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
|
||||||
"Task": tr.settings.task,
|
|
||||||
"Grid": tr.settings.grid,
|
|
||||||
"Level": tr.settings.level,
|
|
||||||
"Model": tr.settings.model,
|
|
||||||
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
|
||||||
"Trials": len(tr.results),
|
|
||||||
"Path": str(tr.path.name),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
summary_df = pd.DataFrame(summary_data)
|
|
||||||
|
|
||||||
# Display with color coding for best scores
|
|
||||||
st.dataframe(
|
|
||||||
summary_df,
|
|
||||||
width="stretch",
|
|
||||||
hide_index=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Expandable details for each result
|
|
||||||
st.subheader("Individual Experiment Details")
|
|
||||||
|
|
||||||
for tr in training_results:
|
|
||||||
display_name = (
|
|
||||||
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
|
|
||||||
)
|
|
||||||
with st.expander(display_name):
|
|
||||||
col1, col2 = st.columns([1, 2])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
st.write("**Configuration:**")
|
|
||||||
st.write(f"- **Task:** {tr.settings.task}")
|
|
||||||
st.write(f"- **Grid:** {tr.settings.grid}")
|
|
||||||
st.write(f"- **Level:** {tr.settings.level}")
|
|
||||||
st.write(f"- **Model:** {tr.settings.model}")
|
|
||||||
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
|
||||||
st.write(f"- **Classes:** {tr.settings.classes}")
|
|
||||||
|
|
||||||
st.write("\n**Files:**")
|
|
||||||
st.write("- 📊 search_results.parquet")
|
|
||||||
st.write("- 🧮 best_estimator_state.nc")
|
|
||||||
st.write("- 🎯 predicted_probabilities.parquet")
|
|
||||||
st.write("- ⚙️ search_settings.toml")
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
st.write("**Best Scores:**")
|
|
||||||
|
|
||||||
# Extract all test scores
|
|
||||||
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
|
|
||||||
|
|
||||||
if score_cols:
|
|
||||||
metric_data = []
|
|
||||||
for col in score_cols:
|
|
||||||
metric_name = col.replace("mean_test_", "").title()
|
|
||||||
best_score = tr.results[col].max()
|
|
||||||
mean_score = tr.results[col].mean()
|
|
||||||
std_score = tr.results[col].std()
|
|
||||||
|
|
||||||
metric_data.append(
|
|
||||||
{
|
|
||||||
"Metric": metric_name,
|
|
||||||
"Best": f"{best_score:.4f}",
|
|
||||||
"Mean": f"{mean_score:.4f}",
|
|
||||||
"Std": f"{std_score:.4f}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_df = pd.DataFrame(metric_data)
|
|
||||||
st.dataframe(metric_df, width="stretch", hide_index=True)
|
|
||||||
else:
|
|
||||||
st.write("No test scores found in results.")
|
|
||||||
|
|
||||||
# Show parameter space explored
|
|
||||||
if "initial_K" in tr.results.columns: # Common parameter
|
|
||||||
st.write("\n**Parameter Ranges Explored:**")
|
|
||||||
for param in ["initial_K", "eps_cl", "eps_e"]:
|
|
||||||
if param in tr.results.columns:
|
|
||||||
min_val = tr.results[param].min()
|
|
||||||
max_val = tr.results[param].max()
|
|
||||||
unique_vals = tr.results[param].nunique()
|
|
||||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
|
||||||
|
|
||||||
st.write(f"\n**Path:** `{tr.path}`")
|
|
||||||
|
|
||||||
|
|
||||||
def render_overview_page():
|
def render_overview_page():
|
||||||
|
|
@ -537,7 +46,8 @@ def render_overview_page():
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Render dataset analysis section
|
# Render dataset analysis section
|
||||||
render_dataset_analysis()
|
all_stats = load_all_default_dataset_statistics()
|
||||||
|
render_dataset_statistics(all_stats)
|
||||||
|
|
||||||
st.balloons()
|
st.balloons()
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class AutoGluonSettings:
|
||||||
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
|
||||||
"""Combined settings for AutoGluon training."""
|
"""Combined settings for AutoGluon training."""
|
||||||
|
|
||||||
classes: list[str]
|
classes: list[str] | None
|
||||||
problem_type: str
|
problem_type: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -228,7 +228,7 @@ def autogluon_train(
|
||||||
combined_settings = AutoGluonTrainingSettings(
|
combined_settings = AutoGluonTrainingSettings(
|
||||||
**asdict(settings),
|
**asdict(settings),
|
||||||
**asdict(dataset_ensemble),
|
**asdict(dataset_ensemble),
|
||||||
classes=training_data.y.labels,
|
classes=training_data.target_labels,
|
||||||
problem_type=problem_type,
|
problem_type=problem_type,
|
||||||
)
|
)
|
||||||
settings_file = results_dir / "training_settings.toml"
|
settings_file = results_dir / "training_settings.toml"
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
||||||
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
|
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
|
||||||
pivcols = set(collapsed.index.names) - {"cell_ids"}
|
pivcols = set(collapsed.index.names) - {"cell_ids"}
|
||||||
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
|
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
|
||||||
collapsed.columns = ["_".join(v) for v in collapsed.columns]
|
collapsed.columns = ["_".join(map(str, v)) for v in collapsed.columns]
|
||||||
if use_dummy:
|
if use_dummy:
|
||||||
collapsed = collapsed.dropna(how="all")
|
collapsed = collapsed.dropna(how="all")
|
||||||
return collapsed
|
return collapsed
|
||||||
|
|
@ -378,7 +378,8 @@ class DatasetEnsemble:
|
||||||
case ("darts_v1" | "darts_v2", int()):
|
case ("darts_v1" | "darts_v2", int()):
|
||||||
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
|
version: Literal["v1", "v2"] = target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||||
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||||
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
targets = xr.open_zarr(target_store, consolidated=False)
|
||||||
|
targets = targets.sel(year=self.temporal_mode)
|
||||||
case ("darts_mllabels", str()): # Years are not supported
|
case ("darts_mllabels", str()): # Years are not supported
|
||||||
target_store = entropice.utils.paths.get_darts_file(
|
target_store = entropice.utils.paths.get_darts_file(
|
||||||
grid=self.grid, level=self.level, version="mllabels"
|
grid=self.grid, level=self.level, version="mllabels"
|
||||||
|
|
@ -467,6 +468,10 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
# Apply the temporal mode
|
# Apply the temporal mode
|
||||||
match (member.split("-"), self.temporal_mode):
|
match (member.split("-"), self.temporal_mode):
|
||||||
|
case (["ArcticDEM"], _):
|
||||||
|
pass # No temporal dimension
|
||||||
|
case (_, "feature"):
|
||||||
|
pass
|
||||||
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
|
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
|
||||||
ds_mean = ds.mean(dim="year")
|
ds_mean = ds.mean(dim="year")
|
||||||
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
||||||
|
|
@ -475,12 +480,8 @@ class DatasetEnsemble:
|
||||||
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
|
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
|
||||||
)
|
)
|
||||||
ds = xr.merge([ds_mean, ds_trend])
|
ds = xr.merge([ds_mean, ds_trend])
|
||||||
case (["ArcticDEM"], "synopsis"):
|
|
||||||
pass # No temporal dimension
|
|
||||||
case (_, int() as year):
|
case (_, int() as year):
|
||||||
ds = ds.sel(year=year, drop=True)
|
ds = ds.sel(year=year, drop=True)
|
||||||
case (_, "feature"):
|
|
||||||
pass
|
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
|
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
|
||||||
|
|
||||||
|
|
@ -556,7 +557,7 @@ class DatasetEnsemble:
|
||||||
cell_ids: pd.Series,
|
cell_ids: pd.Series,
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"],
|
era5_agg: Literal["yearly", "seasonal", "shoulder"],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
|
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=True)
|
||||||
era5_df = _collapse_to_dataframe(era5)
|
era5_df = _collapse_to_dataframe(era5)
|
||||||
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
|
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
|
||||||
# Ensure all target cell_ids are present, fill missing with NaN
|
# Ensure all target cell_ids are present, fill missing with NaN
|
||||||
|
|
@ -565,7 +566,10 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
@stopwatch("Preparing AlphaEarth Embeddings")
|
@stopwatch("Preparing AlphaEarth Embeddings")
|
||||||
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=True)
|
||||||
|
# For non-synopsis modes there is only a single data variable "embeddings"
|
||||||
|
if self.temporal_mode != "synopsis":
|
||||||
|
embeddings = embeddings["embeddings"]
|
||||||
embeddings_df = _collapse_to_dataframe(embeddings)
|
embeddings_df = _collapse_to_dataframe(embeddings)
|
||||||
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
|
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
|
||||||
# Ensure all target cell_ids are present, fill missing with NaN
|
# Ensure all target cell_ids are present, fill missing with NaN
|
||||||
|
|
@ -574,7 +578,7 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
@stopwatch("Preparing ArcticDEM")
|
@stopwatch("Preparing ArcticDEM")
|
||||||
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=False)
|
||||||
if len(arcticdem["cell_ids"]) == 0:
|
if len(arcticdem["cell_ids"]) == 0:
|
||||||
# No data for these cells - create empty DataFrame with expected columns
|
# No data for these cells - create empty DataFrame with expected columns
|
||||||
# Use the Dataset metadata to determine column structure
|
# Use the Dataset metadata to determine column structure
|
||||||
|
|
@ -620,7 +624,7 @@ class DatasetEnsemble:
|
||||||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||||
|
|
||||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
# @stopwatch.f("Creating training Dataset", print_kwargs=["task", "target", "device", "cache_mode"])
|
||||||
def create_training_set(
|
def create_training_set(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import pickle
|
import pickle
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -115,17 +116,19 @@ class TrainingSettings(DatasetEnsemble, CVSettings):
|
||||||
def random_cv(
|
def random_cv(
|
||||||
dataset_ensemble: DatasetEnsemble,
|
dataset_ensemble: DatasetEnsemble,
|
||||||
settings: CVSettings = CVSettings(),
|
settings: CVSettings = CVSettings(),
|
||||||
):
|
experiment: str | None = None,
|
||||||
|
) -> Path:
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
dataset_ensemble (DatasetEnsemble): The dataset ensemble configuration.
|
||||||
settings (CVSettings): The cross-validation settings.
|
settings (CVSettings): The cross-validation settings.
|
||||||
|
experiment (str | None): Optional experiment name for results directory.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
|
# Since we use cuml and xgboost libraries, we can only enable array API for ESPA
|
||||||
use_array_api = settings.model == "espa"
|
use_array_api = settings.model != "xgboost"
|
||||||
device = "torch" if use_array_api else "cuda"
|
device = "torch" if settings.model == "espa" else "cuda"
|
||||||
set_config(array_api_dispatch=use_array_api)
|
set_config(array_api_dispatch=use_array_api)
|
||||||
|
|
||||||
print("Creating training data...")
|
print("Creating training data...")
|
||||||
|
|
@ -173,7 +176,8 @@ def random_cv(
|
||||||
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
print(f"{refit.replace('_', ' ').capitalize()} on test set: {test_score:.3f}")
|
||||||
|
|
||||||
results_dir = get_cv_results_dir(
|
results_dir = get_cv_results_dir(
|
||||||
"random_search",
|
experiment=experiment,
|
||||||
|
name="random_search",
|
||||||
grid=dataset_ensemble.grid,
|
grid=dataset_ensemble.grid,
|
||||||
level=dataset_ensemble.level,
|
level=dataset_ensemble.level,
|
||||||
task=settings.task,
|
task=settings.task,
|
||||||
|
|
@ -283,6 +287,7 @@ def random_cv(
|
||||||
|
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
return results_dir
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,14 @@ def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
|
||||||
return cache_dir / f"cells_{cells_hash}.parquet"
|
return cache_dir / f"cells_{cells_hash}.parquet"
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_stats_cache() -> Path:
|
||||||
|
cache_dir = DATASET_ENSEMBLES_DIR / "cache"
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return cache_dir / "dataset_stats.pckl"
|
||||||
|
|
||||||
|
|
||||||
def get_cv_results_dir(
|
def get_cv_results_dir(
|
||||||
|
experiment: str | None,
|
||||||
name: str,
|
name: str,
|
||||||
grid: Grid,
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
|
|
@ -147,7 +154,12 @@ def get_cv_results_dir(
|
||||||
) -> Path:
|
) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
if experiment is not None:
|
||||||
|
experiment_dir = RESULTS_DIR / experiment
|
||||||
|
experiment_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
experiment_dir = RESULTS_DIR
|
||||||
|
results_dir = experiment_dir / f"{gridname}_{name}_cv{now}_{task}"
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return results_dir
|
return results_dir
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@ type GridLevel = Literal[
|
||||||
"healpix9",
|
"healpix9",
|
||||||
"healpix10",
|
"healpix10",
|
||||||
]
|
]
|
||||||
type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
|
type TargetDataset = Literal["darts_v1", "darts_mllabels"]
|
||||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||||
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
# TODO: Consider implementing a "timeseries" temporal mode
|
# TODO: Consider implementing a "timeseries" temporal mode
|
||||||
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
||||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||||
type Stage = Literal["train", "inference", "visualization"]
|
type Stage = Literal["train", "inference", "visualization"]
|
||||||
|
|
||||||
|
|
@ -37,17 +37,20 @@ class GridConfig:
|
||||||
sort_key: str
|
sort_key: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
|
def from_grid_level(cls, grid_level: GridLevel | tuple[Literal["hex", "healpix"], int]) -> "GridConfig":
|
||||||
"""Create a GridConfig from a GridLevel string."""
|
"""Create a GridConfig from a GridLevel string."""
|
||||||
if grid_level.startswith("hex"):
|
if isinstance(grid_level, str):
|
||||||
grid = "hex"
|
if grid_level.startswith("hex"):
|
||||||
level = int(grid_level[3:])
|
grid = "hex"
|
||||||
elif grid_level.startswith("healpix"):
|
level = int(grid_level[3:])
|
||||||
grid = "healpix"
|
elif grid_level.startswith("healpix"):
|
||||||
level = int(grid_level[7:])
|
grid = "healpix"
|
||||||
|
level = int(grid_level[7:])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid grid level: {grid_level}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid grid level: {grid_level}")
|
grid, level = grid_level
|
||||||
|
grid_level: GridLevel = f"{grid}{level}" # ty:ignore[invalid-assignment]
|
||||||
display_name = f"{grid.capitalize()}-{level}"
|
display_name = f"{grid.capitalize()}-{level}"
|
||||||
|
|
||||||
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
|
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
|
||||||
|
|
@ -74,8 +77,8 @@ class GridConfig:
|
||||||
|
|
||||||
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
||||||
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
|
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023]
|
||||||
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
|
all_target_datasets: list[TargetDataset] = ["darts_v1", "darts_mllabels"]
|
||||||
all_l2_source_datasets: list[L2SourceDataset] = [
|
all_l2_source_datasets: list[L2SourceDataset] = [
|
||||||
"ArcticDEM",
|
"ArcticDEM",
|
||||||
"ERA5-shoulder",
|
"ERA5-shoulder",
|
||||||
|
|
|
||||||
222
tests/test_training.py
Normal file
222
tests/test_training.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
"""Tests for training.py module, specifically random_cv function.
|
||||||
|
|
||||||
|
This test suite validates the random_cv training function across all model-task
|
||||||
|
combinations using a minimal hex level 3 grid with synopsis temporal mode.
|
||||||
|
|
||||||
|
Test Coverage:
|
||||||
|
- All 12 model-task combinations (4 models x 3 tasks): espa, xgboost, rf, knn
|
||||||
|
- Device handling for each model type (torch/CUDA/cuML compatibility)
|
||||||
|
- Multi-label target dataset support
|
||||||
|
- Temporal mode configuration (synopsis)
|
||||||
|
- Output file creation and validation
|
||||||
|
|
||||||
|
Running Tests:
|
||||||
|
# Run all training tests (18 tests total, ~3 iterations each)
|
||||||
|
pixi run pytest tests/test_training.py -v
|
||||||
|
|
||||||
|
# Run only device handling tests
|
||||||
|
pixi run pytest tests/test_training.py::TestRandomCV::test_device_handling -v
|
||||||
|
|
||||||
|
# Run a specific model-task combination
|
||||||
|
pixi run pytest tests/test_training.py::TestRandomCV::test_random_cv_all_combinations[binary-espa] -v
|
||||||
|
|
||||||
|
Note: Tests use minimal iterations (3) and level 3 grid for speed.
|
||||||
|
Full production runs use higher iteration counts (100-2000).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.ml.training import CVSettings, random_cv
|
||||||
|
from entropice.utils.types import Model, Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_ensemble():
|
||||||
|
"""Create a minimal DatasetEnsemble for testing.
|
||||||
|
|
||||||
|
Uses hex level 3 grid with synopsis temporal mode for fast testing.
|
||||||
|
"""
|
||||||
|
return DatasetEnsemble(
|
||||||
|
grid="hex",
|
||||||
|
level=3,
|
||||||
|
temporal_mode="synopsis",
|
||||||
|
members=["AlphaEarth"], # Use only one member for faster tests
|
||||||
|
add_lonlat=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cleanup_results():
|
||||||
|
"""Clean up results directory after each test.
|
||||||
|
|
||||||
|
This fixture collects the actual result directories created during tests
|
||||||
|
and removes them after the test completes.
|
||||||
|
"""
|
||||||
|
created_dirs = []
|
||||||
|
|
||||||
|
def register_dir(results_dir):
|
||||||
|
"""Register a directory to be cleaned up."""
|
||||||
|
created_dirs.append(results_dir)
|
||||||
|
return results_dir
|
||||||
|
|
||||||
|
yield register_dir
|
||||||
|
|
||||||
|
# Clean up only the directories created during this test
|
||||||
|
for results_dir in created_dirs:
|
||||||
|
if results_dir.exists():
|
||||||
|
shutil.rmtree(results_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# Model-task combinations to test
|
||||||
|
# Note: Not all combinations make sense, but we test all to ensure robustness
|
||||||
|
MODELS: list[Model] = ["espa", "xgboost", "rf", "knn"]
|
||||||
|
TASKS: list[Task] = ["binary", "count", "density"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandomCV:
|
||||||
|
"""Test suite for random_cv function."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("task", TASKS)
|
||||||
|
def test_random_cv_all_combinations(self, test_ensemble, model: Model, task: Task, cleanup_results):
|
||||||
|
"""Test random_cv with all model-task combinations.
|
||||||
|
|
||||||
|
This test runs 3 iterations for each combination to verify:
|
||||||
|
- The function completes without errors
|
||||||
|
- Device handling works correctly for each model type
|
||||||
|
- All output files are created
|
||||||
|
"""
|
||||||
|
# Use darts_v1 as the primary target for all tests
|
||||||
|
settings = CVSettings(
|
||||||
|
n_iter=3,
|
||||||
|
task=task,
|
||||||
|
target="darts_v1",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the cross-validation and get the results directory
|
||||||
|
results_dir = random_cv(
|
||||||
|
dataset_ensemble=test_ensemble,
|
||||||
|
settings=settings,
|
||||||
|
experiment="test_training",
|
||||||
|
)
|
||||||
|
cleanup_results(results_dir)
|
||||||
|
|
||||||
|
# Verify results directory was created
|
||||||
|
assert results_dir.exists(), f"Results directory not created for {model=}, {task=}"
|
||||||
|
|
||||||
|
# Verify all expected output files exist
|
||||||
|
expected_files = [
|
||||||
|
"search_settings.toml",
|
||||||
|
"best_estimator_model.pkl",
|
||||||
|
"search_results.parquet",
|
||||||
|
"metrics.toml",
|
||||||
|
"predicted_probabilities.parquet",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add task-specific files
|
||||||
|
if task in ["binary", "count", "density"]:
|
||||||
|
# All tasks that use classification (including count/density when binned)
|
||||||
|
# Note: count and density without _regimes suffix might be regression
|
||||||
|
if task == "binary" or "_regimes" in task:
|
||||||
|
expected_files.append("confusion_matrix.nc")
|
||||||
|
|
||||||
|
# Add model-specific files
|
||||||
|
if model in ["espa", "xgboost", "rf"]:
|
||||||
|
expected_files.append("best_estimator_state.nc")
|
||||||
|
|
||||||
|
for filename in expected_files:
|
||||||
|
filepath = results_dir / filename
|
||||||
|
assert filepath.exists(), f"Expected file {filename} not found for {model=}, {task=}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
def test_device_handling(self, test_ensemble, model: Model, cleanup_results):
|
||||||
|
"""Test that device handling works correctly for each model type.
|
||||||
|
|
||||||
|
Different models require different device configurations:
|
||||||
|
- espa: Uses torch with array API dispatch
|
||||||
|
- xgboost: Uses CUDA without array API dispatch
|
||||||
|
- rf/knn: GPU-accelerated via cuML
|
||||||
|
"""
|
||||||
|
settings = CVSettings(
|
||||||
|
n_iter=3,
|
||||||
|
task="binary", # Simple binary task for device testing
|
||||||
|
target="darts_v1",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should complete without device-related errors
|
||||||
|
try:
|
||||||
|
results_dir = random_cv(
|
||||||
|
dataset_ensemble=test_ensemble,
|
||||||
|
settings=settings,
|
||||||
|
experiment="test_training",
|
||||||
|
)
|
||||||
|
cleanup_results(results_dir)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Check if error is device-related
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
device_keywords = ["cuda", "gpu", "device", "cpu", "torch", "cupy"]
|
||||||
|
if any(keyword in error_msg for keyword in device_keywords):
|
||||||
|
pytest.fail(f"Device handling error for {model=}: {e}")
|
||||||
|
else:
|
||||||
|
# Re-raise non-device errors
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_random_cv_with_mllabels(self, test_ensemble, cleanup_results):
|
||||||
|
"""Test random_cv with multi-label target dataset."""
|
||||||
|
settings = CVSettings(
|
||||||
|
n_iter=3,
|
||||||
|
task="binary",
|
||||||
|
target="darts_mllabels",
|
||||||
|
model="espa",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the cross-validation and get the results directory
|
||||||
|
results_dir = random_cv(
|
||||||
|
dataset_ensemble=test_ensemble,
|
||||||
|
settings=settings,
|
||||||
|
experiment="test_training",
|
||||||
|
)
|
||||||
|
cleanup_results(results_dir)
|
||||||
|
|
||||||
|
# Verify results were created
|
||||||
|
assert results_dir.exists(), "Results directory not created"
|
||||||
|
assert (results_dir / "search_settings.toml").exists()
|
||||||
|
|
||||||
|
def test_temporal_mode_synopsis(self, cleanup_results):
|
||||||
|
"""Test that temporal_mode='synopsis' is correctly used."""
|
||||||
|
import toml
|
||||||
|
|
||||||
|
ensemble = DatasetEnsemble(
|
||||||
|
grid="hex",
|
||||||
|
level=3,
|
||||||
|
temporal_mode="synopsis",
|
||||||
|
members=["AlphaEarth"],
|
||||||
|
add_lonlat=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = CVSettings(
|
||||||
|
n_iter=3,
|
||||||
|
task="binary",
|
||||||
|
target="darts_v1",
|
||||||
|
model="espa",
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should use synopsis mode (all years aggregated)
|
||||||
|
results_dir = random_cv(
|
||||||
|
dataset_ensemble=ensemble,
|
||||||
|
settings=settings,
|
||||||
|
experiment="test_training",
|
||||||
|
)
|
||||||
|
cleanup_results(results_dir)
|
||||||
|
|
||||||
|
# Verify the settings were stored correctly
|
||||||
|
assert results_dir.exists(), "Results directory not created"
|
||||||
|
with open(results_dir / "search_settings.toml") as f:
|
||||||
|
stored_settings = toml.load(f)
|
||||||
|
|
||||||
|
assert stored_settings["settings"]["temporal_mode"] == "synopsis"
|
||||||
38
tests/validate_datasets.py
Normal file
38
tests/validate_datasets.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
import cyclopts
|
||||||
|
from rich import pretty, print, traceback
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
|
from entropice.utils.types import all_temporal_modes, grid_configs
|
||||||
|
|
||||||
|
pretty.install()
|
||||||
|
traceback.install()
|
||||||
|
|
||||||
|
cli = cyclopts.App()
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_ensemble_stats(e: DatasetEnsemble):
|
||||||
|
# Get a small sample of the cell ids
|
||||||
|
sample_cell_ids = e.cell_ids[:5]
|
||||||
|
features = e.make_features(sample_cell_ids)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
|
||||||
|
)
|
||||||
|
print(f"Number of feature columns: {len(features.columns)}")
|
||||||
|
|
||||||
|
for member in ["arcticdem", "embeddings", "era5"]:
|
||||||
|
member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
|
||||||
|
print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.default()
|
||||||
|
def validate_datasets():
|
||||||
|
for gc in grid_configs:
|
||||||
|
for temporal_mode in all_temporal_modes:
|
||||||
|
e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
|
||||||
|
_gather_ensemble_stats(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue