diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 85e3671..dd8e508 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -12,7 +12,6 @@ Pages:
import streamlit as st
-from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
from entropice.dashboard.views.dataset_page import render_dataset_page
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
@@ -28,7 +27,6 @@ def main():
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
- autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
@@ -36,7 +34,7 @@ def main():
{
"Overview": [overview_page],
"Data": [data_page],
- "Experiments": [training_analysis_page, autogluon_page, model_state_page],
+ "Experiments": [training_analysis_page, model_state_page],
"Inference": [inference_page],
}
)
diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py
deleted file mode 100644
index 9ada355..0000000
--- a/src/entropice/dashboard/plots/hyperparameter_analysis.py
+++ /dev/null
@@ -1,1591 +0,0 @@
-"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
-
-from pathlib import Path
-
-import altair as alt
-import geopandas as gpd
-import matplotlib.colors as mcolors
-import numpy as np
-import pandas as pd
-import pydeck as pdk
-import streamlit as st
-import xarray as xr
-
-from entropice.dashboard.utils.class_ordering import get_ordered_classes
-from entropice.dashboard.utils.colors import get_cmap, get_palette
-from entropice.dashboard.utils.geometry import fix_hex_geometry
-from entropice.ml.training import TrainingSettings
-
-
-def render_performance_summary(results: pd.DataFrame, refit_metric: str):
- """Render summary statistics of model performance.
-
- Args:
- results: DataFrame with CV results.
- refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
-
- """
- # Get all test score columns
- score_cols = [col for col in results.columns if col.startswith("mean_test_")]
-
- if not score_cols:
- st.warning("No test score columns found in results.")
- return
-
- # Calculate statistics for each metric
- col1, col2 = st.columns(2)
-
- with col1:
- st.markdown("#### Best Scores")
- best_scores = []
- for col in score_cols:
- metric_name = col.replace("mean_test_", "").replace("_", " ").title()
- best_score = results[col].max()
- best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
-
- st.dataframe(pd.DataFrame(best_scores), hide_index=True, width="stretch")
-
- with col2:
- st.markdown("#### Score Statistics")
- score_stats = []
- for col in score_cols:
- metric_name = col.replace("mean_test_", "").replace("_", " ").title()
- mean_score = results[col].mean()
- std_score = results[col].std()
- score_stats.append(
- {
- "Metric": metric_name,
- "Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}",
- }
- )
-
- st.dataframe(pd.DataFrame(score_stats), hide_index=True, width="stretch")
-
- # Show best parameter combination in a cleaner format (similar to old dashboard)
- st.markdown("#### 🏆 Best Parameter Combination")
- refit_col = f"mean_test_{refit_metric}"
-
- # Check if refit metric exists in results
- if refit_col not in results.columns:
- available_metrics = [col.replace("mean_test_", "") for col in score_cols]
- st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}")
- # Use the first available metric as fallback
- refit_col = score_cols[0]
- refit_metric = refit_col.replace("mean_test_", "")
- st.info(f"Using '{refit_metric}' as fallback metric.")
-
- best_idx = results[refit_col].idxmax()
- best_row = results.loc[best_idx]
-
- # Extract parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
-
- if param_cols:
- best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
-
- # Display in a container with metrics (similar to old dashboard style)
- with st.container(border=True):
- st.caption(f"Parameters of the best model (selected by {refit_metric.replace('_', ' ').title()} score)")
-
- # Display parameters as metrics
- n_params = len(best_params)
- cols = st.columns(n_params)
-
- for idx, (param_name, param_value) in enumerate(best_params.items()):
- with cols[idx]:
- # Format value based on type and magnitude
- if isinstance(param_value, (int, np.integer)):
- formatted_value = f"{param_value:.0f}"
- elif isinstance(param_value, (float, np.floating)):
- # Use scientific notation for very small numbers
- if abs(param_value) < 0.01:
- formatted_value = f"{param_value:.2e}"
- else:
- formatted_value = f"{param_value:.4f}"
- else:
- formatted_value = str(param_value)
-
- st.metric(param_name, formatted_value)
-
- # Show all metrics for the best model
- st.divider()
- st.caption("Performance across all metrics")
-
- # Get all metrics from score_cols
- all_metrics = [col.replace("mean_test_", "") for col in score_cols]
- metric_cols = st.columns(len(all_metrics))
-
- for idx, metric in enumerate(all_metrics):
- with metric_cols[idx]:
- best_score = results.loc[best_idx, f"mean_test_{metric}"]
- best_std = results.loc[best_idx, f"std_test_{metric}"]
- st.metric(
- metric.replace("_", " ").title(),
- f"{best_score:.4f}",
- delta=f"±{best_std:.4f}",
- help="Mean ± std across cross-validation folds",
- )
-
-
-def render_parameter_distributions(results: pd.DataFrame, settings: TrainingSettings | None = None):
- """Render histograms of parameter distributions explored.
-
- Args:
- results: DataFrame with CV results.
- settings: Optional settings dictionary containing param_grid configuration.
-
- """
- # Get parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
-
- if not param_cols:
- st.warning("No parameter columns found in results.")
- return
-
- # Extract scale information from settings if available
- param_scales = {}
- if settings and hasattr(settings, "param_grid"):
- param_grid = settings.param_grid
- for param_name, param_config in param_grid.items():
- if isinstance(param_config, dict) and "distribution" in param_config:
- # loguniform distribution indicates log scale
- if param_config["distribution"] == "loguniform":
- param_scales[param_name] = "log"
- else:
- param_scales[param_name] = "linear"
- else:
- param_scales[param_name] = "linear"
-
- # Get colormap from colors module
- cmap = get_cmap("parameter_distribution")
- bar_color = mcolors.rgb2hex(cmap(0.5))
-
- # Create histograms for each parameter
- n_params = len(param_cols)
- n_cols = min(3, n_params)
- n_rows = (n_params + n_cols - 1) // n_cols
-
- for row in range(n_rows):
- cols = st.columns(n_cols)
- for col_idx in range(n_cols):
- param_idx = row * n_cols + col_idx
- if param_idx >= n_params:
- break
-
- param_col = param_cols[param_idx]
- param_name = param_col.replace("param_", "")
-
- with cols[col_idx]:
- # Check if parameter is numeric or categorical
- param_values = results[param_col].dropna()
-
- if pd.api.types.is_numeric_dtype(param_values):
- # Check if we have enough unique values for a histogram
- n_unique = param_values.nunique()
-
- if n_unique == 1:
- # Only one unique value - show as text
- value = param_values.iloc[0]
- formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}"
- st.metric(param_name, formatted)
- st.caption(f"All {len(param_values)} samples have the same value")
- continue
-
- # Numeric parameter - use histogram
- df_plot = pd.DataFrame({param_name: param_values.to_numpy()})
-
- # Use log scale if the range spans multiple orders of magnitude OR values are very small
- max_val = param_values.max()
-
- # Determine number of bins based on unique values
- n_bins = min(20, max(5, n_unique))
-
- # Determine x-axis scale from config or infer from data
- use_log_scale = param_scales.get(param_name, "linear") == "log"
- if not use_log_scale:
- # Fall back to automatic detection if not in config
- value_range = max_val / (param_values.min() + 1e-10)
- use_log_scale = value_range > 100 or max_val < 0.01
-
- # For very small values or large ranges, use a simple bar chart instead of histogram
- if n_unique <= 10:
- # Few unique values - use bar chart
- value_counts = param_values.value_counts().reset_index()
- value_counts.columns = [param_name, "count"]
- value_counts = value_counts.sort_values(param_name)
-
- x_scale = alt.Scale(type="log") if use_log_scale else alt.Scale()
- title_suffix = " (log scale)" if use_log_scale else ""
-
- # Format x-axis values
- if max_val < 0.01:
- formatted_col = f"{param_name}_formatted"
- value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}")
- chart = (
- alt.Chart(value_counts)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(
- f"{formatted_col}:N",
- title=param_name,
- sort=None,
- ),
- alt.Y("count:Q", title="Count"),
- tooltip=[
- alt.Tooltip(param_name, format=".2e"),
- alt.Tooltip("count", title="Count"),
- ],
- )
- .properties(height=250, title=f"{param_name}{title_suffix}")
- )
- else:
- chart = (
- alt.Chart(value_counts)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(
- f"{param_name}:Q",
- title=param_name,
- scale=x_scale,
- ),
- alt.Y("count:Q", title="Count"),
- tooltip=[
- alt.Tooltip(param_name, format=".3f"),
- alt.Tooltip("count", title="Count"),
- ],
- )
- .properties(height=250, title=f"{param_name}{title_suffix}")
- )
- else:
- # Many unique values - use binned histogram
- title_suffix = " (log scale)" if use_log_scale else ""
-
- if use_log_scale:
- # For log scale parameters, create bins in log space then transform back
- # This gives better distribution visualization
- log_values = np.log10(param_values.to_numpy())
- log_bins = np.linspace(log_values.min(), log_values.max(), n_bins + 1)
- bins_linear = 10**log_bins
-
- # Manually bin the data
- binned = pd.cut(param_values, bins=bins_linear)
- bin_counts = binned.value_counts().sort_index()
-
- # Create dataframe for plotting
- bin_data = []
- for interval, count in bin_counts.items():
- bin_mid = (interval.left + interval.right) / 2
- bin_data.append(
- {
- param_name: bin_mid,
- "count": count,
- "bin_label": f"{interval.left:.2e} - {interval.right:.2e}",
- }
- )
-
- df_binned = pd.DataFrame(bin_data)
-
- chart = (
- alt.Chart(df_binned)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(
- f"{param_name}:Q",
- title=param_name,
- scale=alt.Scale(type="log"),
- axis=alt.Axis(format=".2e"),
- ),
- alt.Y("count:Q", title="Count"),
- tooltip=[
- alt.Tooltip("bin_label:N", title="Range"),
- alt.Tooltip("count:Q", title="Count"),
- ],
- )
- .properties(height=250, title=f"{param_name}{title_suffix}")
- )
- else:
- # Linear scale - use standard binning
- format_str = ".2e" if max_val < 0.01 else ".3f"
- chart = (
- alt.Chart(df_plot)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(
- f"{param_name}:Q",
- bin=alt.Bin(maxbins=n_bins),
- title=param_name,
- ),
- alt.Y("count()", title="Count"),
- tooltip=[
- alt.Tooltip(
- f"{param_name}:Q",
- format=format_str,
- bin=True,
- ),
- "count()",
- ],
- )
- .properties(height=250, title=param_name)
- )
-
- st.altair_chart(chart, width="stretch")
-
- else:
- # Categorical parameter - use bar chart
- value_counts = param_values.value_counts().reset_index()
- value_counts.columns = [param_name, "count"]
-
- chart = (
- alt.Chart(value_counts)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(param_name, title=param_name, sort="-y"),
- alt.Y("count", title="Count"),
- tooltip=[param_name, "count"],
- )
- .properties(height=250, title=param_name)
- )
-
- st.altair_chart(chart, width="stretch")
-
-
-def render_score_vs_parameter(results: pd.DataFrame, metric: str):
- """Render scatter plots of score vs each parameter.
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to plot (e.g., 'f1', 'accuracy').
-
- """
- st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters")
-
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Get parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
-
- if not param_cols:
- st.warning("No parameter columns found in results.")
- return
-
- # Create scatter plots for each parameter
- n_params = len(param_cols)
- n_cols = min(2, n_params)
- n_rows = (n_params + n_cols - 1) // n_cols
-
- for row in range(n_rows):
- cols = st.columns(n_cols)
- for col_idx in range(n_cols):
- param_idx = row * n_cols + col_idx
- if param_idx >= n_params:
- break
-
- param_col = param_cols[param_idx]
- param_name = param_col.replace("param_", "")
-
- with cols[col_idx]:
- param_values = results[param_col].dropna()
-
- if pd.api.types.is_numeric_dtype(param_values):
- # Numeric parameter - scatter plot
- df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
-
- # Use log scale if needed
- value_range = param_values.max() / (param_values.min() + 1e-10)
- use_log = value_range > 100
-
- if use_log:
- chart = (
- alt.Chart(df_plot)
- .mark_circle(size=60, opacity=0.6)
- .encode(
- alt.X(
- param_name,
- scale=alt.Scale(type="log"),
- title=param_name,
- ),
- alt.Y(metric, title=metric.replace("_", " ").title()),
- alt.Color(
- metric,
- scale=alt.Scale(range=get_palette(metric, n_colors=256)),
- legend=None,
- ),
- tooltip=[
- alt.Tooltip(param_name, format=".2e"),
- alt.Tooltip(metric, format=".4f"),
- ],
- )
- .properties(
- height=300,
- title=f"{metric} vs {param_name} (log scale)",
- )
- )
- else:
- chart = (
- alt.Chart(df_plot)
- .mark_circle(size=60, opacity=0.6)
- .encode(
- alt.X(param_name, title=param_name),
- alt.Y(metric, title=metric.replace("_", " ").title()),
- alt.Color(
- metric,
- scale=alt.Scale(range=get_palette(metric, n_colors=256)),
- legend=None,
- ),
- tooltip=[
- alt.Tooltip(param_name, format=".3f"),
- alt.Tooltip(metric, format=".4f"),
- ],
- )
- .properties(height=300, title=f"{metric} vs {param_name}")
- )
-
- st.altair_chart(chart, width="stretch")
-
- else:
- # Categorical parameter - box plot
- df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
-
- chart = (
- alt.Chart(df_plot)
- .mark_boxplot()
- .encode(
- alt.X(param_name, title=param_name),
- alt.Y(metric, title=metric.replace("_", " ").title()),
- tooltip=[param_name, alt.Tooltip(metric, format=".4f")],
- )
- .properties(height=300, title=f"{metric} vs {param_name}")
- )
-
- st.altair_chart(chart, width="stretch")
-
-
-def render_parameter_correlation(results: pd.DataFrame, metric: str):
- """Render correlation bar chart between parameters and score.
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to analyze (e.g., 'f1', 'accuracy').
-
- """
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Get numeric parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
- numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
-
- if not numeric_params:
- st.warning("No numeric parameters found for correlation analysis.")
- return
-
- # Calculate correlations
- correlations = []
- for param_col in numeric_params:
- param_name = param_col.replace("param_", "")
- corr = results[[param_col, score_col]].corr().iloc[0, 1]
- correlations.append({"Parameter": param_name, "Correlation": corr})
-
- corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
-
- # Get colormap from colors module (use diverging colormap for correlation)
- hex_colors = get_palette("correlation", n_colors=256)
-
- # Create bar chart
- chart = (
- alt.Chart(corr_df)
- .mark_bar()
- .encode(
- alt.X("Correlation", title="Correlation with Score"),
- alt.Y("Parameter", sort="-x", title="Parameter"),
- alt.Color(
- "Correlation",
- scale=alt.Scale(range=hex_colors, domain=[-1, 1]),
- legend=None,
- ),
- tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
- )
- .properties(height=max(200, len(correlations) * 30))
- )
-
- st.altair_chart(chart, width="stretch")
-
-
-def render_binned_parameter_space(results: pd.DataFrame, metric: str):
- """Render binned parameter space plots similar to old dashboard.
-
- This creates plots where parameters are binned and plotted against each other,
- showing the metric value as color. Handles different hyperparameters dynamically.
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to visualize (e.g., 'f1', 'accuracy').
-
- """
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Get numeric parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
- numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
-
- if len(numeric_params) < 2:
- st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
- return
-
- # Prepare binned data and gather parameter info
- results_binned = results.copy()
- bin_info = {}
-
- for param_col in numeric_params:
- param_name = param_col.replace("param_", "")
- param_values = results[param_col].dropna()
-
- if len(param_values) == 0:
- continue
-
- # Determine if we should use log bins
- min_val = param_values.min()
- max_val = param_values.max()
-
- if min_val > 0:
- value_range = max_val / min_val
- use_log = value_range > 100 or max_val < 0.01
- else:
- use_log = False
-
- if use_log and min_val > 0:
- # Logarithmic binning for parameters spanning many orders of magnitude
- log_min = np.log10(min_val)
- log_max = np.log10(max_val)
- n_bins = min(10, max(3, int(log_max - log_min) + 1))
- bins = np.logspace(log_min, log_max, num=n_bins)
- else:
- # Linear binning
- n_bins = min(10, max(3, int(np.sqrt(len(param_values)))))
- bins = np.linspace(min_val, max_val, num=n_bins)
-
- results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
- bin_info[param_name] = {
- "use_log": use_log,
- "bins": bins,
- "min": min_val,
- "max": max_val,
- "n_unique": param_values.nunique(),
- }
-
- # Get colormap from colors module
- hex_colors = get_palette(metric, n_colors=256)
-
- # Get parameter names sorted by scale type (log parameters first, then linear)
- param_names = [col.replace("param_", "") for col in numeric_params]
- param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p))
-
- st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
-
- # Create all pairwise combinations systematically
- if len(param_names) == 2:
- # Simple case: just one pair
- x_param, y_param = param_names_sorted
- _render_2d_param_plot(
- results_binned,
- x_param,
- y_param,
- score_col,
- bin_info,
- hex_colors,
- metric,
- height=500,
- )
- else:
- # Multiple parameters: create structured plots
- st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
-
- # Strategy: Show all combinations (including duplicates with swapped axes)
- # Group by first parameter to create organized sections
-
- for i, x_param in enumerate(param_names_sorted):
- # Get all other parameters to pair with this one
- other_params = [p for p in param_names_sorted if p != x_param]
-
- if not other_params:
- continue
-
- # Create expander for each primary parameter
- with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)):
- # Create plots in rows of 2
- n_others = len(other_params)
- for row_idx in range(0, n_others, 2):
- cols = st.columns(2)
-
- for col_idx in range(2):
- other_idx = row_idx + col_idx
- if other_idx >= n_others:
- break
-
- y_param = other_params[other_idx]
-
- with cols[col_idx]:
- _render_2d_param_plot(
- results_binned,
- x_param,
- y_param,
- score_col,
- bin_info,
- hex_colors,
- metric,
- height=350,
- )
-
-
-def _render_2d_param_plot(
- results_binned: pd.DataFrame,
- x_param: str,
- y_param: str,
- score_col: str,
- bin_info: dict,
- hex_colors: list,
- metric: str,
- height: int = 400,
-):
- """Render a 2D parameter space plot.
-
- Args:
- results_binned: DataFrame with binned results.
- x_param: Name of x-axis parameter.
- y_param: Name of y-axis parameter.
- score_col: Column name for the score.
- bin_info: Dictionary with binning information for each parameter.
- hex_colors: List of hex colors for the colormap.
- metric: Name of the metric being visualized.
- height: Height of the plot in pixels.
-
- """
- plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
-
- if len(plot_data) == 0:
- st.warning(f"No data available for {x_param} vs {y_param}")
- return
-
- x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
- y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
-
- # Determine marker size based on number of points
- n_points = len(plot_data)
- marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40
-
- chart = (
- alt.Chart(plot_data)
- .mark_circle(size=marker_size, opacity=0.7)
- .encode(
- x=alt.X(
- f"param_{x_param}:Q",
- title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}",
- scale=x_scale,
- ),
- y=alt.Y(
- f"param_{y_param}:Q",
- title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}",
- scale=y_scale,
- ),
- color=alt.Color(
- f"{score_col}:Q",
- scale=alt.Scale(range=hex_colors),
- title=metric.replace("_", " ").title(),
- ),
- tooltip=[
- alt.Tooltip(
- f"param_{x_param}:Q",
- title=x_param,
- format=".2e" if bin_info[x_param]["use_log"] else ".3f",
- ),
- alt.Tooltip(
- f"param_{y_param}:Q",
- title=y_param,
- format=".2e" if bin_info[y_param]["use_log"] else ".3f",
- ),
- alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"),
- ],
- )
- .properties(
- height=height,
- title=f"{x_param} vs {y_param}",
- )
- .interactive()
- )
-
- st.altair_chart(chart, width="stretch")
-
-
-def render_score_evolution(results: pd.DataFrame, metric: str):
- """Render evolution of scores during search.
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to plot (e.g., 'f1', 'accuracy').
-
- """
- st.subheader(f"📉 {metric.replace('_', ' ').title()} Evolution")
-
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Create a copy with iteration number
- df_plot = results[[score_col]].copy()
- df_plot["Iteration"] = range(len(df_plot))
- df_plot["Best So Far"] = df_plot[score_col].cummax()
- df_plot = df_plot.rename(columns={score_col: "Score"})
-
- # Reshape for Altair
- df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
-
- # Get colormap for score evolution
- evolution_cmap = get_cmap("score_evolution")
- evolution_colors = [
- mcolors.rgb2hex(evolution_cmap(0.3)),
- mcolors.rgb2hex(evolution_cmap(0.7)),
- ]
-
- # Create line chart
- chart = (
- alt.Chart(df_long)
- .mark_line()
- .encode(
- alt.X("Iteration", title="Iteration"),
- alt.Y("value", title=metric.replace("_", " ").title()),
- alt.Color(
- "Type",
- legend=alt.Legend(title=""),
- scale=alt.Scale(range=evolution_colors),
- ),
- strokeDash=alt.StrokeDash(
- "Type",
- legend=None,
- scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
- ),
- tooltip=[
- "Iteration",
- "Type",
- alt.Tooltip("value", format=".4f", title="Score"),
- ],
- )
- .properties(height=400)
- )
-
- st.altair_chart(chart, width="stretch")
-
- # Show statistics
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}")
- with col2:
- st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}")
- with col3:
- st.metric("Std Dev", f"{df_plot['Score'].std():.4f}")
- with col4:
- # Find iteration where best was found
- best_iter = df_plot["Score"].idxmax()
- st.metric("Best at Iteration", best_iter)
-
-
-@st.fragment
-def render_multi_metric_comparison(results: pd.DataFrame):
- """Render comparison of multiple metrics.
-
- Args:
- results: DataFrame with CV results.
-
- """
- # Get all test score columns
- score_cols = [col for col in results.columns if col.startswith("mean_test_")]
-
- if len(score_cols) < 2:
- st.warning("Need at least 2 metrics for comparison.")
- return
-
- # Get parameter columns for color options
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
- param_names = [col.replace("param_", "") for col in param_cols]
-
- if not param_names:
- st.warning("No parameters found for coloring.")
- return
-
- # Let user select two metrics and color parameter
- col1, col2, col3 = st.columns(3)
- with col1:
- metric1 = st.selectbox(
- "Select First Metric",
- options=[col.replace("mean_test_", "") for col in score_cols],
- index=0,
- key="metric1_select",
- )
-
- with col2:
- metric2 = st.selectbox(
- "Select Second Metric",
- options=[col.replace("mean_test_", "") for col in score_cols],
- index=min(1, len(score_cols) - 1),
- key="metric2_select",
- )
-
- with col3:
- color_by = st.selectbox(
- "Color By",
- options=param_names,
- index=0,
- key="color_select",
- )
-
- if metric1 == metric2:
- st.warning("Please select different metrics.")
- return
-
- # Create scatter plot data
- df_plot = pd.DataFrame(
- {
- metric1: results[f"mean_test_{metric1}"],
- metric2: results[f"mean_test_{metric2}"],
- }
- )
-
- # Add color parameter to dataframe
- param_col = f"param_{color_by}"
- df_plot[color_by] = results[param_col]
- color_col = color_by
-
- # Check if parameter is numeric and should use log scale
- param_values = results[param_col].dropna()
- if pd.api.types.is_numeric_dtype(param_values):
- value_range = param_values.max() / (param_values.min() + 1e-10)
- use_log = value_range > 100
-
- if use_log:
- color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256))
- color_format = ".2e"
- else:
- color_scale = alt.Scale(range=get_palette(color_by, n_colors=256))
- color_format = ".3f"
- else:
- # Categorical parameter
- color_scale = alt.Scale(scheme="category20")
- color_format = None
-
- # Build tooltip list
- tooltip_list = [
- alt.Tooltip(metric1, format=".4f"),
- alt.Tooltip(metric2, format=".4f"),
- ]
- if color_format:
- tooltip_list.append(alt.Tooltip(color_col, format=color_format))
- else:
- tooltip_list.append(color_col)
-
- # Calculate axis domains starting from min values
- x_min = df_plot[metric1].min()
- x_max = df_plot[metric1].max()
- y_min = df_plot[metric2].min()
- y_max = df_plot[metric2].max()
-
- # Add small padding (2% of range) for better visualization
- x_padding = (x_max - x_min) * 0.02
- y_padding = (y_max - y_min) * 0.02
-
- chart = (
- alt.Chart(df_plot)
- .mark_circle(size=60, opacity=0.6)
- .encode(
- alt.X(
- metric1,
- title=metric1.replace("_", " ").title(),
- scale=alt.Scale(domain=[x_min - x_padding, x_max + x_padding]),
- ),
- alt.Y(
- metric2,
- title=metric2.replace("_", " ").title(),
- scale=alt.Scale(domain=[y_min - y_padding, y_max + y_padding]),
- ),
- alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
- tooltip=tooltip_list,
- )
- .properties(height=500)
- )
-
- st.altair_chart(chart, width="stretch")
-
- # Calculate correlation
- corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
- st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}")
-
-
-def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin_width: int = 40):
- """Render ESPA-specific binned parameter space plots.
-
- Creates faceted plots for all combinations of the three ESPA parameters:
- - eps_cl vs eps_e (binned by initial_K)
- - eps_cl vs initial_K (binned by eps_e)
- - eps_e vs initial_K (binned by eps_cl)
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to visualize (e.g., 'f1', 'accuracy').
- k_bin_width: Width of bins for initial_K parameter.
-
- """
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Check if this is an ESPA model with the required parameters
- required_params = ["param_initial_K", "param_eps_cl", "param_eps_e"]
- if not all(param in results.columns for param in required_params):
- st.info("ESPA-specific parameters not found. This visualization is only for ESPA models.")
- return
-
- # Get colormap from colors module
- hex_colors = get_palette(metric, n_colors=256)
-
- # Prepare base plot data
- base_data = results[["param_eps_e", "param_eps_cl", "param_initial_K", score_col]].copy()
- base_data = base_data.dropna()
-
- if len(base_data) == 0:
- st.warning("No data available for ESPA binned parameter space.")
- return
-
- # Configuration for each plot combination
- plot_configs = [
- {
- "x_param": "param_eps_e",
- "y_param": "param_eps_cl",
- "bin_param": "param_initial_K",
- "x_label": "eps_e",
- "y_label": "eps_cl",
- "bin_label": "initial_K",
- "x_scale": "log",
- "y_scale": "log",
- "bin_type": "linear",
- "bin_width": k_bin_width,
- "title": "eps_cl vs eps_e (binned by initial_K)",
- },
- {
- "x_param": "param_eps_cl",
- "y_param": "param_initial_K",
- "bin_param": "param_eps_e",
- "x_label": "eps_cl",
- "y_label": "initial_K",
- "bin_label": "eps_e",
- "x_scale": "log",
- "y_scale": "linear",
- "bin_type": "log",
- "bin_width": None, # Will use log bins
- "title": "initial_K vs eps_cl (binned by eps_e)",
- },
- {
- "x_param": "param_eps_e",
- "y_param": "param_initial_K",
- "bin_param": "param_eps_cl",
- "x_label": "eps_e",
- "y_label": "initial_K",
- "bin_label": "eps_cl",
- "x_scale": "log",
- "y_scale": "linear",
- "bin_type": "log",
- "bin_width": None, # Will use log bins
- "title": "initial_K vs eps_e (binned by eps_cl)",
- },
- ]
-
- # Create each plot
- for config in plot_configs:
- st.markdown(f"**{config['title']}**")
-
- plot_data = base_data.copy()
-
- # Create bins for the binning parameter
- bin_values = plot_data[config["bin_param"]].dropna()
- if len(bin_values) == 0:
- st.warning(f"No {config['bin_label']} values found.")
- continue
-
- if config["bin_type"] == "log":
- # Logarithmic binning for epsilon parameters
- log_min = np.log10(bin_values.min())
- log_max = np.log10(bin_values.max())
- n_bins = min(10, max(5, int(log_max - log_min) + 1))
- bins = np.logspace(log_min, log_max, num=n_bins)
- # Adjust bins to ensure all values are captured
- bins[0] = bins[0] * 0.999 # Extend first bin to capture minimum
- bins[-1] = bins[-1] * 1.001 # Extend last bin to capture maximum
- else:
- # Linear binning for initial_K
- bin_min = bin_values.min()
- bin_max = bin_values.max()
- bins = np.arange(bin_min, bin_max + config["bin_width"], config["bin_width"])
-
- # Bin the parameter
- plot_data["binned_param"] = pd.cut(plot_data[config["bin_param"]], bins=bins, right=False)
-
- # Remove any NaN bins (shouldn't happen, but just in case)
- plot_data = plot_data.dropna(subset=["binned_param"])
-
- # Sort bins and convert to string for ordering
- plot_data = plot_data.sort_values("binned_param")
- plot_data["binned_param_str"] = plot_data["binned_param"].astype(str)
- bin_order = plot_data["binned_param_str"].unique().tolist()
-
- # Create faceted scatter plot
- x_scale = alt.Scale(type=config["x_scale"]) if config["x_scale"] == "log" else alt.Scale()
- # For initial_K plots, set domain to start from the minimum value in the search space
- if config["y_label"] == "initial_K":
- y_min = plot_data[config["y_param"]].min()
- y_max = plot_data[config["y_param"]].max()
- y_scale = alt.Scale(domain=[y_min, y_max])
- else:
- y_scale = alt.Scale(type=config["y_scale"]) if config["y_scale"] == "log" else alt.Scale()
-
- chart = (
- alt.Chart(plot_data)
- .mark_circle(size=60, opacity=0.7)
- .encode(
- x=alt.X(
- f"{config['x_param']}:Q",
- scale=x_scale,
- axis=alt.Axis(title=config["x_label"], grid=True, gridOpacity=0.5),
- ),
- y=alt.Y(
- f"{config['y_param']}:Q",
- scale=y_scale,
- axis=alt.Axis(title=config["y_label"], grid=True, gridOpacity=0.5),
- ),
- color=alt.Color(
- f"{score_col}:Q",
- scale=alt.Scale(range=hex_colors),
- title=metric.replace("_", " ").title(),
- ),
- tooltip=[
- alt.Tooltip("param_eps_e:Q", title="eps_e", format=".2e"),
- alt.Tooltip("param_eps_cl:Q", title="eps_cl", format=".2e"),
- alt.Tooltip("param_initial_K:Q", title="initial_K", format=".0f"),
- alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"),
- alt.Tooltip("binned_param_str:N", title=f"{config['bin_label']} Bin"),
- ],
- )
- .properties(width=200, height=200)
- .facet(
- facet=alt.Facet(
- "binned_param_str:N",
- title=f"{config['bin_label']} (binned)",
- sort=bin_order,
- ),
- columns=5,
- )
- )
-
- st.altair_chart(chart, width="stretch")
-
- # Show statistics about the binning
- n_bins = len(bin_order)
- n_total = len(plot_data)
- st.caption(
- f"Showing {n_total} data points across {n_bins} bins of {config['bin_label']}. "
- f"Each facet shows {config['y_label']} vs {config['x_label']} for a range of {config['bin_label']} values."
- )
-
- st.write("") # Add some spacing between plots
-
-
-def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10):
- """Render table of top N configurations.
-
- Args:
- results: DataFrame with CV results.
- metric: The metric to rank by (e.g., 'f1', 'accuracy').
- top_n: Number of top configurations to show.
-
- """
- score_col = f"mean_test_{metric}"
- if score_col not in results.columns:
- st.warning(f"Metric {metric} not found in results.")
- return
-
- # Get parameter columns
- param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
-
- if not param_cols:
- st.warning("No parameter columns found in results.")
- return
-
- # Get top N configurations
- top_configs = results.nlargest(top_n, score_col)
-
- # Create display dataframe
- display_cols = ["rank_test_" + metric, score_col, *param_cols]
- display_cols = [col for col in display_cols if col in top_configs.columns]
-
- display_df = top_configs[display_cols].copy()
-
- # Rename columns for better display
- display_df = display_df.rename(
- columns={
- "rank_test_" + metric: "Rank",
- score_col: metric.replace("_", " ").title(),
- }
- )
-
- # Rename parameter columns
- display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns]
-
- # Format score column
- score_col_display = metric.replace("_", " ").title()
- display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
-
- st.dataframe(display_df, hide_index=True, width="stretch")
-
-
-@st.fragment
-def render_confusion_matrix_map(
- result_path: Path, settings: TrainingSettings, merged_predictions: gpd.GeoDataFrame | None = None
-):
- """Render 3D pydeck map showing model performance on training data.
-
- Displays cells from the training dataset with predictions, colored by correctness.
- Uses true labels for elevation (height) and different shades of red for incorrect predictions.
-
- Args:
- result_path: Path to the training result directory (not used, kept for compatibility).
- settings: Settings dictionary containing grid, level, task, and target information.
- merged_predictions: GeoDataFrame with predictions, true labels, and split info.
-
- """
- if merged_predictions is None:
- st.warning("Prediction data not available. Cannot display map.")
- return
-
- # Get grid type and task from settings
- grid = settings.grid
- task = settings.task
-
- # Use the merged predictions which already have true labels, predictions, and split info
- merged = merged_predictions.copy()
- merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
-
- if len(merged) == 0:
- st.warning("No predictions found for labeled cells.")
- return
-
- # Get ordered class labels for the task
- ordered_classes = get_ordered_classes(task)
-
- # Create controls
- col1, col2, col3 = st.columns([2, 1, 1])
-
- with col1:
- # Split selector (similar to confusion matrix)
- split_type = st.selectbox(
- "Select Data Split",
- options=["test", "train", "all"],
- format_func=lambda x: {"test": "Test Set", "train": "Training Set (CV)", "all": "All Data"}[x],
- help="Choose which data split to display on the map",
- key="prediction_map_split_select",
- )
-
- with col2:
- # Color scheme selector
- show_only_incorrect = st.checkbox(
- "Highlight Errors Only",
- value=False,
- help="Show only incorrect predictions in red, hide correct ones",
- key="prediction_map_errors_only",
- )
-
- with col3:
- opacity = st.slider(
- "Opacity",
- min_value=0.1,
- max_value=1.0,
- value=0.7,
- step=0.1,
- key="prediction_map_opacity",
- )
-
- # Filter data by split
- if split_type == "test":
- display_gdf = merged[merged["split"] == "test"].copy()
- split_caption = "Test Set (held-out data)"
- elif split_type == "train":
- display_gdf = merged[merged["split"] == "train"].copy()
- split_caption = "Training Set (CV data)"
- else: # "all"
- display_gdf = merged.copy()
- split_caption = "All Available Data"
-
- # Optionally filter to show only incorrect predictions
- if show_only_incorrect:
- display_gdf = display_gdf[~display_gdf["is_correct"]].copy()
-
- if len(display_gdf) == 0:
- st.warning(f"No cells found for {split_caption}.")
- return
-
- st.caption(f"📍 Showing {len(display_gdf)} cells from {split_caption}")
-
- # Convert to WGS84 for pydeck
- display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
-
- # Fix antimeridian issues for hex grids
- if grid == "hex":
- display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
-
- # Get red material colormap for incorrect predictions
- red_cmap = get_cmap("red_predictions") # Use red_material palette
- n_classes = len(ordered_classes)
-
- # Assign colors based on correctness
- def get_color(row):
- if row["is_correct"]:
- # Green for correct predictions
- return [46, 204, 113]
- else:
- # Different shades of red for each predicted class (ordered)
- pred_class = row["predicted_class"]
- if pred_class in ordered_classes:
- class_idx = ordered_classes.index(pred_class)
- # Sample from red colormap based on class index
- color_value = red_cmap(class_idx / max(n_classes - 1, 1))
- return [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
- else:
- # Fallback red if class not found
- return [231, 76, 60]
-
- display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
-
- # Add line color based on split: blue for test, orange for train
- def get_line_color(row):
- if row["split"] == "test":
- return [52, 152, 219] # Blue for test split
- else:
- return [230, 126, 34] # Orange for train split
-
- display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
-
- # Add elevation based on TRUE label (not predicted)
- # Map each true class to a height based on its position in the ordered list
- def get_elevation(row):
- true_class = row["true_class"]
- if true_class in ordered_classes:
- class_idx = ordered_classes.index(true_class)
- # Normalize to 0-1 range based on class position
- return (class_idx + 1) / n_classes
- else:
- return 0.5 # Default elevation
-
- display_gdf_wgs84["elevation"] = display_gdf_wgs84.apply(get_elevation, axis=1)
-
- # Convert to GeoJSON format
- geojson_data = []
- for _, row in display_gdf_wgs84.iterrows():
- # Determine split and status for tooltip
- split_name = "Test" if row["split"] == "test" else "Training (CV)"
- status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
-
- feature = {
- "type": "Feature",
- "geometry": row["geometry"].__geo_interface__,
- "properties": {
- "true_label": str(row["true_class"]),
- "predicted_label": str(row["predicted_class"]),
- "is_correct": bool(row["is_correct"]),
- "split": split_name,
- "status": status,
- "fill_color": row["fill_color"],
- "line_color": row["line_color"],
- "elevation": float(row["elevation"]),
- },
- }
- geojson_data.append(feature)
-
- # Create pydeck layer
- layer = pdk.Layer(
- "GeoJsonLayer",
- geojson_data,
- opacity=opacity,
- stroked=True,
- filled=True,
- extruded=True,
- wireframe=False,
- get_fill_color="properties.fill_color",
- get_line_color="properties.line_color",
- line_width_min_pixels=2,
- get_elevation="properties.elevation",
- elevation_scale=500000,
- pickable=True,
- )
-
- # Set initial view state (centered on the Arctic)
- view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
-
- # Create deck
- deck = pdk.Deck(
- layers=[layer],
- initial_view_state=view_state,
- tooltip={
- "html": "Status: {status}
"
- "True Label: {true_label}
"
- "Predicted Label: {predicted_label}
"
- "Split: {split}",
- "style": {"backgroundColor": "steelblue", "color": "white"},
- },
- map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
- )
-
- # Render the map
- st.pydeck_chart(deck)
-
- # Show statistics for displayed data
- col1, col2, col3 = st.columns(3)
-
- with col1:
- st.metric("Cells Displayed", len(display_gdf))
-
- with col2:
- correct = len(display_gdf[display_gdf["is_correct"]])
- st.metric("Correct Predictions", correct)
-
- with col3:
- if len(display_gdf) > 0:
- accuracy = correct / len(display_gdf)
- st.metric("Accuracy", f"{accuracy:.2%}")
- else:
- st.metric("Accuracy", "N/A")
-
- # Add legend
- with st.expander("Legend", expanded=True):
- st.markdown("**Fill Color (Prediction Correctness):**")
-
- # Correct predictions
- correct_count = len(display_gdf[display_gdf["is_correct"]])
- incorrect_count = len(display_gdf[~display_gdf["is_correct"]])
-
- st.markdown(
- f'
'
- f'
'
- f"
Correct Predictions ({correct_count} cells, "
- f"{correct_count / len(display_gdf) * 100 if len(display_gdf) > 0 else 0:.1f}%) ",
- unsafe_allow_html=True,
- )
-
- # Incorrect predictions by predicted class (shades of red)
- if incorrect_count > 0:
- st.markdown(
- f"Incorrect Predictions by Predicted Class ({incorrect_count} cells):", unsafe_allow_html=True
- )
-
- for class_idx, class_label in enumerate(ordered_classes):
- # Get count of incorrect predictions for this predicted class
- count = len(display_gdf[(~display_gdf["is_correct"]) & (display_gdf["predicted_class"] == class_label)])
- if count > 0:
- # Get color for this predicted class
- color_value = red_cmap(class_idx / max(n_classes - 1, 1))
- rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
-
- percentage = count / incorrect_count * 100
-
- st.markdown(
- f''
- f'
'
- f"
Predicted as {class_label}: {count} ({percentage:.1f}%) ",
- unsafe_allow_html=True,
- )
-
- st.markdown("---")
- st.markdown("**Border Color (Data Split):**")
-
- # Count by split in displayed data
- test_in_display = len(display_gdf[display_gdf["split"] == "test"])
- train_in_display = len(display_gdf[display_gdf["split"] == "train"])
-
- if test_in_display > 0:
- st.markdown(
- f''
- f'
'
- f"
Test Split ({test_in_display} cells) ",
- unsafe_allow_html=True,
- )
-
- if train_in_display > 0:
- st.markdown(
- f''
- f'
'
- f"
Training Split ({train_in_display} cells) ",
- unsafe_allow_html=True,
- )
-
- st.markdown("---")
- st.markdown("**Elevation (3D Height):**")
-
- # Show elevation mapping for each true class
- st.markdown("Height represents the true label:", unsafe_allow_html=True)
- for class_idx, class_label in enumerate(ordered_classes):
- elevation_value = (class_idx + 1) / n_classes
- height_km = elevation_value * 500 # Since elevation_scale is 500000
- st.markdown(
- f'• {class_label}: {height_km:.0f} km
',
- unsafe_allow_html=True,
- )
- st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.")
-
-
-def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str):
- """Render confusion matrix as an interactive heatmap.
-
- Args:
- confusion_matrix: xarray DataArray with dimensions (true_label, predicted_label).
- task: Task type ('binary' or 'multiclass').
-
- """
- import plotly.express as px
-
- # Convert to DataFrame for plotting
- cm_df = confusion_matrix.to_pandas()
-
- # Get labels (convert numeric labels to semantic labels if possible)
- true_labels = confusion_matrix.coords["true_label"].values
- pred_labels = confusion_matrix.coords["predicted_label"].values
-
- # Check if labels are already strings (from predictions) or numeric (from stored confusion matrices)
- first_true_label = true_labels[0]
- is_string_labels = isinstance(first_true_label, str) or (
- hasattr(first_true_label, "dtype") and first_true_label.dtype.kind in ("U", "O")
- )
-
- if is_string_labels:
- # Labels are already string labels, use them directly
- true_labels_str = [str(label) for label in true_labels]
- pred_labels_str = [str(label) for label in pred_labels]
- elif task == "binary":
- # Numeric binary labels - map 0/1 to No-RTS/RTS
- label_map = {0: "No-RTS", 1: "RTS"}
- true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
- pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
- else:
- # Numeric multiclass labels - use as is
- true_labels_str = [str(label) for label in true_labels]
- pred_labels_str = [str(label) for label in pred_labels]
-
- # Rename DataFrame indices and columns for display
- cm_df.index = true_labels_str
- cm_df.columns = pred_labels_str
-
- # Store raw counts for annotations
- cm_counts = cm_df.copy()
-
- # Normalize by row (true label) to get percentages
- cm_normalized = cm_df.div(cm_df.sum(axis=1), axis=0)
-
- # Create custom text annotations showing both percentage and count
- text_annotations = []
- for i, true_label in enumerate(true_labels_str):
- row_annotations = []
- for j, pred_label in enumerate(pred_labels_str):
- count = int(cm_counts.iloc[i, j])
- percentage = cm_normalized.iloc[i, j] * 100
- row_annotations.append(f"{percentage:.1f}%
({count:,})")
- text_annotations.append(row_annotations)
-
- # Create heatmap with normalized values
- fig = px.imshow(
- cm_normalized,
- labels=dict(x="Predicted Label", y="True Label", color="Proportion"),
- x=pred_labels_str,
- y=true_labels_str,
- color_continuous_scale="Blues",
- aspect="auto",
- zmin=0,
- zmax=1,
- )
-
- # Update with custom annotations
- fig.update_traces(
- text=text_annotations,
- texttemplate="%{text}",
- textfont={"size": 12},
- )
-
- # Update layout for better readability
- fig.update_layout(
- title="Confusion Matrix (Normalized by True Label)",
- xaxis_title="Predicted Label",
- yaxis_title="True Label",
- height=500,
- )
-
- # Update colorbar to show percentage
- fig.update_coloraxes(
- colorbar=dict(
- title="Proportion",
- tickformat=".0%",
- )
- )
-
- st.plotly_chart(fig, width="stretch")
-
- st.caption(
- "📊 Values show **row-normalized percentages** (percentage of each true class predicted as each label). "
- "Raw counts shown in parentheses."
- )
-
- # Calculate and display metrics from confusion matrix
- col1, col2, col3 = st.columns(3)
-
- total_samples = int(cm_df.values.sum())
- correct_predictions = int(np.trace(cm_df.values))
- accuracy = correct_predictions / total_samples if total_samples > 0 else 0
-
- with col1:
- st.metric("Total Samples", f"{total_samples:,}")
-
- with col2:
- st.metric("Correct Predictions", f"{correct_predictions:,}")
-
- with col3:
- st.metric("Accuracy", f"{accuracy:.2%}")
-
- # Add detailed breakdown for binary classification
- if task == "binary":
- st.markdown("#### Binary Classification Metrics")
-
- # Extract TP, TN, FP, FN from confusion matrix
- # Assuming 0=No-RTS (negative), 1=RTS (positive)
- tn = int(cm_df.iloc[0, 0])
- fp = int(cm_df.iloc[0, 1])
- fn = int(cm_df.iloc[1, 0])
- tp = int(cm_df.iloc[1, 1])
-
- col1, col2, col3, col4 = st.columns(4)
-
- with col1:
- st.metric("True Positives (TP)", f"{tp:,}")
-
- with col2:
- st.metric("True Negatives (TN)", f"{tn:,}")
-
- with col3:
- st.metric("False Positives (FP)", f"{fp:,}")
-
- with col4:
- st.metric("False Negatives (FN)", f"{fn:,}")
diff --git a/src/entropice/dashboard/plots/hyperparameter_space.py b/src/entropice/dashboard/plots/hyperparameter_space.py
new file mode 100644
index 0000000..77f57e6
--- /dev/null
+++ b/src/entropice/dashboard/plots/hyperparameter_space.py
@@ -0,0 +1,417 @@
+"""Hyperparameter space plotting functions."""
+
+import matplotlib.colors as mcolors
+import pandas as pd
+import plotly.graph_objects as go
+
+from entropice.dashboard.utils.colors import get_cmap, get_palette
+
+
+def plot_performance_summary(results: pd.DataFrame, refit_metric: str) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
+ """Compute performance summary statistics.
+
+ Args:
+ results: DataFrame with CV results.
+ refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
+
+ Returns:
+ Tuple of (best_scores_df, score_stats_df, best_params_dict).
+
+ """
+ # Get all test score columns
+ score_cols = [col for col in results.columns if col.startswith("mean_test_")]
+
+ if not score_cols:
+ return pd.DataFrame(), pd.DataFrame(), {}
+
+ # Calculate best scores
+ best_scores = []
+ for col in score_cols:
+ metric_name = col.replace("mean_test_", "").replace("_", " ").title()
+ best_score = results[col].max()
+ best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
+
+ # Calculate score statistics
+ score_stats = []
+ for col in score_cols:
+ metric_name = col.replace("mean_test_", "").replace("_", " ").title()
+ mean_score = results[col].mean()
+ std_score = results[col].std()
+ score_stats.append(
+ {
+ "Metric": metric_name,
+ "Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}",
+ }
+ )
+
+ # Get best parameter combination
+ refit_col = f"mean_test_{refit_metric}"
+ if refit_col not in results.columns and score_cols:
+ refit_col = score_cols[0]
+
+ best_idx = results[refit_col].idxmax()
+ best_row = results.loc[best_idx]
+
+ # Extract parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+ best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
+
+ return pd.DataFrame(best_scores), pd.DataFrame(score_stats), best_params
+
+
+def plot_parameter_distributions(results: pd.DataFrame, param_grid: dict | None = None) -> dict[str, go.Figure]:
+ """Create histogram charts for parameter distributions.
+
+ Args:
+ results: DataFrame with CV results.
+ param_grid: Optional parameter grid with distribution information.
+
+ Returns:
+ Dictionary mapping parameter names to Plotly figures.
+
+ """
+ # Get parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+
+ if not param_cols:
+ return {}
+
+ cmap = get_cmap("parameter_distribution")
+ bar_color = mcolors.rgb2hex(cmap(0.5))
+
+ charts = {}
+ for param_col in param_cols:
+ param_name = param_col.replace("param_", "")
+ param_values = results[param_col].dropna()
+
+ if len(param_values) == 0:
+ continue
+
+ # Determine if parameter is numeric
+ if pd.api.types.is_numeric_dtype(param_values):
+ # Create histogram for numeric parameters
+ fig = go.Figure()
+ fig.add_trace(
+ go.Histogram(
+ x=param_values,
+ nbinsx=30,
+ marker_color=bar_color,
+ name=param_name,
+ )
+ )
+ fig.update_layout(
+ title=f"Distribution of {param_name}",
+ xaxis_title=param_name,
+ yaxis_title="Count",
+ height=400,
+ showlegend=False,
+ )
+ else:
+ # Create bar chart for categorical parameters
+ value_counts = param_values.value_counts().reset_index()
+ value_counts.columns = [param_name, "count"]
+ fig = go.Figure()
+ fig.add_trace(
+ go.Bar(
+ x=value_counts[param_name],
+ y=value_counts["count"],
+ marker_color=bar_color,
+ name=param_name,
+ )
+ )
+ fig.update_layout(
+ title=f"Distribution of {param_name}",
+ xaxis_title=param_name,
+ yaxis_title="Count",
+ height=400,
+ showlegend=False,
+ )
+
+ charts[param_name] = fig
+
+ return charts
+
+
+def plot_score_vs_parameters(
+ results: pd.DataFrame, metric: str, param_grid: dict | None = None
+) -> dict[str, go.Figure]:
+ """Create scatter plots of score vs each parameter.
+
+ Args:
+ results: DataFrame with CV results.
+ metric: The metric to plot (e.g., 'f1', 'accuracy').
+ param_grid: Optional parameter grid with distribution information.
+
+ Returns:
+ Dictionary mapping parameter names to Plotly figures.
+
+ """
+ score_col = f"mean_test_{metric}"
+ if score_col not in results.columns:
+ return {}
+
+ # Get parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+
+ if not param_cols:
+ return {}
+
+ # Get colormap
+ hex_colors = get_palette(metric, n_colors=256)
+
+ charts = {}
+ for param_col in param_cols:
+ param_name = param_col.replace("param_", "")
+ param_values = results[param_col].dropna()
+
+ if len(param_values) == 0:
+ continue
+
+ # Check if this parameter uses log scale
+ use_log = False
+ if param_grid and param_name in param_grid:
+ param_config = param_grid[param_name]
+ if isinstance(param_config, dict) and param_config.get("distribution") == "loguniform":
+ use_log = True
+
+ # Create scatter plot
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=results[param_col],
+ y=results[score_col],
+ mode="markers",
+ marker={
+ "size": 8,
+ "color": results[score_col],
+ "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)],
+ "showscale": False,
+ "opacity": 0.6,
+ },
+ text=[
+ f"{param_name}: {val}
Score: {score:.4f}"
+ for val, score in zip(results[param_col], results[score_col])
+ ],
+ hovertemplate="%{text}",
+ )
+ )
+ fig.update_layout(
+ title=f"{metric.replace('_', ' ').title()} vs {param_name}",
+ xaxis_title=param_name,
+ xaxis_type="log" if use_log else "linear",
+ yaxis_title=metric.replace("_", " ").title(),
+ height=400,
+ showlegend=False,
+ )
+
+ charts[param_name] = fig
+
+ return charts
+
+
+def plot_parameter_correlations(results: pd.DataFrame, metric: str) -> go.Figure | None:
+ """Create correlation bar chart between parameters and score.
+
+ Args:
+ results: DataFrame with CV results.
+ metric: The metric to analyze (e.g., 'f1', 'accuracy').
+
+ Returns:
+ Plotly figure or None if no numeric parameters found.
+
+ """
+ score_col = f"mean_test_{metric}"
+ if score_col not in results.columns:
+ return None
+
+ # Get numeric parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+ numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
+
+ if not numeric_params:
+ return None
+
+ # Calculate correlations
+ correlations = []
+ for param_col in numeric_params:
+ param_name = param_col.replace("param_", "")
+ corr = results[[param_col, score_col]].corr().iloc[0, 1]
+ correlations.append({"Parameter": param_name, "Correlation": corr})
+
+ corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
+
+ # Get colormap (use diverging colormap for correlation)
+ hex_colors = get_palette("correlation", n_colors=256)
+
+ # Create bar chart
+ fig = go.Figure()
+ fig.add_trace(
+ go.Bar(
+ x=corr_df["Correlation"],
+ y=corr_df["Parameter"],
+ orientation="h",
+ marker={
+ "color": corr_df["Correlation"],
+ "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)],
+ "cmin": -1,
+ "cmax": 1,
+ "showscale": False,
+ },
+ text=[f"{c:.3f}" for c in corr_df["Correlation"]],
+ hovertemplate="%{y}
Correlation: %{x:.3f}",
+ )
+ )
+ fig.update_layout(
+ xaxis_title="Correlation with Score",
+ yaxis_title="Parameter",
+ height=max(300, len(correlations) * 30),
+ showlegend=False,
+ )
+
+ return fig
+
+
+def plot_parameter_interactions(results: pd.DataFrame, metric: str, param_grid: dict | None = None) -> list[go.Figure]:
+ """Create scatter plots showing parameter interactions.
+
+ Args:
+ results: DataFrame with CV results.
+ metric: The metric to visualize (e.g., 'f1', 'accuracy').
+ param_grid: Optional parameter grid with distribution information.
+
+ Returns:
+ List of Plotly figures showing parameter interactions.
+
+ """
+ score_col = f"mean_test_{metric}"
+ if score_col not in results.columns:
+ return []
+
+ # Get numeric parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+ numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
+
+ if len(numeric_params) < 2:
+ return []
+
+ # Get colormap
+ hex_colors = get_palette(metric, n_colors=256)
+
+ # Create scatter plots for parameter pairs
+ charts = []
+ param_names = [col.replace("param_", "") for col in numeric_params]
+
+ for i, x_param in enumerate(param_names[:-1]):
+ for y_param in param_names[i + 1 :]:
+ x_col = f"param_{x_param}"
+ y_col = f"param_{y_param}"
+
+ # Check if parameters use log scale
+ x_use_log = False
+ y_use_log = False
+ if param_grid:
+ if x_param in param_grid:
+ x_config = param_grid[x_param]
+ if isinstance(x_config, dict) and x_config.get("distribution") == "loguniform":
+ x_use_log = True
+ if y_param in param_grid:
+ y_config = param_grid[y_param]
+ if isinstance(y_config, dict) and y_config.get("distribution") == "loguniform":
+ y_use_log = True
+
+ fig = go.Figure()
+ fig.add_trace(
+ go.Scatter(
+ x=results[x_col],
+ y=results[y_col],
+ mode="markers",
+ marker={
+ "size": 8,
+ "color": results[score_col],
+ "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)],
+ "showscale": True,
+ "colorbar": {"title": metric.replace("_", " ").title()},
+ "opacity": 0.7,
+ },
+ text=[
+ f"{x_param}: {x_val}
{y_param}: {y_val}
Score: {score:.4f}"
+ for x_val, y_val, score in zip(results[x_col], results[y_col], results[score_col])
+ ],
+ hovertemplate="%{text}",
+ )
+ )
+ fig.update_layout(
+ title=f"{metric.replace('_', ' ').title()} by {x_param} and {y_param}",
+ xaxis_title=x_param,
+ xaxis_type="log" if x_use_log else "linear",
+ yaxis_title=y_param,
+ yaxis_type="log" if y_use_log else "linear",
+ height=500,
+ width=500,
+ )
+
+ charts.append(fig)
+
+ return charts
+
+
+def plot_score_evolution(results: pd.DataFrame, metric: str) -> go.Figure | None:
+ """Create line chart showing score evolution over iterations.
+
+ Args:
+ results: DataFrame with CV results.
+ metric: The metric to visualize (e.g., 'f1', 'accuracy').
+
+ Returns:
+ Plotly figure or None if metric not found.
+
+ """
+ score_col = f"mean_test_{metric}"
+ if score_col not in results.columns:
+ return None
+
+ # Add iteration number
+ iterations = list(range(len(results)))
+ scores = results[score_col].to_numpy()
+ best_so_far = results[score_col].cummax().to_numpy()
+
+ # Get colormap
+ cmap = get_cmap("score_evolution")
+ score_color = mcolors.rgb2hex(cmap(0.3))
+ best_color = mcolors.rgb2hex(cmap(0.7))
+
+ # Create line chart
+ fig = go.Figure()
+
+ fig.add_trace(
+ go.Scatter(
+ x=iterations,
+ y=scores,
+ mode="lines",
+ name="Score",
+ line={"color": score_color, "width": 1},
+ opacity=0.6,
+ hovertemplate="Iteration: %{x}
Score: %{y:.4f}",
+ )
+ )
+
+ fig.add_trace(
+ go.Scatter(
+ x=iterations,
+ y=best_so_far,
+ mode="lines",
+ name="Best So Far",
+ line={"color": best_color, "width": 2},
+ hovertemplate="Iteration: %{x}
Best So Far: %{y:.4f}",
+ )
+ )
+
+ fig.update_layout(
+ title=f"{metric.replace('_', ' ').title()} Evolution",
+ xaxis_title="Iteration",
+ yaxis_title=metric.replace("_", " ").title(),
+ height=300,
+ hovermode="x unified",
+ )
+
+ return fig
diff --git a/src/entropice/dashboard/plots/metrics.py b/src/entropice/dashboard/plots/metrics.py
new file mode 100644
index 0000000..2ba9ba3
--- /dev/null
+++ b/src/entropice/dashboard/plots/metrics.py
@@ -0,0 +1,97 @@
+"""Metrics visualization plots."""
+
+import numpy as np
+import plotly.graph_objects as go
+import xarray as xr
+
+
+def plot_confusion_matrix(cm_data: xr.DataArray, title: str = "Confusion Matrix", normalize: str = "none") -> go.Figure:
+ """Plot an interactive confusion matrix heatmap.
+
+ Args:
+ cm_data: XArray DataArray with confusion matrix data (dimensions: true_label, predicted_label).
+ title: Title for the plot.
+ normalize: Normalization mode - "none", "true", or "pred".
+
+ Returns:
+ Plotly figure with the interactive confusion matrix heatmap.
+
+ """
+ # Get the data as numpy array
+ cm_array = cm_data.values.astype(float)
+ labels = cm_data.coords["true_label"].values.tolist()
+
+ # Store original counts for display
+ cm_counts = cm_data.values
+
+ # Apply normalization
+ if normalize == "true":
+ # Normalize over true labels (rows) - each row sums to 1
+ row_sums = cm_array.sum(axis=1, keepdims=True)
+ cm_normalized = np.divide(cm_array, row_sums, where=row_sums != 0)
+ colorbar_title = "Proportion"
+ elif normalize == "pred":
+ # Normalize over predicted labels (columns) - each column sums to 1
+ col_sums = cm_array.sum(axis=0, keepdims=True)
+ cm_normalized = np.divide(cm_array, col_sums, where=col_sums != 0)
+ colorbar_title = "Proportion"
+ else:
+ # No normalization
+ cm_normalized = cm_array
+ colorbar_title = "Count"
+
+ # Create annotations for the heatmap
+ annotations = []
+ for i, true_label in enumerate(labels):
+ for j, pred_label in enumerate(labels):
+ count = int(cm_counts[i, j])
+ normalized_val = cm_normalized[i, j]
+
+ # Format text based on normalization mode
+ if normalize == "none":
+ # Show count and percentage of total
+ total = cm_counts.sum()
+ pct = (count / total * 100) if total > 0 else 0
+ text = f"{count}
({pct:.1f}%)"
+ else:
+ # Show percentage only for normalized versions
+ text = f"{normalized_val:.1%}"
+
+ # Determine text color based on normalized value
+ threshold = cm_normalized.max() / 2 if cm_normalized.max() > 0 else 0.5
+ text_color = "white" if normalized_val > threshold else "black"
+
+ annotations.append(
+ {
+ "x": pred_label,
+ "y": true_label,
+ "text": text,
+ "showarrow": False,
+ "font": {"size": 10, "color": text_color},
+ }
+ )
+
+ # Create the heatmap with normalized values for coloring
+ fig = go.Figure(
+ data=go.Heatmap(
+ z=cm_normalized,
+ x=labels,
+ y=labels,
+ colorscale="Blues",
+ colorbar={"title": colorbar_title},
+ hoverongaps=False,
+ hovertemplate="True: %{y}
Predicted: %{x}
Count: %{customdata}",
+ customdata=cm_counts,
+ )
+ )
+
+ # Add annotations
+ fig.update_layout(
+ annotations=annotations,
+ xaxis={"title": "Predicted Label", "side": "bottom"},
+ yaxis={"title": "True Label", "autorange": "reversed"},
+ width=600,
+ height=550,
+ )
+
+ return fig
diff --git a/src/entropice/dashboard/plots/regression.py b/src/entropice/dashboard/plots/regression.py
new file mode 100644
index 0000000..4012c0e
--- /dev/null
+++ b/src/entropice/dashboard/plots/regression.py
@@ -0,0 +1,180 @@
+"""Regression analysis plotting functions."""
+
+from typing import cast
+
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
+
+from entropice.dashboard.utils.colors import get_palette
+
+
+def plot_regression_scatter(
+ y_true: np.ndarray | pd.Series,
+ y_pred: np.ndarray | pd.Series,
+ title: str = "True vs Predicted",
+) -> go.Figure:
+ """Create scatter plot of true vs predicted values for regression.
+
+ Args:
+ y_true: True target values.
+ y_pred: Predicted target values.
+ title: Title for the plot.
+
+ Returns:
+ Plotly figure with regression scatter plot.
+
+ """
+ # Convert to numpy arrays if needed
+ y_true_np = cast(np.ndarray, y_true.to_numpy()) if isinstance(y_true, pd.Series) else y_true
+ y_pred_np = cast(np.ndarray, y_pred.to_numpy()) if isinstance(y_pred, pd.Series) else y_pred
+
+ # Calculate metrics
+ mse = np.mean((y_true_np - y_pred_np) ** 2)
+ mae = np.mean(np.abs(y_true_np - y_pred_np))
+ r2 = 1 - (np.sum((y_true_np - y_pred_np) ** 2) / np.sum((y_true_np - np.mean(y_true_np)) ** 2))
+
+ # Get colormap
+ hex_colors = get_palette("r2", n_colors=256)
+
+ # Calculate point density for coloring
+ from scipy.stats import gaussian_kde
+
+ try:
+ # Create KDE for density estimation
+ xy = np.vstack([y_true_np, y_pred_np])
+ kde = gaussian_kde(xy)
+ density = kde(xy)
+ except (np.linalg.LinAlgError, ValueError):
+ # Fallback if KDE fails (e.g., all points identical)
+ density = np.ones(len(y_true_np))
+
+ # Create figure
+ fig = go.Figure()
+
+ # Add scatter plot
+ fig.add_trace(
+ go.Scatter(
+ x=y_true_np,
+ y=y_pred_np,
+ mode="markers",
+ marker={
+ "size": 6,
+ "color": density,
+ "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)],
+ "showscale": False,
+ "opacity": 0.6,
+ },
+ text=[f"True: {true:.3f}
Pred: {pred:.3f}" for true, pred in zip(y_true_np, y_pred_np)],
+ hovertemplate="%{text}",
+ name="Data",
+ )
+ )
+
+ # Add diagonal line (perfect prediction)
+ min_val = min(y_true_np.min(), y_pred_np.min())
+ max_val = max(y_true_np.max(), y_pred_np.max())
+ fig.add_trace(
+ go.Scatter(
+ x=[min_val, max_val],
+ y=[min_val, max_val],
+ mode="lines",
+ line={"color": "red", "dash": "dash", "width": 2},
+ name="Perfect Prediction",
+ hovertemplate="y = x",
+ )
+ )
+
+ # Add metrics as annotation
+ metrics_text = f"R² = {r2:.4f}
MSE = {mse:.4f}
MAE = {mae:.4f}"
+
+ fig.add_annotation(
+ x=0.02,
+ y=0.98,
+ xref="paper",
+ yref="paper",
+ text=metrics_text,
+ showarrow=False,
+ bgcolor="white",
+ bordercolor="black",
+ borderwidth=1,
+ xanchor="left",
+ yanchor="top",
+ font={"size": 12},
+ )
+
+ fig.update_layout(
+ title=title,
+ xaxis_title="True Values",
+ yaxis_title="Predicted Values",
+ height=500,
+ showlegend=True,
+ legend={"x": 0.98, "y": 0.02, "xanchor": "right", "yanchor": "bottom"},
+ )
+
+ # Make axes equal
+ fig.update_xaxes(scaleanchor="y", scaleratio=1)
+
+ return fig
+
+
+def plot_residuals(
+ y_true: np.ndarray | pd.Series,
+ y_pred: np.ndarray | pd.Series,
+ title: str = "Residual Plot",
+) -> go.Figure:
+ """Create residual plot for regression diagnostics.
+
+ Args:
+ y_true: True target values.
+ y_pred: Predicted target values.
+ title: Title for the plot.
+
+ Returns:
+ Plotly figure with residual plot.
+
+ """
+ # Convert to numpy arrays if needed
+ y_true_np = cast(np.ndarray, y_true.to_numpy()) if isinstance(y_true, pd.Series) else y_true
+ y_pred_np = cast(np.ndarray, y_pred.to_numpy()) if isinstance(y_pred, pd.Series) else y_pred
+
+ # Calculate residuals
+ residuals = y_true_np - y_pred_np
+
+ # Get colormap
+ hex_colors = get_palette("r2", n_colors=256)
+
+ # Create figure
+ fig = go.Figure()
+
+ # Add scatter plot
+ fig.add_trace(
+ go.Scatter(
+ x=y_pred,
+ y=residuals,
+ mode="markers",
+ marker={
+ "size": 6,
+ "color": np.abs(residuals),
+ "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)],
+ "showscale": True,
+ "colorbar": {"title": "Abs Residual"},
+ "opacity": 0.6,
+ },
+ text=[f"Pred: {pred:.3f}
Residual: {res:.3f}" for pred, res in zip(y_pred, residuals)],
+ hovertemplate="%{text}",
+ )
+ )
+
+ # Add zero line
+ fig.add_hline(y=0, line_dash="dash", line_color="red", line_width=2)
+
+ fig.update_layout(
+ title=title,
+ xaxis_title="Predicted Values",
+ yaxis_title="Residuals (True - Predicted)",
+ height=400,
+ showlegend=False,
+ )
+
+ return fig
diff --git a/src/entropice/dashboard/sections/cv_result.py b/src/entropice/dashboard/sections/cv_result.py
new file mode 100644
index 0000000..b1da3b2
--- /dev/null
+++ b/src/entropice/dashboard/sections/cv_result.py
@@ -0,0 +1,185 @@
+"""Training Result Sections."""
+
+import streamlit as st
+
+from entropice.dashboard.plots.metrics import plot_confusion_matrix
+from entropice.dashboard.utils.formatters import format_metric_name
+from entropice.dashboard.utils.loaders import TrainingResult
+from entropice.dashboard.utils.stats import CVMetricStatistics
+from entropice.utils.types import GridConfig
+
+
+def render_run_information(selected_result: TrainingResult, refit_metric):
+ """Render training run configuration overview.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+ refit_metric: The refit metric used for model selection.
+
+ """
+ st.header("📋 Run Information")
+
+ grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type]
+
+ col1, col2, col3, col4, col5 = st.columns(5)
+ with col1:
+ st.metric("Task", selected_result.settings.task.capitalize())
+ with col2:
+ st.metric("Target", selected_result.settings.target.capitalize())
+ with col3:
+ st.metric("Grid", grid_config.display_name)
+ with col4:
+ st.metric("Model", selected_result.settings.model.upper())
+ with col5:
+ st.metric("Trials", len(selected_result.results))
+
+ st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
+
+
+def _render_metrics(metrics: dict[str, float]):
+ """Render a set of metrics in a two-column layout.
+
+ Args:
+ metrics: Dictionary of metric names and their values.
+
+ """
+ ncols = min(5, len(metrics))
+ cols = st.columns(ncols)
+ for idx, (metric_name, metric_value) in enumerate(metrics.items()):
+ with cols[idx % ncols]:
+ st.metric(format_metric_name(metric_name), f"{metric_value:.4f}")
+
+
+def render_metrics_section(selected_result: TrainingResult):
+ """Render test metrics overview showing final model performance.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+
+ """
+ # Test
+ st.header("🎯 Test Set Performance")
+ st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)")
+ _render_metrics(selected_result.test_metrics)
+
+ # Train
+ st.header("🏋️♂️ Training Set Performance")
+ st.caption("Performance metrics on the training set (best model from hyperparameter search)")
+ _render_metrics(selected_result.train_metrics)
+
+ # Combined / All
+ st.header("🧮 Overall Performance")
+ st.caption("Overall performance metrics combining training and test sets")
+ _render_metrics(selected_result.combined_metrics)
+
+
+@st.fragment
+def render_confusion_matrices(selected_result: TrainingResult):
+ """Render confusion matrices for classification tasks.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+
+ """
+ st.header("🎭 Confusion Matrices")
+
+ # Check if this is a classification task
+ if selected_result.settings.task not in ["binary", "count_regimes", "density_regimes"]:
+ st.info(
+ "📊 Confusion matrices are only available for classification tasks "
+ "(binary, count_regimes, density_regimes)."
+ )
+ st.caption("Coming soon for regression tasks: residual plots and error distributions.")
+ return
+
+ # Check if confusion matrix data is available
+ if selected_result.confusion_matrix is None:
+ st.warning("⚠️ No confusion matrix data found for this training result.")
+ return
+
+ cm = selected_result.confusion_matrix
+
+ # Add normalization selection
+ st.subheader("Display Options")
+ normalize_option = st.radio(
+ "Normalization",
+ options=["No normalization", "Normalize over True Labels", "Normalize over Predicted Labels"],
+ horizontal=True,
+ help="Choose how to normalize the confusion matrix values",
+ )
+
+ # Map selection to normalization mode
+ normalize_map = {
+ "No normalization": "none",
+ "Normalize over True Labels": "true",
+ "Normalize over Predicted Labels": "pred",
+ }
+ normalize_mode = normalize_map[normalize_option]
+
+ cols = st.columns(3)
+
+ with cols[0]:
+ # Test Set Confusion Matrix
+ st.subheader("Test Set")
+ st.caption("Held-out test set")
+ fig_test = plot_confusion_matrix(cm["test"], title="Test Set", normalize=normalize_mode)
+ st.plotly_chart(fig_test, width="stretch")
+ with cols[1]:
+ # Training Set Confusion Matrix
+ st.subheader("Training Set")
+ st.caption("Training set")
+ fig_train = plot_confusion_matrix(cm["train"], title="Training Set", normalize=normalize_mode)
+ st.plotly_chart(fig_train, width="stretch")
+ with cols[2]:
+ # Combined Confusion Matrix
+ st.subheader("Combined")
+ st.caption("Train + Test sets")
+ fig_combined = plot_confusion_matrix(cm["combined"], title="Combined", normalize=normalize_mode)
+ st.plotly_chart(fig_combined, width="stretch")
+
+
+def render_cv_statistics_section(cv_stats: CVMetricStatistics, test_score: float):
+ """Render cross-validation statistics for selected metric.
+
+ Args:
+ cv_stats: CVMetricStatistics object containing cross-validation statistics.
+ test_score: The test set score for the selected metric.
+
+ """
+ st.header("📈 Cross-Validation Statistics")
+ st.caption("Performance during hyperparameter search (averaged across CV folds)")
+
+ col1, col2, col3, col4, col5 = st.columns(5)
+
+ with col1:
+ st.metric("Best Score", f"{cv_stats.best_score:.4f}")
+ with col2:
+ st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
+ with col3:
+ st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
+ with col4:
+ st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
+ with col5:
+ st.metric("Median Score", f"{cv_stats.median_score:.4f}")
+
+ if cv_stats.mean_cv_std is not None:
+ st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
+
+ # Compare with test metric
+ st.subheader("CV vs Test Performance")
+
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Best CV Score", f"{cv_stats.best_score:.4f}")
+ with col2:
+ st.metric("Test Score", f"{test_score:.4f}")
+ with col3:
+ delta = test_score - cv_stats.best_score
+ delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0
+ st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%")
+
+ if abs(delta) > cv_stats.std_score:
+ st.warning(
+ "⚠️ Test performance differs significantly (larger than the CV standard deviation) from CV performance. "
+ "This may indicate overfitting or data distribution mismatch between training and test sets."
+ )
diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py
index c8a59a4..ecfb448 100644
--- a/src/entropice/dashboard/sections/experiment_results.py
+++ b/src/entropice/dashboard/sections/experiment_results.py
@@ -2,15 +2,16 @@
from datetime import datetime
+import pandas as pd
import streamlit as st
-from entropice.dashboard.utils.loaders import TrainingResult
+from entropice.dashboard.utils.loaders import AutogluonTrainingResult, TrainingResult
from entropice.utils.types import (
GridConfig,
)
-def render_training_results_summary(training_results: list[TrainingResult]):
+def render_training_results_summary(training_results: list[TrainingResult | AutogluonTrainingResult]):
"""Render summary metrics for training results."""
st.header("📊 Training Results Summary")
col1, col2, col3, col4 = st.columns(4)
@@ -23,7 +24,7 @@ def render_training_results_summary(training_results: list[TrainingResult]):
st.metric("Total Runs", len(training_results))
with col3:
- models = {tr.settings.model for tr in training_results}
+ models = {tr.settings.model for tr in training_results if hasattr(tr.settings, "model")}
st.metric("Model Types", len(models))
with col4:
@@ -33,14 +34,14 @@ def render_training_results_summary(training_results: list[TrainingResult]):
@st.fragment
-def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901
+def render_experiment_results(training_results: list[TrainingResult | AutogluonTrainingResult]): # noqa: C901
"""Render detailed experiment results table and expandable details."""
st.header("🎯 Experiment Results")
# Filters
experiments = sorted({tr.experiment for tr in training_results if tr.experiment})
tasks = sorted({tr.settings.task for tr in training_results})
- models = sorted({tr.settings.model for tr in training_results})
+ models = sorted({tr.settings.model if isinstance(tr, TrainingResult) else "autogluon" for tr in training_results})
grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results})
# Create filter columns
@@ -87,14 +88,26 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment]
if selected_task != "All":
filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task]
- if selected_model != "All":
- filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model]
+ if selected_model != "All" and selected_model != "autogluon":
+ filtered_results = [
+ tr for tr in filtered_results if isinstance(tr, TrainingResult) and tr.settings.model == selected_model
+ ]
+ elif selected_model == "autogluon":
+ filtered_results = [tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)]
if selected_grid != "All":
filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid]
st.subheader("Results Table")
- summary_df = TrainingResult.to_dataframe(filtered_results)
+ summary_df = TrainingResult.to_dataframe([tr for tr in filtered_results if isinstance(tr, TrainingResult)])
+ autogluon_df = AutogluonTrainingResult.to_dataframe(
+ [tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)]
+ )
+ if len(summary_df) == 0:
+ summary_df = autogluon_df
+ elif len(autogluon_df) > 0:
+ summary_df = pd.concat([summary_df, autogluon_df], ignore_index=True)
+
# Display with color coding for best scores
st.dataframe(
summary_df,
@@ -107,6 +120,8 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
for tr in filtered_results:
tr_info = tr.display_info
display_name = tr_info.get_display_name("model_first")
+ model = "autogluon" if isinstance(tr, AutogluonTrainingResult) else tr.settings.model
+ cv_splits = tr.settings.cv_splits if hasattr(tr.settings, "cv_splits") else "N/A"
with st.expander(display_name):
col1, col2 = st.columns([1, 2])
@@ -117,12 +132,12 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
f"- **Experiment:** {tr.experiment}\n"
f"- **Task:** {tr.settings.task}\n"
f"- **Target:** {tr.settings.target}\n"
- f"- **Model:** {tr.settings.model}\n"
+ f"- **Model:** {model}\n"
f"- **Grid:** {grid_config.display_name}\n"
f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n"
f"- **Temporal Mode:** {tr.settings.temporal_mode}\n"
f"- **Members:** {', '.join(tr.settings.members)}\n"
- f"- **CV Splits:** {tr.settings.cv_splits}\n"
+ f"- **CV Splits:** {cv_splits}\n"
f"- **Classes:** {tr.settings.classes}\n"
)
@@ -140,26 +155,29 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
file_str += f"- 📄 `{file.name}`\n"
st.write(file_str)
with col2:
- st.write("**CV Score Summary:**")
-
- # Extract all test scores
- metric_df = tr.get_metric_dataframe()
- if metric_df is not None:
- st.dataframe(metric_df, width="stretch", hide_index=True)
+ if isinstance(tr, AutogluonTrainingResult):
+ st.write("**Leaderboard:**")
+ st.dataframe(tr.leaderboard, width="stretch", hide_index=True)
else:
- st.write("No test scores found in results.")
+ st.write("**CV Score Summary:**")
+ # Extract all test scores
+ metric_df = tr.get_metric_dataframe()
+ if metric_df is not None:
+ st.dataframe(metric_df, width="stretch", hide_index=True)
+ else:
+ st.write("No test scores found in results.")
- # Show parameter space explored
- if "initial_K" in tr.results.columns: # Common parameter
- st.write("\n**Parameter Ranges Explored:**")
- for param in ["initial_K", "eps_cl", "eps_e"]:
- if param in tr.results.columns:
- min_val = tr.results[param].min()
- max_val = tr.results[param].max()
- unique_vals = tr.results[param].nunique()
- st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
+ # Show parameter space explored
+ if "initial_K" in tr.results.columns: # Common parameter
+ st.write("\n**Parameter Ranges Explored:**")
+ for param in ["initial_K", "eps_cl", "eps_e"]:
+ if param in tr.results.columns:
+ min_val = tr.results[param].min()
+ max_val = tr.results[param].max()
+ unique_vals = tr.results[param].nunique()
+ st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
- st.write("**CV Results DataFrame:**")
- st.dataframe(tr.results, width="stretch", hide_index=True)
+ st.write("**CV Results DataFrame:**")
+ st.dataframe(tr.results, width="stretch", hide_index=True)
st.write(f"\n**Path:** `{tr.path}`")
diff --git a/src/entropice/dashboard/sections/hparam_space.py b/src/entropice/dashboard/sections/hparam_space.py
new file mode 100644
index 0000000..194202a
--- /dev/null
+++ b/src/entropice/dashboard/sections/hparam_space.py
@@ -0,0 +1,172 @@
+"""Hyperparameter Space Visualization Section."""
+
+import streamlit as st
+
+from entropice.dashboard.plots.hyperparameter_space import (
+ plot_parameter_correlations,
+ plot_parameter_distributions,
+ plot_parameter_interactions,
+ plot_score_evolution,
+ plot_score_vs_parameters,
+)
+from entropice.dashboard.utils.formatters import format_metric_name
+from entropice.dashboard.utils.loaders import TrainingResult
+
+
+def _render_performance_summary(results, refit_metric: str):
+ """Render performance summary subsection."""
+ best_idx = results[f"mean_test_{refit_metric}"].idxmax()
+ best_row = results.loc[best_idx]
+ # Extract parameter columns
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+ best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
+
+ # Display best parameter combination
+ if not best_params:
+ return
+
+ with st.container(border=True):
+ st.subheader("🏆 Best Parameter Combination")
+ st.caption(f"Parameters of the best model (selected by {format_metric_name(refit_metric)} score)")
+ n_params = len(best_params)
+ cols = st.columns(n_params)
+ for idx, (param_name, param_value) in enumerate(best_params.items()):
+ with cols[idx]:
+ # Format value based on type and magnitude
+ if isinstance(param_value, int):
+ formatted_value = f"{param_value:.0f}"
+ elif isinstance(param_value, float):
+ # Use scientific notation for very small numbers
+ if abs(param_value) < 0.001 and param_value != 0:
+ formatted_value = f"{param_value:.2e}"
+ else:
+ formatted_value = f"{param_value:.4f}"
+ else:
+ formatted_value = str(param_value)
+
+ st.metric(param_name, formatted_value)
+
+
+def _render_parameter_distributions(results, param_grid: dict | None):
+ """Render parameter distributions subsection."""
+ st.subheader("Parameter Distributions")
+ st.caption("Distribution of hyperparameter values explored during random search")
+
+ param_charts = plot_parameter_distributions(results, param_grid)
+
+ if not param_charts:
+ st.info("No parameter distribution data available.")
+ return
+
+ # Display charts in a grid
+ param_names = list(param_charts.keys())
+ n_cols = min(3, len(param_names))
+ n_rows = (len(param_names) + n_cols - 1) // n_cols
+
+ for row in range(n_rows):
+ cols = st.columns(n_cols)
+ for col_idx in range(n_cols):
+ param_idx = row * n_cols + col_idx
+ if param_idx < len(param_names):
+ param_name = param_names[param_idx]
+ with cols[col_idx]:
+ st.plotly_chart(param_charts[param_name], width="stretch")
+
+
+def _render_score_evolution(results, selected_metric: str):
+ """Render score evolution subsection."""
+ st.subheader("Score Evolution Over Iterations")
+ st.caption(f"How {format_metric_name(selected_metric)} evolved during the random search")
+
+ evolution_chart = plot_score_evolution(results, selected_metric)
+ if evolution_chart:
+ st.plotly_chart(evolution_chart, width="stretch")
+ else:
+ st.warning(f"Score evolution not available for metric: {selected_metric}")
+
+
+def _render_score_vs_parameters(results, selected_metric: str, param_grid: dict | None):
+ """Render score vs parameters subsection."""
+ st.subheader("Score vs Individual Parameters")
+ st.caption(f"Relationship between {format_metric_name(selected_metric)} and each hyperparameter")
+
+ score_vs_param_charts = plot_score_vs_parameters(results, selected_metric, param_grid)
+
+ if not score_vs_param_charts:
+ st.info("No score vs parameter data available.")
+ return
+
+ param_names = list(score_vs_param_charts.keys())
+ n_cols = min(2, len(param_names))
+ n_rows = (len(param_names) + n_cols - 1) // n_cols
+
+ for row in range(n_rows):
+ cols = st.columns(n_cols)
+ for col_idx in range(n_cols):
+ param_idx = row * n_cols + col_idx
+ if param_idx < len(param_names):
+ param_name = param_names[param_idx]
+ with cols[col_idx]:
+ st.plotly_chart(score_vs_param_charts[param_name], width="stretch")
+
+
+def _render_parameter_correlations(results, selected_metric: str):
+ """Render parameter correlations subsection."""
+ st.subheader("Parameter-Score Correlations")
+ st.caption(f"Correlation between numeric parameters and {format_metric_name(selected_metric)}")
+
+ corr_chart = plot_parameter_correlations(results, selected_metric)
+ if corr_chart:
+ st.plotly_chart(corr_chart, width="stretch")
+ else:
+ st.info("No numeric parameters found for correlation analysis.")
+
+
+def _render_parameter_interactions(results, selected_metric: str, param_grid: dict | None):
+ """Render parameter interactions subsection."""
+ st.subheader("Parameter Interactions")
+ st.caption(f"Interaction between parameter pairs and their effect on {format_metric_name(selected_metric)}")
+
+ interaction_charts = plot_parameter_interactions(results, selected_metric, param_grid)
+
+ if not interaction_charts:
+ st.info("Not enough numeric parameters for parameter interaction visualization.")
+ return
+
+ n_cols = min(2, len(interaction_charts))
+ n_rows = (len(interaction_charts) + n_cols - 1) // n_cols
+
+ for row in range(n_rows):
+ cols = st.columns(n_cols)
+ for col_idx in range(n_cols):
+ chart_idx = row * n_cols + col_idx
+ if chart_idx < len(interaction_charts):
+ with cols[col_idx]:
+ st.plotly_chart(interaction_charts[chart_idx], width="stretch")
+
+
+def render_hparam_space_section(selected_result: TrainingResult, selected_metric: str):
+ """Render the hyperparameter space visualization section.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+ selected_metric: The metric to focus analysis on.
+
+ """
+ st.header("🧩 Hyperparameter Space Exploration")
+
+ results = selected_result.results
+ refit_metric = selected_result._get_best_metric_name()
+ param_grid = selected_result.settings.param_grid
+
+ _render_performance_summary(results, refit_metric)
+
+ _render_parameter_distributions(results, param_grid)
+
+ _render_score_evolution(results, selected_metric)
+
+ _render_score_vs_parameters(results, selected_metric, param_grid)
+
+ _render_parameter_correlations(results, selected_metric)
+
+ _render_parameter_interactions(results, selected_metric, param_grid)
diff --git a/src/entropice/dashboard/sections/regression_analysis.py b/src/entropice/dashboard/sections/regression_analysis.py
new file mode 100644
index 0000000..c819d06
--- /dev/null
+++ b/src/entropice/dashboard/sections/regression_analysis.py
@@ -0,0 +1,122 @@
+"""Regression Analysis Section."""
+
+import streamlit as st
+
+from entropice.dashboard.plots.regression import plot_regression_scatter, plot_residuals
+from entropice.dashboard.utils.loaders import TrainingResult
+from entropice.ml.dataset import DatasetEnsemble
+
+
+def render_regression_analysis(selected_result: TrainingResult):
+ """Render regression analysis with true vs predicted scatter plots.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+
+ """
+ st.header("📊 Regression Analysis")
+
+ # Check if this is a regression task
+ if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]:
+ st.info("📈 Regression analysis is only available for regression tasks (count, density).")
+ return
+
+ # Load predictions
+ predictions_df = selected_result.load_predictions()
+ if predictions_df is None:
+ st.warning("⚠️ No prediction data found for this training result.")
+ return
+
+ # Create DatasetEnsemble from settings
+ with st.spinner("Loading training data to get true values..."):
+ ensemble = DatasetEnsemble(
+ grid=selected_result.settings.grid,
+ level=selected_result.settings.level,
+ members=selected_result.settings.members,
+ temporal_mode=selected_result.settings.temporal_mode,
+ dimension_filters=selected_result.settings.dimension_filters,
+ variable_filters=selected_result.settings.variable_filters,
+ add_lonlat=selected_result.settings.add_lonlat,
+ )
+
+ # Create training set to get true values
+ training_set = ensemble.create_training_set(
+ task=selected_result.settings.task,
+ target=selected_result.settings.target,
+ device="cpu",
+ cache_mode="read",
+ )
+
+ # Get split information
+ split_series = training_set.split
+
+ # Merge predictions with true values and split info
+ # predictions_df should have 'cell_id' and 'predicted' columns
+ # training_set.targets has 'y' (true values) with cell_id as index
+ true_values = training_set.targets[["y"]].reset_index()
+
+ # Merge on cell_id
+ merged = predictions_df.merge(true_values, on="cell_id", how="inner")
+ merged["split"] = split_series.reindex(merged["cell_id"]).values
+
+ # Get train, test, and combined data
+ train_data = merged[merged["split"] == "train"]
+ test_data = merged[merged["split"] == "test"]
+
+ if len(train_data) == 0 or len(test_data) == 0:
+ st.error("❌ Could not properly split data into train and test sets.")
+ return
+
+ # Display scatter plots
+ st.subheader("True vs Predicted Values")
+ st.caption("Scatter plots showing the relationship between true and predicted values")
+
+ cols = st.columns(3)
+
+ with cols[0]:
+ st.markdown("#### Test Set")
+ st.caption("Held-out test set")
+ fig_test = plot_regression_scatter(
+ test_data["y"],
+ test_data["predicted"],
+ title="Test Set",
+ )
+ st.plotly_chart(fig_test, use_container_width=True)
+
+ with cols[1]:
+ st.markdown("#### Training Set")
+ st.caption("Training set")
+ fig_train = plot_regression_scatter(
+ train_data["y"],
+ train_data["predicted"],
+ title="Training Set",
+ )
+ st.plotly_chart(fig_train, use_container_width=True)
+
+ with cols[2]:
+ st.markdown("#### Combined")
+ st.caption("Train + Test sets")
+ fig_combined = plot_regression_scatter(
+ merged["y"],
+ merged["predicted"],
+ title="Combined",
+ )
+ st.plotly_chart(fig_combined, use_container_width=True)
+
+ # Display residual plots
+ st.subheader("Residual Analysis")
+ st.caption("Residual plots to assess model fit and identify patterns in errors")
+
+ cols = st.columns(3)
+
+ with cols[0]:
+ fig_test_res = plot_residuals(test_data["y"], test_data["predicted"], title="Test Set Residuals")
+ st.plotly_chart(fig_test_res, use_container_width=True)
+
+ with cols[1]:
+ fig_train_res = plot_residuals(train_data["y"], train_data["predicted"], title="Training Set Residuals")
+ st.plotly_chart(fig_train_res, use_container_width=True)
+
+ with cols[2]:
+ fig_combined_res = plot_residuals(merged["y"], merged["predicted"], title="Combined Residuals")
+ st.plotly_chart(fig_combined_res, use_container_width=True)
diff --git a/src/entropice/dashboard/utils/class_ordering.py b/src/entropice/dashboard/utils/class_ordering.py
deleted file mode 100644
index 79fd6f4..0000000
--- a/src/entropice/dashboard/utils/class_ordering.py
+++ /dev/null
@@ -1,70 +0,0 @@
-"""Utilities for ordering predicted classes consistently across visualizations.
-
-This module leverages the canonical class labels defined in the ML dataset module
-to ensure consistent ordering across all visualizations.
-"""
-
-import pandas as pd
-
-from entropice.utils.types import Task
-
-# Canonical orderings imported from the ML pipeline
-# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"}
-# Count/Density labels are defined in the bin_values function
-BINARY_LABELS = ["No RTS", "RTS"]
-COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"]
-DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"]
-
-CLASS_ORDERINGS: dict[Task | str, list[str]] = {
- "binary": BINARY_LABELS,
- "count": COUNT_LABELS,
- "density": DENSITY_LABELS,
-}
-
-
-def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]:
- """Get properly ordered class labels for a given task.
-
- This uses the same canonical ordering as defined in the ML dataset module,
- ensuring consistency between training and inference visualizations.
-
- Args:
- task: Task type ('binary', 'count', 'density').
- available_classes: Optional list of available classes to filter and order.
- If None, returns all canonical classes for the task.
-
- Returns:
- List of class labels in proper order.
-
- Examples:
- >>> get_ordered_classes("binary")
- ['No RTS', 'RTS']
- >>> get_ordered_classes("count", ["None", "Few", "Several"])
- ['None', 'Few', 'Several']
-
- """
- canonical_order = CLASS_ORDERINGS[task]
-
- if available_classes is None:
- return canonical_order
-
- # Filter canonical order to only include available classes, preserving order
- return [cls for cls in canonical_order if cls in available_classes]
-
-
-def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series:
- """Sort a pandas Series with class labels according to canonical ordering.
-
- Args:
- series: Pandas Series with class labels as index.
- task: Task type ('binary', 'count', 'density').
-
- Returns:
- Sorted Series with classes in canonical order.
-
- """
- available_classes = series.index.tolist()
- ordered_classes = get_ordered_classes(task, available_classes)
-
- # Reindex to get proper order
- return series.reindex(ordered_classes)
diff --git a/src/entropice/dashboard/utils/formatters.py b/src/entropice/dashboard/utils/formatters.py
index bad2965..4408314 100644
--- a/src/entropice/dashboard/utils/formatters.py
+++ b/src/entropice/dashboard/utils/formatters.py
@@ -59,7 +59,7 @@ task_display_infos: dict[Task, TaskDisplayInfo] = {
class TrainingResultDisplayInfo:
task: Task
target: TargetDataset
- model: Model
+ model: Model | Literal["autogluon"]
grid: Grid
level: int
timestamp: datetime
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index fe34802..a0e6fbe 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -17,6 +17,7 @@ from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
+from entropice.ml.autogluon_training import AutoGluonTrainingSettings
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.ml.training import TrainingSettings
from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks
@@ -215,14 +216,18 @@ class TrainingResult:
return pd.DataFrame.from_records(records)
-@st.cache_data
+@st.cache_data(ttl=300) # Cache for 5 minutes
def load_all_training_results() -> list[TrainingResult]:
"""Load all training results from the results directory."""
results_dir = entropice.utils.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
+ incomplete_results: list[tuple[Path, Exception]] = []
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
+ # Skip AutoGluon results directory
+ if "autogluon" in result_path.name.lower():
+ continue
try:
training_result = TrainingResult.from_path(result_path)
training_results.append(training_result)
@@ -237,10 +242,159 @@ def load_all_training_results() -> list[TrainingResult]:
training_results.append(training_result)
is_experiment_dir = True
except FileNotFoundError as e2:
- st.warning(f"Skipping incomplete training result: {e2}")
+ incomplete_results.append((experiment_path, e2))
if not is_experiment_dir:
- st.warning(f"Skipping incomplete training result: {e}")
+ incomplete_results.append((result_path, e))
+ if len(incomplete_results) > 0:
+ st.warning(
+ f"Found {len(incomplete_results)} incomplete training results that were skipped:\n - "
+ + "\n - ".join(f"{p}: {e}" for p, e in incomplete_results)
+ )
+ # Sort by creation time (most recent first)
+ training_results.sort(key=lambda tr: tr.created_at, reverse=True)
+ return training_results
+
+
+@dataclass
+class AutogluonTrainingResult:
+ """Wrapper for training result data and metadata."""
+
+ path: Path
+ experiment: str
+ settings: AutoGluonTrainingSettings
+ test_metrics: dict[str, float | dict | pd.DataFrame]
+ leaderboard: pd.DataFrame
+ feature_importance: pd.DataFrame | None
+ created_at: float
+ files: list[Path]
+
+ @classmethod
+ def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "AutogluonTrainingResult":
+ """Load an AutogluonTrainingResult from a given result directory path."""
+ settings_file = result_path / "training_settings.toml"
+ metrics_file = result_path / "test_metrics.pickle"
+ leaderboard_file = result_path / "leaderboard.parquet"
+ feature_importance_file = result_path / "feature_importance.parquet"
+ all_files = list(result_path.iterdir())
+ if not settings_file.exists():
+ raise FileNotFoundError(f"Missing settings file in {result_path}")
+ if not metrics_file.exists():
+ raise FileNotFoundError(f"Missing metrics file in {result_path}")
+ if not leaderboard_file.exists():
+ raise FileNotFoundError(f"Missing leaderboard file in {result_path}")
+
+ created_at = result_path.stat().st_ctime
+ settings_dict = toml.load(settings_file)["settings"]
+ settings = AutoGluonTrainingSettings(**settings_dict)
+ with open(metrics_file, "rb") as f:
+ metrics = pickle.load(f)
+ leaderboard = pd.read_parquet(leaderboard_file)
+
+ if feature_importance_file.exists():
+ feature_importance = pd.read_parquet(feature_importance_file)
+ else:
+ feature_importance = None
+
+ return cls(
+ path=result_path,
+ experiment=experiment_name or "N/A",
+ settings=settings,
+ test_metrics=metrics,
+ leaderboard=leaderboard,
+ feature_importance=feature_importance,
+ created_at=created_at,
+ files=all_files,
+ )
+
+ @property
+ def test_confusion_matrix(self) -> pd.DataFrame | None:
+ """Get the test confusion matrix."""
+ if "confusion_matrix" not in self.test_metrics:
+ return None
+ assert isinstance(self.test_metrics["confusion_matrix"], pd.DataFrame)
+ return self.test_metrics["confusion_matrix"]
+
+ @property
+ def display_info(self) -> TrainingResultDisplayInfo:
+ """Get display information for the training result."""
+ return TrainingResultDisplayInfo(
+ task=self.settings.task,
+ target=self.settings.target,
+ model="autogluon",
+ grid=self.settings.grid,
+ level=self.settings.level,
+ timestamp=datetime.fromtimestamp(self.created_at),
+ )
+
+ def _get_best_metric_name(self) -> str:
+ """Get the primary metric name for a given task."""
+ match self.settings.task:
+ case "binary":
+ return "f1"
+ case "count_regimes" | "density_regimes":
+ return "f1_weighted"
+ case _: # regression tasks
+ return "root_mean_squared_error"
+
+ @staticmethod
+ def to_dataframe(training_results: list["AutogluonTrainingResult"]) -> pd.DataFrame:
+ """Convert a list of AutogluonTrainingResult objects to a DataFrame for display."""
+ records = []
+ for tr in training_results:
+ info = tr.display_info
+ best_metric_name = tr._get_best_metric_name()
+
+ record = {
+ "Experiment": tr.experiment if tr.experiment else "N/A",
+ "Task": info.task,
+ "Target": info.target,
+ "Model": info.model,
+ "Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name,
+ "Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"),
+ "Score-Metric": best_metric_name.title(),
+ "Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name),
+ "Path": str(tr.path.name),
+ }
+ records.append(record)
+ return pd.DataFrame.from_records(records)
+
+
+@st.cache_data(ttl=300) # Cache for 5 minutes
+def load_all_autogluon_training_results() -> list[AutogluonTrainingResult]:
+ """Load all training results from the results directory."""
+ results_dir = entropice.utils.paths.RESULTS_DIR
+ training_results: list[AutogluonTrainingResult] = []
+ incomplete_results: list[tuple[Path, Exception]] = []
+ for result_path in results_dir.iterdir():
+ if not result_path.is_dir():
+ continue
+ # Skip AutoGluon results directory
+ if "autogluon" not in result_path.name.lower():
+ continue
+ try:
+ training_result = AutogluonTrainingResult.from_path(result_path)
+ training_results.append(training_result)
+ except FileNotFoundError as e:
+ is_experiment_dir = False
+ for experiment_path in result_path.iterdir():
+ if not experiment_path.is_dir():
+ continue
+ try:
+ experiment_name = experiment_path.parent.name
+ training_result = AutogluonTrainingResult.from_path(experiment_path, experiment_name)
+ training_results.append(training_result)
+ is_experiment_dir = True
+ except FileNotFoundError as e2:
+ incomplete_results.append((experiment_path, e2))
+ if not is_experiment_dir:
+ incomplete_results.append((result_path, e))
+
+ if len(incomplete_results) > 0:
+ st.warning(
+ f"Found {len(incomplete_results)} incomplete autogluon training results that were skipped:\n - "
+ + "\n - ".join(f"{p}: {e}" for p, e in incomplete_results)
+ )
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py
index f589fd6..af2e1e4 100644
--- a/src/entropice/dashboard/views/model_state_page.py
+++ b/src/entropice/dashboard/views/model_state_page.py
@@ -369,6 +369,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result: Trainin
options=["gain", "weight", "cover", "total_gain", "total_cover"],
index=0,
help="Choose which importance metric to visualize",
+ key="model_state_importance_type",
)
# Top N slider
diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py
index dded4fd..e594cee 100644
--- a/src/entropice/dashboard/views/overview_page.py
+++ b/src/entropice/dashboard/views/overview_page.py
@@ -9,7 +9,7 @@ from entropice.dashboard.sections.experiment_results import (
render_training_results_summary,
)
from entropice.dashboard.sections.storage_statistics import render_storage_statistics
-from entropice.dashboard.utils.loaders import load_all_training_results
+from entropice.dashboard.utils.loaders import load_all_autogluon_training_results, load_all_training_results
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
@@ -27,6 +27,9 @@ def render_overview_page():
)
# Load training results
training_results = load_all_training_results()
+ autogluon_results = load_all_autogluon_training_results()
+ if len(autogluon_results) > 0:
+ training_results.extend(autogluon_results)
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
diff --git a/src/entropice/dashboard/views/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py
index b1ad430..4aba169 100644
--- a/src/entropice/dashboard/views/training_analysis_page.py
+++ b/src/entropice/dashboard/views/training_analysis_page.py
@@ -2,150 +2,22 @@
from typing import cast
-import geopandas as gpd
import streamlit as st
-import xarray as xr
-from stopuhr import stopwatch
-from entropice.dashboard.plots.hyperparameter_analysis import (
- render_binned_parameter_space,
- render_confusion_matrix_heatmap,
- render_confusion_matrix_map,
- render_espa_binned_parameter_space,
- render_multi_metric_comparison,
- render_parameter_correlation,
- render_parameter_distributions,
- render_performance_summary,
- render_top_configurations,
+from entropice.dashboard.sections.cv_result import (
+ render_confusion_matrices,
+ render_cv_statistics_section,
+ render_metrics_section,
+ render_run_information,
)
+from entropice.dashboard.sections.hparam_space import render_hparam_space_section
+from entropice.dashboard.sections.regression_analysis import render_regression_analysis
from entropice.dashboard.utils.formatters import format_metric_name
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
-from entropice.dashboard.utils.stats import CVResultsStatistics
-from entropice.utils.types import GridConfig
+from entropice.dashboard.utils.stats import CVMetricStatistics
-def load_predictions_with_labels(selected_result: TrainingResult) -> gpd.GeoDataFrame | None:
- """Load predictions and merge with training data to get true labels and split info.
-
- Args:
- selected_result: The selected TrainingResult object.
-
- Returns:
- GeoDataFrame with predictions, true labels, and split information, or None if unavailable.
-
- """
- from sklearn.model_selection import train_test_split
-
- from entropice.ml.dataset import DatasetEnsemble, bin_values, taskcol
-
- # Load predictions
- preds_gdf = selected_result.load_predictions()
- if preds_gdf is None:
- return None
-
- # Create a minimal dataset ensemble to access target data
- settings = selected_result.settings
- dataset_ensemble = DatasetEnsemble(
- grid=settings.grid,
- level=settings.level,
- target=settings.target,
- members=[], # No feature data needed, just targets
- )
-
- # Load target dataset (just labels, no features)
- with st.spinner("Loading target labels..."):
- targets = dataset_ensemble._read_target()
-
- # Get coverage and task columns
- task_col = taskcol[settings.task][settings.target]
-
- # Filter for valid labels (same as in _cat_and_split)
- valid_labels = targets[task_col].notna()
- filtered_targets = targets.loc[valid_labels].copy()
-
- # Apply binning to get class labels (same logic as _cat_and_split)
- if settings.task == "binary":
- binned = filtered_targets[task_col].map({False: "No RTS", True: "RTS"}).astype("category")
- elif settings.task == "count":
- binned = bin_values(filtered_targets[task_col].astype(int), task=settings.task)
- elif settings.task == "density":
- binned = bin_values(filtered_targets[task_col], task=settings.task)
- else:
- raise ValueError(f"Invalid task: {settings.task}")
-
- filtered_targets["true_class"] = binned.to_numpy()
-
- # Recreate the train/test split deterministically (same random_state=42 as in _cat_and_split)
- _train_idx, test_idx = train_test_split(
- filtered_targets.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True
- )
- filtered_targets["split"] = "train"
- filtered_targets.loc[test_idx, "split"] = "test"
- filtered_targets["split"] = filtered_targets["split"].astype("category")
-
- # Ensure cell_id is available as a column for merging
- # Check if cell_id already exists, otherwise use the index
- if "cell_id" not in filtered_targets.columns:
- filtered_targets = filtered_targets.reset_index().rename(columns={"index": "cell_id"})
-
- # Merge predictions with labels (inner join to keep only cells with predictions)
- merged = filtered_targets.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
- merged_gdf = gpd.GeoDataFrame(merged, geometry="geometry", crs=targets.crs)
-
- return merged_gdf
-
-
-def compute_confusion_matrix_from_merged_data(
- merged_data: gpd.GeoDataFrame,
- split_type: str,
- label_names: list[str],
-) -> xr.DataArray | None:
- """Compute confusion matrix from merged predictions and labels.
-
- Args:
- merged_data: GeoDataFrame with 'true_class', 'predicted_class', and 'split' columns.
- split_type: One of 'test', 'train', or 'all'.
- label_names: List of class label names in order.
-
- Returns:
- xarray.DataArray with confusion matrix or None if data unavailable.
-
- """
- from sklearn.metrics import confusion_matrix
-
- # Filter by split type
- if split_type == "train":
- data = merged_data[merged_data["split"] == "train"]
- elif split_type == "test":
- data = merged_data[merged_data["split"] == "test"]
- elif split_type == "all":
- data = merged_data
- else:
- raise ValueError(f"Invalid split_type: {split_type}")
-
- if len(data) == 0:
- st.warning(f"No data available for {split_type} split.")
- return None
-
- # Get true and predicted labels
- y_true = data["true_class"].to_numpy()
- y_pred = data["predicted_class"].to_numpy()
-
- # Compute confusion matrix
- cm = confusion_matrix(y_true, y_pred, labels=label_names)
-
- # Create xarray DataArray
- cm_xr = xr.DataArray(
- cm,
- dims=["true_label", "predicted_label"],
- coords={"true_label": label_names, "predicted_label": label_names},
- name="confusion_matrix",
- )
-
- return cm_xr
-
-
-def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
+def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str]:
"""Render sidebar for training run and analysis settings selection.
Args:
@@ -155,351 +27,63 @@ def render_analysis_settings_sidebar(training_results: list[TrainingResult]) ->
Tuple of (selected_result, selected_metric, refit_metric, top_n).
"""
- st.header("Select Training Run")
+ with st.sidebar.form("training_analysis_settings_form"):
+ st.header("Select Training Run")
- # Create selection options with task-first naming
- training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
+ # Create selection options with task-first naming
+ training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
- selected_name = st.selectbox(
- "Training Run",
- options=list(training_options.keys()),
- index=0,
- help="Select a training run to analyze",
- key="training_run_select",
- )
+ selected_name = st.selectbox(
+ "Training Run",
+ options=list(training_options.keys()),
+ index=0,
+ help="Select a training run to analyze",
+ key="training_run_select",
+ )
- selected_result = cast(TrainingResult, training_options[selected_name])
+ selected_result = cast(TrainingResult, training_options[selected_name])
- st.divider()
-
- # Metric selection for detailed analysis
- st.subheader("Analysis Settings")
-
- available_metrics = selected_result.available_metrics
-
- # Try to get refit metric from settings
- refit_metric = "f1" if selected_result.settings.task == "binary" else "f1_weighted"
-
- if refit_metric in available_metrics:
- default_metric_idx = available_metrics.index(refit_metric)
- else:
- default_metric_idx = 0
-
- selected_metric = st.selectbox(
- "Primary Metric for Analysis",
- options=available_metrics,
- index=default_metric_idx,
- format_func=format_metric_name,
- help="Select the metric to focus on for detailed analysis",
- key="metric_select",
- )
-
- # Top N configurations
- top_n = st.slider(
- "Top N Configurations",
- min_value=5,
- max_value=50,
- value=10,
- step=5,
- help="Number of top configurations to display",
- key="top_n_slider",
- )
-
- return selected_result, selected_metric, refit_metric, top_n
-
-
-def render_run_information(selected_result: TrainingResult, refit_metric):
- """Render training run configuration overview.
-
- Args:
- selected_result: The selected TrainingResult object.
- refit_metric: The refit metric used for model selection.
-
- """
- st.header("📋 Run Information")
-
- grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type]
-
- col1, col2, col3, col4, col5 = st.columns(5)
- with col1:
- st.metric("Task", selected_result.settings.task.capitalize())
- with col2:
- st.metric("Target", selected_result.settings.target.capitalize())
- with col3:
- st.metric("Grid", grid_config.display_name)
- with col4:
- st.metric("Model", selected_result.settings.model.upper())
- with col5:
- st.metric("Trials", len(selected_result.results))
-
- st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
-
-
-def render_test_metrics_section(selected_result: TrainingResult):
- """Render test metrics overview showing final model performance.
-
- Args:
- selected_result: The selected TrainingResult object.
-
- """
- st.header("🎯 Test Set Performance")
- st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)")
-
- test_metrics = selected_result.metrics
-
- if not test_metrics:
- st.warning("No test metrics available for this training run.")
- return
-
- # Display metrics in columns based on task type
- task = selected_result.settings.task
-
- if task == "binary":
- # Binary classification metrics
- col1, col2, col3, col4, col5 = st.columns(5)
-
- with col1:
- st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
- with col2:
- st.metric("F1 Score", f"{test_metrics.get('f1', 0):.4f}")
- with col3:
- st.metric("Precision", f"{test_metrics.get('precision', 0):.4f}")
- with col4:
- st.metric("Recall", f"{test_metrics.get('recall', 0):.4f}")
- with col5:
- st.metric("Jaccard", f"{test_metrics.get('jaccard', 0):.4f}")
- else:
- # Multiclass metrics
- col1, col2, col3 = st.columns(3)
-
- with col1:
- st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
- with col2:
- st.metric("F1 (Macro)", f"{test_metrics.get('f1_macro', 0):.4f}")
- with col3:
- st.metric("F1 (Weighted)", f"{test_metrics.get('f1_weighted', 0):.4f}")
-
- col4, col5, col6 = st.columns(3)
-
- with col4:
- st.metric("Precision (Macro)", f"{test_metrics.get('precision_macro', 0):.4f}")
- with col5:
- st.metric("Precision (Weighted)", f"{test_metrics.get('precision_weighted', 0):.4f}")
- with col6:
- st.metric("Recall (Macro)", f"{test_metrics.get('recall_macro', 0):.4f}")
-
- col7, col8, col9 = st.columns(3)
-
- with col7:
- st.metric("Jaccard (Micro)", f"{test_metrics.get('jaccard_micro', 0):.4f}")
- with col8:
- st.metric("Jaccard (Macro)", f"{test_metrics.get('jaccard_macro', 0):.4f}")
- with col9:
- st.metric("Jaccard (Weighted)", f"{test_metrics.get('jaccard_weighted', 0):.4f}")
-
-
-def render_cv_statistics_section(selected_result, selected_metric):
- """Render cross-validation statistics for selected metric.
-
- Args:
- selected_result: The selected TrainingResult object.
- selected_metric: The metric to display statistics for.
-
- """
- st.header("📈 Cross-Validation Statistics")
- st.caption("Performance during hyperparameter search (averaged across CV folds)")
-
- from entropice.dashboard.utils.stats import CVMetricStatistics
-
- cv_stats = CVMetricStatistics.compute(selected_result, selected_metric)
-
- col1, col2, col3, col4, col5 = st.columns(5)
-
- with col1:
- st.metric("Best Score", f"{cv_stats.best_score:.4f}")
-
- with col2:
- st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
-
- with col3:
- st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
-
- with col4:
- st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
-
- with col5:
- st.metric("Median Score", f"{cv_stats.median_score:.4f}")
-
- if cv_stats.mean_cv_std is not None:
- st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
-
- # Compare with test metric if available
- if selected_metric in selected_result.metrics:
- test_score = selected_result.metrics[selected_metric]
st.divider()
- st.subheader("CV vs Test Performance")
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric("Best CV Score", f"{cv_stats.best_score:.4f}")
- with col2:
- st.metric("Test Score", f"{test_score:.4f}")
- with col3:
- delta = test_score - cv_stats.best_score
- delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0
- st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%")
+ # Metric selection for detailed analysis
+ st.subheader("Analysis Settings")
- if abs(delta) > cv_stats.std_score:
- st.warning(
- "⚠️ Test performance differs significantly from CV performance. "
- "This may indicate overfitting or data distribution mismatch."
- )
+ available_metrics = selected_result.available_metrics
-
-@st.fragment
-def render_confusion_matrix_section(selected_result: TrainingResult, merged_predictions: gpd.GeoDataFrame | None):
- """Render confusion matrix visualization and analysis.
-
- Args:
- selected_result: The selected TrainingResult object.
- merged_predictions: GeoDataFrame with predictions merged with true labels and split info.
-
- """
- st.header("🎲 Confusion Matrix")
- st.caption("Detailed breakdown of predictions")
-
- # Add selector for confusion matrix type
- cm_type = st.selectbox(
- "Select Data Split",
- options=["test", "train", "all"],
- format_func=lambda x: {"test": "Test Set", "train": "CV Set (Train Split)", "all": "All Available Data"}[x],
- help="Choose which data split to display the confusion matrix for",
- key="cm_split_select",
- )
-
- # Get label names from settings
- label_names = selected_result.settings.classes
-
- # Compute or load confusion matrix based on selection
- if cm_type == "test":
- if selected_result.confusion_matrix is None:
- st.warning("No confusion matrix available for the test set.")
- return
- cm = selected_result.confusion_matrix
- st.info("📊 Showing confusion matrix for the **Test Set** (held-out data, never used during training)")
- else:
- if merged_predictions is None:
- st.warning("Predictions data not available. Cannot compute confusion matrix.")
- return
-
- with st.spinner(f"Computing confusion matrix for {cm_type} split..."):
- cm = compute_confusion_matrix_from_merged_data(merged_predictions, cm_type, label_names)
- if cm is None:
- return
-
- if cm_type == "train":
- st.info(
- "📊 Showing confusion matrix for the **CV Set (Train Split)** "
- "(data used during hyperparameter search cross-validation)"
- )
- else: # all
- st.info("📊 Showing confusion matrix for **All Available Data** (combined train and test splits)")
-
- render_confusion_matrix_heatmap(cm, selected_result.settings.task)
-
-
-def render_parameter_space_section(selected_result, selected_metric):
- """Render parameter space analysis section.
-
- Args:
- selected_result: The selected TrainingResult object.
- selected_metric: The metric to analyze parameters against.
-
- """
- st.header("🔍 Parameter Space Analysis")
-
- # Compute CV results statistics
- cv_results_stats = CVResultsStatistics.compute(selected_result)
-
- # Show parameter space summary
- with st.expander("📋 Parameter Space Summary", expanded=False):
- param_summary_df = cv_results_stats.parameters_to_dataframe()
- if not param_summary_df.empty:
- st.dataframe(param_summary_df, hide_index=True, width="stretch")
+ # Try to get refit metric from settings
+ if selected_result.settings.task == "binary":
+ refit_metric = "f1"
+ elif selected_result.settings.task in ["count_regimes", "density_regimes"]:
+ refit_metric = "f1_weighted"
else:
- st.info("No parameter information available.")
+ refit_metric = "r2"
- results = selected_result.results
- settings = selected_result.settings
+ if refit_metric in available_metrics:
+ default_metric_idx = available_metrics.index(refit_metric)
+ else:
+ default_metric_idx = 0
- # Parameter distributions
- st.subheader("📈 Parameter Distributions")
- render_parameter_distributions(results, settings)
+ selected_metric = st.selectbox(
+ "Primary Metric for Analysis",
+ options=available_metrics,
+ index=default_metric_idx,
+ format_func=format_metric_name,
+ help="Select the metric to focus on for detailed analysis",
+ key="metric_select",
+ )
- # Binned parameter space plots
- st.subheader("🎨 Binned Parameter Space")
+ # Form submit button
+ submitted = st.form_submit_button(
+ "Load Training Result",
+ type="primary",
+ use_container_width=True,
+ )
- # Check if this is an ESPA model and show ESPA-specific plots
- model_type = settings.model
- if model_type == "espa":
- # Show ESPA-specific binned plots (eps_cl vs eps_e binned by K)
- render_espa_binned_parameter_space(results, selected_metric)
+ if not submitted:
+ st.info("👆 Click 'Load Training Result' to apply changes.")
+ st.stop()
- # Optionally show the generic binned plots in an expander
- with st.expander("📊 All Parameter Combinations", expanded=False):
- st.caption("Generic parameter space exploration (all pairwise combinations)")
- render_binned_parameter_space(results, selected_metric)
- else:
- # For non-ESPA models, show the generic binned plots
- render_binned_parameter_space(results, selected_metric)
-
-
-def render_data_export_section(results, selected_result):
- """Render data export section with download buttons.
-
- Args:
- results: DataFrame with CV results.
- selected_result: The selected TrainingResult object.
-
- """
- with st.expander("💾 Export Data", expanded=False):
- st.subheader("Download Results")
-
- col1, col2 = st.columns(2)
-
- with col1:
- # Download full results as CSV
- csv_data = results.to_csv(index=False)
- st.download_button(
- label="📥 Download Full Results (CSV)",
- data=csv_data,
- file_name=f"{selected_result.path.name}_results.csv",
- mime="text/csv",
- )
-
- with col2:
- # Download settings as JSON
- import json
-
- settings_dict = {
- "task": selected_result.settings.task,
- "grid": selected_result.settings.grid,
- "level": selected_result.settings.level,
- "model": selected_result.settings.model,
- "cv_splits": selected_result.settings.cv_splits,
- "classes": selected_result.settings.classes,
- }
- settings_json = json.dumps(settings_dict, indent=2)
- st.download_button(
- label="⚙️ Download Settings (JSON)",
- data=settings_json,
- file_name=f"{selected_result.path.name}_settings.json",
- mime="application/json",
- )
-
- # Show raw data preview
- st.subheader("Raw Data Preview")
- st.dataframe(results.head(100), width="stretch")
+ return selected_result, selected_metric, refit_metric
def render_training_analysis_page():
@@ -513,91 +97,47 @@ def render_training_analysis_page():
"""
)
- # Load all available training results
+ # Load training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
- st.info("Run training using: `pixi run python -m entropice.ml.training`")
+ st.stop()
return
- st.success(f"Found **{len(training_results)}** training result(s)")
+ st.write(f"Found **{len(training_results)}** training result(s)")
st.divider()
+ selected_result, selected_metric, refit_metric = render_analysis_settings_sidebar(training_results)
- # Sidebar: Training run selection
- with st.sidebar:
- selection_result = render_analysis_settings_sidebar(training_results)
- if selection_result[0] is None:
- return
- selected_result, selected_metric, refit_metric, top_n = selection_result
+ cv_statistics = CVMetricStatistics.compute(selected_result, selected_metric)
- # Load predictions with labels once (used by confusion matrix and map)
- merged_predictions = load_predictions_with_labels(selected_result)
-
- # Main content area
- results = selected_result.results
- settings = selected_result.settings
-
- # Run Information
render_run_information(selected_result, refit_metric)
st.divider()
- # Test Metrics Section
- render_test_metrics_section(selected_result)
+ render_metrics_section(selected_result)
st.divider()
- # Confusion Matrix Section
- render_confusion_matrix_section(selected_result, merged_predictions)
+ # Render confusion matrices for classification, regression analysis for regression
+ if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]:
+ render_confusion_matrices(selected_result)
+ else:
+ render_regression_analysis(selected_result)
st.divider()
- # Performance Summary Section
- st.header("📊 CV Performance Overview")
- st.caption("Summary of hyperparameter search results across all configurations")
- render_performance_summary(results, refit_metric)
+ render_cv_statistics_section(cv_statistics, selected_result.test_metrics.get(selected_metric, float("nan")))
st.divider()
- # Prediction Analysis Map Section
- st.header("🗺️ Model Performance Map")
- st.caption("Interactive 3D map showing prediction correctness across the training dataset")
- render_confusion_matrix_map(selected_result.path, settings, merged_predictions)
+ render_hparam_space_section(selected_result, selected_metric)
st.divider()
- # Cross-Validation Statistics
- render_cv_statistics_section(selected_result, selected_metric)
-
- st.divider()
-
- # Parameter Space Analysis
- render_parameter_space_section(selected_result, selected_metric)
-
- st.divider()
-
- # Parameter Correlation
- st.header("🔗 Parameter Correlation")
- render_parameter_correlation(results, selected_metric)
-
- st.divider()
-
- # Multi-Metric Comparison
- if len(selected_result.available_metrics) >= 2:
- st.header("📊 Multi-Metric Comparison")
- render_multi_metric_comparison(results)
- st.divider()
-
- # Top Configurations
- st.header("🏆 Top Performing Configurations")
- render_top_configurations(results, selected_metric, top_n)
-
- st.divider()
-
- # Raw Data Export
- render_data_export_section(results, selected_result)
+ # List all results at the end
+ st.header("📄 All Training Results")
+ st.dataframe(selected_result.results)
st.balloons()
- stopwatch.summary()
diff --git a/src/entropice/ml/autogluon_training.py b/src/entropice/ml/autogluon_training.py
index 69ba846..ed33a67 100644
--- a/src/entropice/ml/autogluon_training.py
+++ b/src/entropice/ml/autogluon_training.py
@@ -44,8 +44,8 @@ class AutoGluonSettings:
class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings):
"""Combined settings for AutoGluon training."""
- classes: list[str] | None
- problem_type: str
+ classes: list[str] | None = None
+ problem_type: str = "binary"
def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]:
@@ -177,6 +177,8 @@ def autogluon_train(
toml.dump({"settings": asdict(combined_settings)}, f)
# Save test metrics
+ # We need to use pickle here, because the confusion matrix is stored as a dataframe
+ # This only matters for classification tasks
test_metrics_file = results_dir / "test_metrics.pickle"
print(f"💾 Saving test metrics to {test_metrics_file}")
with open(test_metrics_file, "wb") as f: