diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 85e3671..dd8e508 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -12,7 +12,6 @@ Pages: import streamlit as st -from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page from entropice.dashboard.views.dataset_page import render_dataset_page from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.model_state_page import render_model_state_page @@ -28,7 +27,6 @@ def main(): overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) data_page = st.Page(render_dataset_page, title="Dataset", icon="📊") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") - autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖") model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") @@ -36,7 +34,7 @@ def main(): { "Overview": [overview_page], "Data": [data_page], - "Experiments": [training_analysis_page, autogluon_page, model_state_page], + "Experiments": [training_analysis_page, model_state_page], "Inference": [inference_page], } ) diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py deleted file mode 100644 index 9ada355..0000000 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ /dev/null @@ -1,1591 +0,0 @@ -"""Hyperparameter analysis plotting functions for RandomizedSearchCV results.""" - -from pathlib import Path - -import altair as alt -import geopandas as gpd -import matplotlib.colors as mcolors -import numpy as np -import pandas as pd -import pydeck as pdk -import streamlit as st -import xarray as xr - -from entropice.dashboard.utils.class_ordering import get_ordered_classes -from entropice.dashboard.utils.colors import get_cmap, get_palette -from entropice.dashboard.utils.geometry import fix_hex_geometry -from entropice.ml.training import TrainingSettings - - -def render_performance_summary(results: pd.DataFrame, refit_metric: str): - """Render summary statistics of model performance. - - Args: - results: DataFrame with CV results. - refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted'). - - """ - # Get all test score columns - score_cols = [col for col in results.columns if col.startswith("mean_test_")] - - if not score_cols: - st.warning("No test score columns found in results.") - return - - # Calculate statistics for each metric - col1, col2 = st.columns(2) - - with col1: - st.markdown("#### Best Scores") - best_scores = [] - for col in score_cols: - metric_name = col.replace("mean_test_", "").replace("_", " ").title() - best_score = results[col].max() - best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"}) - - st.dataframe(pd.DataFrame(best_scores), hide_index=True, width="stretch") - - with col2: - st.markdown("#### Score Statistics") - score_stats = [] - for col in score_cols: - metric_name = col.replace("mean_test_", "").replace("_", " ").title() - mean_score = results[col].mean() - std_score = results[col].std() - score_stats.append( - { - "Metric": metric_name, - "Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}", - } - ) - - st.dataframe(pd.DataFrame(score_stats), hide_index=True, width="stretch") - - # Show best parameter combination in a cleaner format (similar to old dashboard) - st.markdown("#### 🏆 Best Parameter Combination") - refit_col = f"mean_test_{refit_metric}" - - # Check if refit metric exists in results - if refit_col not in results.columns: - available_metrics = [col.replace("mean_test_", "") for col in score_cols] - st.warning(f"Refit metric '{refit_metric}' not found in results. Available metrics: {available_metrics}") - # Use the first available metric as fallback - refit_col = score_cols[0] - refit_metric = refit_col.replace("mean_test_", "") - st.info(f"Using '{refit_metric}' as fallback metric.") - - best_idx = results[refit_col].idxmax() - best_row = results.loc[best_idx] - - # Extract parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - - if param_cols: - best_params = {col.replace("param_", ""): best_row[col] for col in param_cols} - - # Display in a container with metrics (similar to old dashboard style) - with st.container(border=True): - st.caption(f"Parameters of the best model (selected by {refit_metric.replace('_', ' ').title()} score)") - - # Display parameters as metrics - n_params = len(best_params) - cols = st.columns(n_params) - - for idx, (param_name, param_value) in enumerate(best_params.items()): - with cols[idx]: - # Format value based on type and magnitude - if isinstance(param_value, (int, np.integer)): - formatted_value = f"{param_value:.0f}" - elif isinstance(param_value, (float, np.floating)): - # Use scientific notation for very small numbers - if abs(param_value) < 0.01: - formatted_value = f"{param_value:.2e}" - else: - formatted_value = f"{param_value:.4f}" - else: - formatted_value = str(param_value) - - st.metric(param_name, formatted_value) - - # Show all metrics for the best model - st.divider() - st.caption("Performance across all metrics") - - # Get all metrics from score_cols - all_metrics = [col.replace("mean_test_", "") for col in score_cols] - metric_cols = st.columns(len(all_metrics)) - - for idx, metric in enumerate(all_metrics): - with metric_cols[idx]: - best_score = results.loc[best_idx, f"mean_test_{metric}"] - best_std = results.loc[best_idx, f"std_test_{metric}"] - st.metric( - metric.replace("_", " ").title(), - f"{best_score:.4f}", - delta=f"±{best_std:.4f}", - help="Mean ± std across cross-validation folds", - ) - - -def render_parameter_distributions(results: pd.DataFrame, settings: TrainingSettings | None = None): - """Render histograms of parameter distributions explored. - - Args: - results: DataFrame with CV results. - settings: Optional settings dictionary containing param_grid configuration. - - """ - # Get parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - - if not param_cols: - st.warning("No parameter columns found in results.") - return - - # Extract scale information from settings if available - param_scales = {} - if settings and hasattr(settings, "param_grid"): - param_grid = settings.param_grid - for param_name, param_config in param_grid.items(): - if isinstance(param_config, dict) and "distribution" in param_config: - # loguniform distribution indicates log scale - if param_config["distribution"] == "loguniform": - param_scales[param_name] = "log" - else: - param_scales[param_name] = "linear" - else: - param_scales[param_name] = "linear" - - # Get colormap from colors module - cmap = get_cmap("parameter_distribution") - bar_color = mcolors.rgb2hex(cmap(0.5)) - - # Create histograms for each parameter - n_params = len(param_cols) - n_cols = min(3, n_params) - n_rows = (n_params + n_cols - 1) // n_cols - - for row in range(n_rows): - cols = st.columns(n_cols) - for col_idx in range(n_cols): - param_idx = row * n_cols + col_idx - if param_idx >= n_params: - break - - param_col = param_cols[param_idx] - param_name = param_col.replace("param_", "") - - with cols[col_idx]: - # Check if parameter is numeric or categorical - param_values = results[param_col].dropna() - - if pd.api.types.is_numeric_dtype(param_values): - # Check if we have enough unique values for a histogram - n_unique = param_values.nunique() - - if n_unique == 1: - # Only one unique value - show as text - value = param_values.iloc[0] - formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}" - st.metric(param_name, formatted) - st.caption(f"All {len(param_values)} samples have the same value") - continue - - # Numeric parameter - use histogram - df_plot = pd.DataFrame({param_name: param_values.to_numpy()}) - - # Use log scale if the range spans multiple orders of magnitude OR values are very small - max_val = param_values.max() - - # Determine number of bins based on unique values - n_bins = min(20, max(5, n_unique)) - - # Determine x-axis scale from config or infer from data - use_log_scale = param_scales.get(param_name, "linear") == "log" - if not use_log_scale: - # Fall back to automatic detection if not in config - value_range = max_val / (param_values.min() + 1e-10) - use_log_scale = value_range > 100 or max_val < 0.01 - - # For very small values or large ranges, use a simple bar chart instead of histogram - if n_unique <= 10: - # Few unique values - use bar chart - value_counts = param_values.value_counts().reset_index() - value_counts.columns = [param_name, "count"] - value_counts = value_counts.sort_values(param_name) - - x_scale = alt.Scale(type="log") if use_log_scale else alt.Scale() - title_suffix = " (log scale)" if use_log_scale else "" - - # Format x-axis values - if max_val < 0.01: - formatted_col = f"{param_name}_formatted" - value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}") - chart = ( - alt.Chart(value_counts) - .mark_bar(color=bar_color) - .encode( - alt.X( - f"{formatted_col}:N", - title=param_name, - sort=None, - ), - alt.Y("count:Q", title="Count"), - tooltip=[ - alt.Tooltip(param_name, format=".2e"), - alt.Tooltip("count", title="Count"), - ], - ) - .properties(height=250, title=f"{param_name}{title_suffix}") - ) - else: - chart = ( - alt.Chart(value_counts) - .mark_bar(color=bar_color) - .encode( - alt.X( - f"{param_name}:Q", - title=param_name, - scale=x_scale, - ), - alt.Y("count:Q", title="Count"), - tooltip=[ - alt.Tooltip(param_name, format=".3f"), - alt.Tooltip("count", title="Count"), - ], - ) - .properties(height=250, title=f"{param_name}{title_suffix}") - ) - else: - # Many unique values - use binned histogram - title_suffix = " (log scale)" if use_log_scale else "" - - if use_log_scale: - # For log scale parameters, create bins in log space then transform back - # This gives better distribution visualization - log_values = np.log10(param_values.to_numpy()) - log_bins = np.linspace(log_values.min(), log_values.max(), n_bins + 1) - bins_linear = 10**log_bins - - # Manually bin the data - binned = pd.cut(param_values, bins=bins_linear) - bin_counts = binned.value_counts().sort_index() - - # Create dataframe for plotting - bin_data = [] - for interval, count in bin_counts.items(): - bin_mid = (interval.left + interval.right) / 2 - bin_data.append( - { - param_name: bin_mid, - "count": count, - "bin_label": f"{interval.left:.2e} - {interval.right:.2e}", - } - ) - - df_binned = pd.DataFrame(bin_data) - - chart = ( - alt.Chart(df_binned) - .mark_bar(color=bar_color) - .encode( - alt.X( - f"{param_name}:Q", - title=param_name, - scale=alt.Scale(type="log"), - axis=alt.Axis(format=".2e"), - ), - alt.Y("count:Q", title="Count"), - tooltip=[ - alt.Tooltip("bin_label:N", title="Range"), - alt.Tooltip("count:Q", title="Count"), - ], - ) - .properties(height=250, title=f"{param_name}{title_suffix}") - ) - else: - # Linear scale - use standard binning - format_str = ".2e" if max_val < 0.01 else ".3f" - chart = ( - alt.Chart(df_plot) - .mark_bar(color=bar_color) - .encode( - alt.X( - f"{param_name}:Q", - bin=alt.Bin(maxbins=n_bins), - title=param_name, - ), - alt.Y("count()", title="Count"), - tooltip=[ - alt.Tooltip( - f"{param_name}:Q", - format=format_str, - bin=True, - ), - "count()", - ], - ) - .properties(height=250, title=param_name) - ) - - st.altair_chart(chart, width="stretch") - - else: - # Categorical parameter - use bar chart - value_counts = param_values.value_counts().reset_index() - value_counts.columns = [param_name, "count"] - - chart = ( - alt.Chart(value_counts) - .mark_bar(color=bar_color) - .encode( - alt.X(param_name, title=param_name, sort="-y"), - alt.Y("count", title="Count"), - tooltip=[param_name, "count"], - ) - .properties(height=250, title=param_name) - ) - - st.altair_chart(chart, width="stretch") - - -def render_score_vs_parameter(results: pd.DataFrame, metric: str): - """Render scatter plots of score vs each parameter. - - Args: - results: DataFrame with CV results. - metric: The metric to plot (e.g., 'f1', 'accuracy'). - - """ - st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters") - - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Get parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - - if not param_cols: - st.warning("No parameter columns found in results.") - return - - # Create scatter plots for each parameter - n_params = len(param_cols) - n_cols = min(2, n_params) - n_rows = (n_params + n_cols - 1) // n_cols - - for row in range(n_rows): - cols = st.columns(n_cols) - for col_idx in range(n_cols): - param_idx = row * n_cols + col_idx - if param_idx >= n_params: - break - - param_col = param_cols[param_idx] - param_name = param_col.replace("param_", "") - - with cols[col_idx]: - param_values = results[param_col].dropna() - - if pd.api.types.is_numeric_dtype(param_values): - # Numeric parameter - scatter plot - df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]}) - - # Use log scale if needed - value_range = param_values.max() / (param_values.min() + 1e-10) - use_log = value_range > 100 - - if use_log: - chart = ( - alt.Chart(df_plot) - .mark_circle(size=60, opacity=0.6) - .encode( - alt.X( - param_name, - scale=alt.Scale(type="log"), - title=param_name, - ), - alt.Y(metric, title=metric.replace("_", " ").title()), - alt.Color( - metric, - scale=alt.Scale(range=get_palette(metric, n_colors=256)), - legend=None, - ), - tooltip=[ - alt.Tooltip(param_name, format=".2e"), - alt.Tooltip(metric, format=".4f"), - ], - ) - .properties( - height=300, - title=f"{metric} vs {param_name} (log scale)", - ) - ) - else: - chart = ( - alt.Chart(df_plot) - .mark_circle(size=60, opacity=0.6) - .encode( - alt.X(param_name, title=param_name), - alt.Y(metric, title=metric.replace("_", " ").title()), - alt.Color( - metric, - scale=alt.Scale(range=get_palette(metric, n_colors=256)), - legend=None, - ), - tooltip=[ - alt.Tooltip(param_name, format=".3f"), - alt.Tooltip(metric, format=".4f"), - ], - ) - .properties(height=300, title=f"{metric} vs {param_name}") - ) - - st.altair_chart(chart, width="stretch") - - else: - # Categorical parameter - box plot - df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]}) - - chart = ( - alt.Chart(df_plot) - .mark_boxplot() - .encode( - alt.X(param_name, title=param_name), - alt.Y(metric, title=metric.replace("_", " ").title()), - tooltip=[param_name, alt.Tooltip(metric, format=".4f")], - ) - .properties(height=300, title=f"{metric} vs {param_name}") - ) - - st.altair_chart(chart, width="stretch") - - -def render_parameter_correlation(results: pd.DataFrame, metric: str): - """Render correlation bar chart between parameters and score. - - Args: - results: DataFrame with CV results. - metric: The metric to analyze (e.g., 'f1', 'accuracy'). - - """ - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Get numeric parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])] - - if not numeric_params: - st.warning("No numeric parameters found for correlation analysis.") - return - - # Calculate correlations - correlations = [] - for param_col in numeric_params: - param_name = param_col.replace("param_", "") - corr = results[[param_col, score_col]].corr().iloc[0, 1] - correlations.append({"Parameter": param_name, "Correlation": corr}) - - corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False) - - # Get colormap from colors module (use diverging colormap for correlation) - hex_colors = get_palette("correlation", n_colors=256) - - # Create bar chart - chart = ( - alt.Chart(corr_df) - .mark_bar() - .encode( - alt.X("Correlation", title="Correlation with Score"), - alt.Y("Parameter", sort="-x", title="Parameter"), - alt.Color( - "Correlation", - scale=alt.Scale(range=hex_colors, domain=[-1, 1]), - legend=None, - ), - tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")], - ) - .properties(height=max(200, len(correlations) * 30)) - ) - - st.altair_chart(chart, width="stretch") - - -def render_binned_parameter_space(results: pd.DataFrame, metric: str): - """Render binned parameter space plots similar to old dashboard. - - This creates plots where parameters are binned and plotted against each other, - showing the metric value as color. Handles different hyperparameters dynamically. - - Args: - results: DataFrame with CV results. - metric: The metric to visualize (e.g., 'f1', 'accuracy'). - - """ - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Get numeric parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])] - - if len(numeric_params) < 2: - st.info("Need at least 2 numeric parameters for binned parameter space analysis.") - return - - # Prepare binned data and gather parameter info - results_binned = results.copy() - bin_info = {} - - for param_col in numeric_params: - param_name = param_col.replace("param_", "") - param_values = results[param_col].dropna() - - if len(param_values) == 0: - continue - - # Determine if we should use log bins - min_val = param_values.min() - max_val = param_values.max() - - if min_val > 0: - value_range = max_val / min_val - use_log = value_range > 100 or max_val < 0.01 - else: - use_log = False - - if use_log and min_val > 0: - # Logarithmic binning for parameters spanning many orders of magnitude - log_min = np.log10(min_val) - log_max = np.log10(max_val) - n_bins = min(10, max(3, int(log_max - log_min) + 1)) - bins = np.logspace(log_min, log_max, num=n_bins) - else: - # Linear binning - n_bins = min(10, max(3, int(np.sqrt(len(param_values))))) - bins = np.linspace(min_val, max_val, num=n_bins) - - results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins) - bin_info[param_name] = { - "use_log": use_log, - "bins": bins, - "min": min_val, - "max": max_val, - "n_unique": param_values.nunique(), - } - - # Get colormap from colors module - hex_colors = get_palette(metric, n_colors=256) - - # Get parameter names sorted by scale type (log parameters first, then linear) - param_names = [col.replace("param_", "") for col in numeric_params] - param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p)) - - st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values") - - # Create all pairwise combinations systematically - if len(param_names) == 2: - # Simple case: just one pair - x_param, y_param = param_names_sorted - _render_2d_param_plot( - results_binned, - x_param, - y_param, - score_col, - bin_info, - hex_colors, - metric, - height=500, - ) - else: - # Multiple parameters: create structured plots - st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}") - - # Strategy: Show all combinations (including duplicates with swapped axes) - # Group by first parameter to create organized sections - - for i, x_param in enumerate(param_names_sorted): - # Get all other parameters to pair with this one - other_params = [p for p in param_names_sorted if p != x_param] - - if not other_params: - continue - - # Create expander for each primary parameter - with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)): - # Create plots in rows of 2 - n_others = len(other_params) - for row_idx in range(0, n_others, 2): - cols = st.columns(2) - - for col_idx in range(2): - other_idx = row_idx + col_idx - if other_idx >= n_others: - break - - y_param = other_params[other_idx] - - with cols[col_idx]: - _render_2d_param_plot( - results_binned, - x_param, - y_param, - score_col, - bin_info, - hex_colors, - metric, - height=350, - ) - - -def _render_2d_param_plot( - results_binned: pd.DataFrame, - x_param: str, - y_param: str, - score_col: str, - bin_info: dict, - hex_colors: list, - metric: str, - height: int = 400, -): - """Render a 2D parameter space plot. - - Args: - results_binned: DataFrame with binned results. - x_param: Name of x-axis parameter. - y_param: Name of y-axis parameter. - score_col: Column name for the score. - bin_info: Dictionary with binning information for each parameter. - hex_colors: List of hex colors for the colormap. - metric: Name of the metric being visualized. - height: Height of the plot in pixels. - - """ - plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna() - - if len(plot_data) == 0: - st.warning(f"No data available for {x_param} vs {y_param}") - return - - x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale() - y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale() - - # Determine marker size based on number of points - n_points = len(plot_data) - marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40 - - chart = ( - alt.Chart(plot_data) - .mark_circle(size=marker_size, opacity=0.7) - .encode( - x=alt.X( - f"param_{x_param}:Q", - title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}", - scale=x_scale, - ), - y=alt.Y( - f"param_{y_param}:Q", - title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}", - scale=y_scale, - ), - color=alt.Color( - f"{score_col}:Q", - scale=alt.Scale(range=hex_colors), - title=metric.replace("_", " ").title(), - ), - tooltip=[ - alt.Tooltip( - f"param_{x_param}:Q", - title=x_param, - format=".2e" if bin_info[x_param]["use_log"] else ".3f", - ), - alt.Tooltip( - f"param_{y_param}:Q", - title=y_param, - format=".2e" if bin_info[y_param]["use_log"] else ".3f", - ), - alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"), - ], - ) - .properties( - height=height, - title=f"{x_param} vs {y_param}", - ) - .interactive() - ) - - st.altair_chart(chart, width="stretch") - - -def render_score_evolution(results: pd.DataFrame, metric: str): - """Render evolution of scores during search. - - Args: - results: DataFrame with CV results. - metric: The metric to plot (e.g., 'f1', 'accuracy'). - - """ - st.subheader(f"📉 {metric.replace('_', ' ').title()} Evolution") - - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Create a copy with iteration number - df_plot = results[[score_col]].copy() - df_plot["Iteration"] = range(len(df_plot)) - df_plot["Best So Far"] = df_plot[score_col].cummax() - df_plot = df_plot.rename(columns={score_col: "Score"}) - - # Reshape for Altair - df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type") - - # Get colormap for score evolution - evolution_cmap = get_cmap("score_evolution") - evolution_colors = [ - mcolors.rgb2hex(evolution_cmap(0.3)), - mcolors.rgb2hex(evolution_cmap(0.7)), - ] - - # Create line chart - chart = ( - alt.Chart(df_long) - .mark_line() - .encode( - alt.X("Iteration", title="Iteration"), - alt.Y("value", title=metric.replace("_", " ").title()), - alt.Color( - "Type", - legend=alt.Legend(title=""), - scale=alt.Scale(range=evolution_colors), - ), - strokeDash=alt.StrokeDash( - "Type", - legend=None, - scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]), - ), - tooltip=[ - "Iteration", - "Type", - alt.Tooltip("value", format=".4f", title="Score"), - ], - ) - .properties(height=400) - ) - - st.altair_chart(chart, width="stretch") - - # Show statistics - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}") - with col2: - st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}") - with col3: - st.metric("Std Dev", f"{df_plot['Score'].std():.4f}") - with col4: - # Find iteration where best was found - best_iter = df_plot["Score"].idxmax() - st.metric("Best at Iteration", best_iter) - - -@st.fragment -def render_multi_metric_comparison(results: pd.DataFrame): - """Render comparison of multiple metrics. - - Args: - results: DataFrame with CV results. - - """ - # Get all test score columns - score_cols = [col for col in results.columns if col.startswith("mean_test_")] - - if len(score_cols) < 2: - st.warning("Need at least 2 metrics for comparison.") - return - - # Get parameter columns for color options - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - param_names = [col.replace("param_", "") for col in param_cols] - - if not param_names: - st.warning("No parameters found for coloring.") - return - - # Let user select two metrics and color parameter - col1, col2, col3 = st.columns(3) - with col1: - metric1 = st.selectbox( - "Select First Metric", - options=[col.replace("mean_test_", "") for col in score_cols], - index=0, - key="metric1_select", - ) - - with col2: - metric2 = st.selectbox( - "Select Second Metric", - options=[col.replace("mean_test_", "") for col in score_cols], - index=min(1, len(score_cols) - 1), - key="metric2_select", - ) - - with col3: - color_by = st.selectbox( - "Color By", - options=param_names, - index=0, - key="color_select", - ) - - if metric1 == metric2: - st.warning("Please select different metrics.") - return - - # Create scatter plot data - df_plot = pd.DataFrame( - { - metric1: results[f"mean_test_{metric1}"], - metric2: results[f"mean_test_{metric2}"], - } - ) - - # Add color parameter to dataframe - param_col = f"param_{color_by}" - df_plot[color_by] = results[param_col] - color_col = color_by - - # Check if parameter is numeric and should use log scale - param_values = results[param_col].dropna() - if pd.api.types.is_numeric_dtype(param_values): - value_range = param_values.max() / (param_values.min() + 1e-10) - use_log = value_range > 100 - - if use_log: - color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256)) - color_format = ".2e" - else: - color_scale = alt.Scale(range=get_palette(color_by, n_colors=256)) - color_format = ".3f" - else: - # Categorical parameter - color_scale = alt.Scale(scheme="category20") - color_format = None - - # Build tooltip list - tooltip_list = [ - alt.Tooltip(metric1, format=".4f"), - alt.Tooltip(metric2, format=".4f"), - ] - if color_format: - tooltip_list.append(alt.Tooltip(color_col, format=color_format)) - else: - tooltip_list.append(color_col) - - # Calculate axis domains starting from min values - x_min = df_plot[metric1].min() - x_max = df_plot[metric1].max() - y_min = df_plot[metric2].min() - y_max = df_plot[metric2].max() - - # Add small padding (2% of range) for better visualization - x_padding = (x_max - x_min) * 0.02 - y_padding = (y_max - y_min) * 0.02 - - chart = ( - alt.Chart(df_plot) - .mark_circle(size=60, opacity=0.6) - .encode( - alt.X( - metric1, - title=metric1.replace("_", " ").title(), - scale=alt.Scale(domain=[x_min - x_padding, x_max + x_padding]), - ), - alt.Y( - metric2, - title=metric2.replace("_", " ").title(), - scale=alt.Scale(domain=[y_min - y_padding, y_max + y_padding]), - ), - alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()), - tooltip=tooltip_list, - ) - .properties(height=500) - ) - - st.altair_chart(chart, width="stretch") - - # Calculate correlation - corr = df_plot[[metric1, metric2]].corr().iloc[0, 1] - st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}") - - -def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin_width: int = 40): - """Render ESPA-specific binned parameter space plots. - - Creates faceted plots for all combinations of the three ESPA parameters: - - eps_cl vs eps_e (binned by initial_K) - - eps_cl vs initial_K (binned by eps_e) - - eps_e vs initial_K (binned by eps_cl) - - Args: - results: DataFrame with CV results. - metric: The metric to visualize (e.g., 'f1', 'accuracy'). - k_bin_width: Width of bins for initial_K parameter. - - """ - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Check if this is an ESPA model with the required parameters - required_params = ["param_initial_K", "param_eps_cl", "param_eps_e"] - if not all(param in results.columns for param in required_params): - st.info("ESPA-specific parameters not found. This visualization is only for ESPA models.") - return - - # Get colormap from colors module - hex_colors = get_palette(metric, n_colors=256) - - # Prepare base plot data - base_data = results[["param_eps_e", "param_eps_cl", "param_initial_K", score_col]].copy() - base_data = base_data.dropna() - - if len(base_data) == 0: - st.warning("No data available for ESPA binned parameter space.") - return - - # Configuration for each plot combination - plot_configs = [ - { - "x_param": "param_eps_e", - "y_param": "param_eps_cl", - "bin_param": "param_initial_K", - "x_label": "eps_e", - "y_label": "eps_cl", - "bin_label": "initial_K", - "x_scale": "log", - "y_scale": "log", - "bin_type": "linear", - "bin_width": k_bin_width, - "title": "eps_cl vs eps_e (binned by initial_K)", - }, - { - "x_param": "param_eps_cl", - "y_param": "param_initial_K", - "bin_param": "param_eps_e", - "x_label": "eps_cl", - "y_label": "initial_K", - "bin_label": "eps_e", - "x_scale": "log", - "y_scale": "linear", - "bin_type": "log", - "bin_width": None, # Will use log bins - "title": "initial_K vs eps_cl (binned by eps_e)", - }, - { - "x_param": "param_eps_e", - "y_param": "param_initial_K", - "bin_param": "param_eps_cl", - "x_label": "eps_e", - "y_label": "initial_K", - "bin_label": "eps_cl", - "x_scale": "log", - "y_scale": "linear", - "bin_type": "log", - "bin_width": None, # Will use log bins - "title": "initial_K vs eps_e (binned by eps_cl)", - }, - ] - - # Create each plot - for config in plot_configs: - st.markdown(f"**{config['title']}**") - - plot_data = base_data.copy() - - # Create bins for the binning parameter - bin_values = plot_data[config["bin_param"]].dropna() - if len(bin_values) == 0: - st.warning(f"No {config['bin_label']} values found.") - continue - - if config["bin_type"] == "log": - # Logarithmic binning for epsilon parameters - log_min = np.log10(bin_values.min()) - log_max = np.log10(bin_values.max()) - n_bins = min(10, max(5, int(log_max - log_min) + 1)) - bins = np.logspace(log_min, log_max, num=n_bins) - # Adjust bins to ensure all values are captured - bins[0] = bins[0] * 0.999 # Extend first bin to capture minimum - bins[-1] = bins[-1] * 1.001 # Extend last bin to capture maximum - else: - # Linear binning for initial_K - bin_min = bin_values.min() - bin_max = bin_values.max() - bins = np.arange(bin_min, bin_max + config["bin_width"], config["bin_width"]) - - # Bin the parameter - plot_data["binned_param"] = pd.cut(plot_data[config["bin_param"]], bins=bins, right=False) - - # Remove any NaN bins (shouldn't happen, but just in case) - plot_data = plot_data.dropna(subset=["binned_param"]) - - # Sort bins and convert to string for ordering - plot_data = plot_data.sort_values("binned_param") - plot_data["binned_param_str"] = plot_data["binned_param"].astype(str) - bin_order = plot_data["binned_param_str"].unique().tolist() - - # Create faceted scatter plot - x_scale = alt.Scale(type=config["x_scale"]) if config["x_scale"] == "log" else alt.Scale() - # For initial_K plots, set domain to start from the minimum value in the search space - if config["y_label"] == "initial_K": - y_min = plot_data[config["y_param"]].min() - y_max = plot_data[config["y_param"]].max() - y_scale = alt.Scale(domain=[y_min, y_max]) - else: - y_scale = alt.Scale(type=config["y_scale"]) if config["y_scale"] == "log" else alt.Scale() - - chart = ( - alt.Chart(plot_data) - .mark_circle(size=60, opacity=0.7) - .encode( - x=alt.X( - f"{config['x_param']}:Q", - scale=x_scale, - axis=alt.Axis(title=config["x_label"], grid=True, gridOpacity=0.5), - ), - y=alt.Y( - f"{config['y_param']}:Q", - scale=y_scale, - axis=alt.Axis(title=config["y_label"], grid=True, gridOpacity=0.5), - ), - color=alt.Color( - f"{score_col}:Q", - scale=alt.Scale(range=hex_colors), - title=metric.replace("_", " ").title(), - ), - tooltip=[ - alt.Tooltip("param_eps_e:Q", title="eps_e", format=".2e"), - alt.Tooltip("param_eps_cl:Q", title="eps_cl", format=".2e"), - alt.Tooltip("param_initial_K:Q", title="initial_K", format=".0f"), - alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"), - alt.Tooltip("binned_param_str:N", title=f"{config['bin_label']} Bin"), - ], - ) - .properties(width=200, height=200) - .facet( - facet=alt.Facet( - "binned_param_str:N", - title=f"{config['bin_label']} (binned)", - sort=bin_order, - ), - columns=5, - ) - ) - - st.altair_chart(chart, width="stretch") - - # Show statistics about the binning - n_bins = len(bin_order) - n_total = len(plot_data) - st.caption( - f"Showing {n_total} data points across {n_bins} bins of {config['bin_label']}. " - f"Each facet shows {config['y_label']} vs {config['x_label']} for a range of {config['bin_label']} values." - ) - - st.write("") # Add some spacing between plots - - -def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10): - """Render table of top N configurations. - - Args: - results: DataFrame with CV results. - metric: The metric to rank by (e.g., 'f1', 'accuracy'). - top_n: Number of top configurations to show. - - """ - score_col = f"mean_test_{metric}" - if score_col not in results.columns: - st.warning(f"Metric {metric} not found in results.") - return - - # Get parameter columns - param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] - - if not param_cols: - st.warning("No parameter columns found in results.") - return - - # Get top N configurations - top_configs = results.nlargest(top_n, score_col) - - # Create display dataframe - display_cols = ["rank_test_" + metric, score_col, *param_cols] - display_cols = [col for col in display_cols if col in top_configs.columns] - - display_df = top_configs[display_cols].copy() - - # Rename columns for better display - display_df = display_df.rename( - columns={ - "rank_test_" + metric: "Rank", - score_col: metric.replace("_", " ").title(), - } - ) - - # Rename parameter columns - display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns] - - # Format score column - score_col_display = metric.replace("_", " ").title() - display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}") - - st.dataframe(display_df, hide_index=True, width="stretch") - - -@st.fragment -def render_confusion_matrix_map( - result_path: Path, settings: TrainingSettings, merged_predictions: gpd.GeoDataFrame | None = None -): - """Render 3D pydeck map showing model performance on training data. - - Displays cells from the training dataset with predictions, colored by correctness. - Uses true labels for elevation (height) and different shades of red for incorrect predictions. - - Args: - result_path: Path to the training result directory (not used, kept for compatibility). - settings: Settings dictionary containing grid, level, task, and target information. - merged_predictions: GeoDataFrame with predictions, true labels, and split info. - - """ - if merged_predictions is None: - st.warning("Prediction data not available. Cannot display map.") - return - - # Get grid type and task from settings - grid = settings.grid - task = settings.task - - # Use the merged predictions which already have true labels, predictions, and split info - merged = merged_predictions.copy() - merged["is_correct"] = merged["true_class"] == merged["predicted_class"] - - if len(merged) == 0: - st.warning("No predictions found for labeled cells.") - return - - # Get ordered class labels for the task - ordered_classes = get_ordered_classes(task) - - # Create controls - col1, col2, col3 = st.columns([2, 1, 1]) - - with col1: - # Split selector (similar to confusion matrix) - split_type = st.selectbox( - "Select Data Split", - options=["test", "train", "all"], - format_func=lambda x: {"test": "Test Set", "train": "Training Set (CV)", "all": "All Data"}[x], - help="Choose which data split to display on the map", - key="prediction_map_split_select", - ) - - with col2: - # Color scheme selector - show_only_incorrect = st.checkbox( - "Highlight Errors Only", - value=False, - help="Show only incorrect predictions in red, hide correct ones", - key="prediction_map_errors_only", - ) - - with col3: - opacity = st.slider( - "Opacity", - min_value=0.1, - max_value=1.0, - value=0.7, - step=0.1, - key="prediction_map_opacity", - ) - - # Filter data by split - if split_type == "test": - display_gdf = merged[merged["split"] == "test"].copy() - split_caption = "Test Set (held-out data)" - elif split_type == "train": - display_gdf = merged[merged["split"] == "train"].copy() - split_caption = "Training Set (CV data)" - else: # "all" - display_gdf = merged.copy() - split_caption = "All Available Data" - - # Optionally filter to show only incorrect predictions - if show_only_incorrect: - display_gdf = display_gdf[~display_gdf["is_correct"]].copy() - - if len(display_gdf) == 0: - st.warning(f"No cells found for {split_caption}.") - return - - st.caption(f"📍 Showing {len(display_gdf)} cells from {split_caption}") - - # Convert to WGS84 for pydeck - display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326") - - # Fix antimeridian issues for hex grids - if grid == "hex": - display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry) - - # Get red material colormap for incorrect predictions - red_cmap = get_cmap("red_predictions") # Use red_material palette - n_classes = len(ordered_classes) - - # Assign colors based on correctness - def get_color(row): - if row["is_correct"]: - # Green for correct predictions - return [46, 204, 113] - else: - # Different shades of red for each predicted class (ordered) - pred_class = row["predicted_class"] - if pred_class in ordered_classes: - class_idx = ordered_classes.index(pred_class) - # Sample from red colormap based on class index - color_value = red_cmap(class_idx / max(n_classes - 1, 1)) - return [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)] - else: - # Fallback red if class not found - return [231, 76, 60] - - display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1) - - # Add line color based on split: blue for test, orange for train - def get_line_color(row): - if row["split"] == "test": - return [52, 152, 219] # Blue for test split - else: - return [230, 126, 34] # Orange for train split - - display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1) - - # Add elevation based on TRUE label (not predicted) - # Map each true class to a height based on its position in the ordered list - def get_elevation(row): - true_class = row["true_class"] - if true_class in ordered_classes: - class_idx = ordered_classes.index(true_class) - # Normalize to 0-1 range based on class position - return (class_idx + 1) / n_classes - else: - return 0.5 # Default elevation - - display_gdf_wgs84["elevation"] = display_gdf_wgs84.apply(get_elevation, axis=1) - - # Convert to GeoJSON format - geojson_data = [] - for _, row in display_gdf_wgs84.iterrows(): - # Determine split and status for tooltip - split_name = "Test" if row["split"] == "test" else "Training (CV)" - status = "✓ Correct" if row["is_correct"] else "✗ Incorrect" - - feature = { - "type": "Feature", - "geometry": row["geometry"].__geo_interface__, - "properties": { - "true_label": str(row["true_class"]), - "predicted_label": str(row["predicted_class"]), - "is_correct": bool(row["is_correct"]), - "split": split_name, - "status": status, - "fill_color": row["fill_color"], - "line_color": row["line_color"], - "elevation": float(row["elevation"]), - }, - } - geojson_data.append(feature) - - # Create pydeck layer - layer = pdk.Layer( - "GeoJsonLayer", - geojson_data, - opacity=opacity, - stroked=True, - filled=True, - extruded=True, - wireframe=False, - get_fill_color="properties.fill_color", - get_line_color="properties.line_color", - line_width_min_pixels=2, - get_elevation="properties.elevation", - elevation_scale=500000, - pickable=True, - ) - - # Set initial view state (centered on the Arctic) - view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) - - # Create deck - deck = pdk.Deck( - layers=[layer], - initial_view_state=view_state, - tooltip={ - "html": "Status: {status}
" - "True Label: {true_label}
" - "Predicted Label: {predicted_label}
" - "Split: {split}", - "style": {"backgroundColor": "steelblue", "color": "white"}, - }, - map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", - ) - - # Render the map - st.pydeck_chart(deck) - - # Show statistics for displayed data - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Cells Displayed", len(display_gdf)) - - with col2: - correct = len(display_gdf[display_gdf["is_correct"]]) - st.metric("Correct Predictions", correct) - - with col3: - if len(display_gdf) > 0: - accuracy = correct / len(display_gdf) - st.metric("Accuracy", f"{accuracy:.2%}") - else: - st.metric("Accuracy", "N/A") - - # Add legend - with st.expander("Legend", expanded=True): - st.markdown("**Fill Color (Prediction Correctness):**") - - # Correct predictions - correct_count = len(display_gdf[display_gdf["is_correct"]]) - incorrect_count = len(display_gdf[~display_gdf["is_correct"]]) - - st.markdown( - f'
' - f'
' - f"Correct Predictions ({correct_count} cells, " - f"{correct_count / len(display_gdf) * 100 if len(display_gdf) > 0 else 0:.1f}%)
", - unsafe_allow_html=True, - ) - - # Incorrect predictions by predicted class (shades of red) - if incorrect_count > 0: - st.markdown( - f"Incorrect Predictions by Predicted Class ({incorrect_count} cells):", unsafe_allow_html=True - ) - - for class_idx, class_label in enumerate(ordered_classes): - # Get count of incorrect predictions for this predicted class - count = len(display_gdf[(~display_gdf["is_correct"]) & (display_gdf["predicted_class"] == class_label)]) - if count > 0: - # Get color for this predicted class - color_value = red_cmap(class_idx / max(n_classes - 1, 1)) - rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)] - - percentage = count / incorrect_count * 100 - - st.markdown( - f'
' - f'
' - f"Predicted as {class_label}: {count} ({percentage:.1f}%)
", - unsafe_allow_html=True, - ) - - st.markdown("---") - st.markdown("**Border Color (Data Split):**") - - # Count by split in displayed data - test_in_display = len(display_gdf[display_gdf["split"] == "test"]) - train_in_display = len(display_gdf[display_gdf["split"] == "train"]) - - if test_in_display > 0: - st.markdown( - f'
' - f'
' - f"Test Split ({test_in_display} cells)
", - unsafe_allow_html=True, - ) - - if train_in_display > 0: - st.markdown( - f'
' - f'
' - f"Training Split ({train_in_display} cells)
", - unsafe_allow_html=True, - ) - - st.markdown("---") - st.markdown("**Elevation (3D Height):**") - - # Show elevation mapping for each true class - st.markdown("Height represents the true label:", unsafe_allow_html=True) - for class_idx, class_label in enumerate(ordered_classes): - elevation_value = (class_idx + 1) / n_classes - height_km = elevation_value * 500 # Since elevation_scale is 500000 - st.markdown( - f'
{class_label}: {height_km:.0f} km
', - unsafe_allow_html=True, - ) - st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.") - - -def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str): - """Render confusion matrix as an interactive heatmap. - - Args: - confusion_matrix: xarray DataArray with dimensions (true_label, predicted_label). - task: Task type ('binary' or 'multiclass'). - - """ - import plotly.express as px - - # Convert to DataFrame for plotting - cm_df = confusion_matrix.to_pandas() - - # Get labels (convert numeric labels to semantic labels if possible) - true_labels = confusion_matrix.coords["true_label"].values - pred_labels = confusion_matrix.coords["predicted_label"].values - - # Check if labels are already strings (from predictions) or numeric (from stored confusion matrices) - first_true_label = true_labels[0] - is_string_labels = isinstance(first_true_label, str) or ( - hasattr(first_true_label, "dtype") and first_true_label.dtype.kind in ("U", "O") - ) - - if is_string_labels: - # Labels are already string labels, use them directly - true_labels_str = [str(label) for label in true_labels] - pred_labels_str = [str(label) for label in pred_labels] - elif task == "binary": - # Numeric binary labels - map 0/1 to No-RTS/RTS - label_map = {0: "No-RTS", 1: "RTS"} - true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels] - pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels] - else: - # Numeric multiclass labels - use as is - true_labels_str = [str(label) for label in true_labels] - pred_labels_str = [str(label) for label in pred_labels] - - # Rename DataFrame indices and columns for display - cm_df.index = true_labels_str - cm_df.columns = pred_labels_str - - # Store raw counts for annotations - cm_counts = cm_df.copy() - - # Normalize by row (true label) to get percentages - cm_normalized = cm_df.div(cm_df.sum(axis=1), axis=0) - - # Create custom text annotations showing both percentage and count - text_annotations = [] - for i, true_label in enumerate(true_labels_str): - row_annotations = [] - for j, pred_label in enumerate(pred_labels_str): - count = int(cm_counts.iloc[i, j]) - percentage = cm_normalized.iloc[i, j] * 100 - row_annotations.append(f"{percentage:.1f}%
({count:,})") - text_annotations.append(row_annotations) - - # Create heatmap with normalized values - fig = px.imshow( - cm_normalized, - labels=dict(x="Predicted Label", y="True Label", color="Proportion"), - x=pred_labels_str, - y=true_labels_str, - color_continuous_scale="Blues", - aspect="auto", - zmin=0, - zmax=1, - ) - - # Update with custom annotations - fig.update_traces( - text=text_annotations, - texttemplate="%{text}", - textfont={"size": 12}, - ) - - # Update layout for better readability - fig.update_layout( - title="Confusion Matrix (Normalized by True Label)", - xaxis_title="Predicted Label", - yaxis_title="True Label", - height=500, - ) - - # Update colorbar to show percentage - fig.update_coloraxes( - colorbar=dict( - title="Proportion", - tickformat=".0%", - ) - ) - - st.plotly_chart(fig, width="stretch") - - st.caption( - "📊 Values show **row-normalized percentages** (percentage of each true class predicted as each label). " - "Raw counts shown in parentheses." - ) - - # Calculate and display metrics from confusion matrix - col1, col2, col3 = st.columns(3) - - total_samples = int(cm_df.values.sum()) - correct_predictions = int(np.trace(cm_df.values)) - accuracy = correct_predictions / total_samples if total_samples > 0 else 0 - - with col1: - st.metric("Total Samples", f"{total_samples:,}") - - with col2: - st.metric("Correct Predictions", f"{correct_predictions:,}") - - with col3: - st.metric("Accuracy", f"{accuracy:.2%}") - - # Add detailed breakdown for binary classification - if task == "binary": - st.markdown("#### Binary Classification Metrics") - - # Extract TP, TN, FP, FN from confusion matrix - # Assuming 0=No-RTS (negative), 1=RTS (positive) - tn = int(cm_df.iloc[0, 0]) - fp = int(cm_df.iloc[0, 1]) - fn = int(cm_df.iloc[1, 0]) - tp = int(cm_df.iloc[1, 1]) - - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric("True Positives (TP)", f"{tp:,}") - - with col2: - st.metric("True Negatives (TN)", f"{tn:,}") - - with col3: - st.metric("False Positives (FP)", f"{fp:,}") - - with col4: - st.metric("False Negatives (FN)", f"{fn:,}") diff --git a/src/entropice/dashboard/plots/hyperparameter_space.py b/src/entropice/dashboard/plots/hyperparameter_space.py new file mode 100644 index 0000000..77f57e6 --- /dev/null +++ b/src/entropice/dashboard/plots/hyperparameter_space.py @@ -0,0 +1,417 @@ +"""Hyperparameter space plotting functions.""" + +import matplotlib.colors as mcolors +import pandas as pd +import plotly.graph_objects as go + +from entropice.dashboard.utils.colors import get_cmap, get_palette + + +def plot_performance_summary(results: pd.DataFrame, refit_metric: str) -> tuple[pd.DataFrame, pd.DataFrame, dict]: + """Compute performance summary statistics. + + Args: + results: DataFrame with CV results. + refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted'). + + Returns: + Tuple of (best_scores_df, score_stats_df, best_params_dict). + + """ + # Get all test score columns + score_cols = [col for col in results.columns if col.startswith("mean_test_")] + + if not score_cols: + return pd.DataFrame(), pd.DataFrame(), {} + + # Calculate best scores + best_scores = [] + for col in score_cols: + metric_name = col.replace("mean_test_", "").replace("_", " ").title() + best_score = results[col].max() + best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"}) + + # Calculate score statistics + score_stats = [] + for col in score_cols: + metric_name = col.replace("mean_test_", "").replace("_", " ").title() + mean_score = results[col].mean() + std_score = results[col].std() + score_stats.append( + { + "Metric": metric_name, + "Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}", + } + ) + + # Get best parameter combination + refit_col = f"mean_test_{refit_metric}" + if refit_col not in results.columns and score_cols: + refit_col = score_cols[0] + + best_idx = results[refit_col].idxmax() + best_row = results.loc[best_idx] + + # Extract parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + best_params = {col.replace("param_", ""): best_row[col] for col in param_cols} + + return pd.DataFrame(best_scores), pd.DataFrame(score_stats), best_params + + +def plot_parameter_distributions(results: pd.DataFrame, param_grid: dict | None = None) -> dict[str, go.Figure]: + """Create histogram charts for parameter distributions. + + Args: + results: DataFrame with CV results. + param_grid: Optional parameter grid with distribution information. + + Returns: + Dictionary mapping parameter names to Plotly figures. + + """ + # Get parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if not param_cols: + return {} + + cmap = get_cmap("parameter_distribution") + bar_color = mcolors.rgb2hex(cmap(0.5)) + + charts = {} + for param_col in param_cols: + param_name = param_col.replace("param_", "") + param_values = results[param_col].dropna() + + if len(param_values) == 0: + continue + + # Determine if parameter is numeric + if pd.api.types.is_numeric_dtype(param_values): + # Create histogram for numeric parameters + fig = go.Figure() + fig.add_trace( + go.Histogram( + x=param_values, + nbinsx=30, + marker_color=bar_color, + name=param_name, + ) + ) + fig.update_layout( + title=f"Distribution of {param_name}", + xaxis_title=param_name, + yaxis_title="Count", + height=400, + showlegend=False, + ) + else: + # Create bar chart for categorical parameters + value_counts = param_values.value_counts().reset_index() + value_counts.columns = [param_name, "count"] + fig = go.Figure() + fig.add_trace( + go.Bar( + x=value_counts[param_name], + y=value_counts["count"], + marker_color=bar_color, + name=param_name, + ) + ) + fig.update_layout( + title=f"Distribution of {param_name}", + xaxis_title=param_name, + yaxis_title="Count", + height=400, + showlegend=False, + ) + + charts[param_name] = fig + + return charts + + +def plot_score_vs_parameters( + results: pd.DataFrame, metric: str, param_grid: dict | None = None +) -> dict[str, go.Figure]: + """Create scatter plots of score vs each parameter. + + Args: + results: DataFrame with CV results. + metric: The metric to plot (e.g., 'f1', 'accuracy'). + param_grid: Optional parameter grid with distribution information. + + Returns: + Dictionary mapping parameter names to Plotly figures. + + """ + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + return {} + + # Get parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if not param_cols: + return {} + + # Get colormap + hex_colors = get_palette(metric, n_colors=256) + + charts = {} + for param_col in param_cols: + param_name = param_col.replace("param_", "") + param_values = results[param_col].dropna() + + if len(param_values) == 0: + continue + + # Check if this parameter uses log scale + use_log = False + if param_grid and param_name in param_grid: + param_config = param_grid[param_name] + if isinstance(param_config, dict) and param_config.get("distribution") == "loguniform": + use_log = True + + # Create scatter plot + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=results[param_col], + y=results[score_col], + mode="markers", + marker={ + "size": 8, + "color": results[score_col], + "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)], + "showscale": False, + "opacity": 0.6, + }, + text=[ + f"{param_name}: {val}
Score: {score:.4f}" + for val, score in zip(results[param_col], results[score_col]) + ], + hovertemplate="%{text}", + ) + ) + fig.update_layout( + title=f"{metric.replace('_', ' ').title()} vs {param_name}", + xaxis_title=param_name, + xaxis_type="log" if use_log else "linear", + yaxis_title=metric.replace("_", " ").title(), + height=400, + showlegend=False, + ) + + charts[param_name] = fig + + return charts + + +def plot_parameter_correlations(results: pd.DataFrame, metric: str) -> go.Figure | None: + """Create correlation bar chart between parameters and score. + + Args: + results: DataFrame with CV results. + metric: The metric to analyze (e.g., 'f1', 'accuracy'). + + Returns: + Plotly figure or None if no numeric parameters found. + + """ + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + return None + + # Get numeric parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])] + + if not numeric_params: + return None + + # Calculate correlations + correlations = [] + for param_col in numeric_params: + param_name = param_col.replace("param_", "") + corr = results[[param_col, score_col]].corr().iloc[0, 1] + correlations.append({"Parameter": param_name, "Correlation": corr}) + + corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False) + + # Get colormap (use diverging colormap for correlation) + hex_colors = get_palette("correlation", n_colors=256) + + # Create bar chart + fig = go.Figure() + fig.add_trace( + go.Bar( + x=corr_df["Correlation"], + y=corr_df["Parameter"], + orientation="h", + marker={ + "color": corr_df["Correlation"], + "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)], + "cmin": -1, + "cmax": 1, + "showscale": False, + }, + text=[f"{c:.3f}" for c in corr_df["Correlation"]], + hovertemplate="%{y}
Correlation: %{x:.3f}", + ) + ) + fig.update_layout( + xaxis_title="Correlation with Score", + yaxis_title="Parameter", + height=max(300, len(correlations) * 30), + showlegend=False, + ) + + return fig + + +def plot_parameter_interactions(results: pd.DataFrame, metric: str, param_grid: dict | None = None) -> list[go.Figure]: + """Create scatter plots showing parameter interactions. + + Args: + results: DataFrame with CV results. + metric: The metric to visualize (e.g., 'f1', 'accuracy'). + param_grid: Optional parameter grid with distribution information. + + Returns: + List of Plotly figures showing parameter interactions. + + """ + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + return [] + + # Get numeric parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])] + + if len(numeric_params) < 2: + return [] + + # Get colormap + hex_colors = get_palette(metric, n_colors=256) + + # Create scatter plots for parameter pairs + charts = [] + param_names = [col.replace("param_", "") for col in numeric_params] + + for i, x_param in enumerate(param_names[:-1]): + for y_param in param_names[i + 1 :]: + x_col = f"param_{x_param}" + y_col = f"param_{y_param}" + + # Check if parameters use log scale + x_use_log = False + y_use_log = False + if param_grid: + if x_param in param_grid: + x_config = param_grid[x_param] + if isinstance(x_config, dict) and x_config.get("distribution") == "loguniform": + x_use_log = True + if y_param in param_grid: + y_config = param_grid[y_param] + if isinstance(y_config, dict) and y_config.get("distribution") == "loguniform": + y_use_log = True + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=results[x_col], + y=results[y_col], + mode="markers", + marker={ + "size": 8, + "color": results[score_col], + "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)], + "showscale": True, + "colorbar": {"title": metric.replace("_", " ").title()}, + "opacity": 0.7, + }, + text=[ + f"{x_param}: {x_val}
{y_param}: {y_val}
Score: {score:.4f}" + for x_val, y_val, score in zip(results[x_col], results[y_col], results[score_col]) + ], + hovertemplate="%{text}", + ) + ) + fig.update_layout( + title=f"{metric.replace('_', ' ').title()} by {x_param} and {y_param}", + xaxis_title=x_param, + xaxis_type="log" if x_use_log else "linear", + yaxis_title=y_param, + yaxis_type="log" if y_use_log else "linear", + height=500, + width=500, + ) + + charts.append(fig) + + return charts + + +def plot_score_evolution(results: pd.DataFrame, metric: str) -> go.Figure | None: + """Create line chart showing score evolution over iterations. + + Args: + results: DataFrame with CV results. + metric: The metric to visualize (e.g., 'f1', 'accuracy'). + + Returns: + Plotly figure or None if metric not found. + + """ + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + return None + + # Add iteration number + iterations = list(range(len(results))) + scores = results[score_col].to_numpy() + best_so_far = results[score_col].cummax().to_numpy() + + # Get colormap + cmap = get_cmap("score_evolution") + score_color = mcolors.rgb2hex(cmap(0.3)) + best_color = mcolors.rgb2hex(cmap(0.7)) + + # Create line chart + fig = go.Figure() + + fig.add_trace( + go.Scatter( + x=iterations, + y=scores, + mode="lines", + name="Score", + line={"color": score_color, "width": 1}, + opacity=0.6, + hovertemplate="Iteration: %{x}
Score: %{y:.4f}", + ) + ) + + fig.add_trace( + go.Scatter( + x=iterations, + y=best_so_far, + mode="lines", + name="Best So Far", + line={"color": best_color, "width": 2}, + hovertemplate="Iteration: %{x}
Best So Far: %{y:.4f}", + ) + ) + + fig.update_layout( + title=f"{metric.replace('_', ' ').title()} Evolution", + xaxis_title="Iteration", + yaxis_title=metric.replace("_", " ").title(), + height=300, + hovermode="x unified", + ) + + return fig diff --git a/src/entropice/dashboard/plots/metrics.py b/src/entropice/dashboard/plots/metrics.py new file mode 100644 index 0000000..2ba9ba3 --- /dev/null +++ b/src/entropice/dashboard/plots/metrics.py @@ -0,0 +1,97 @@ +"""Metrics visualization plots.""" + +import numpy as np +import plotly.graph_objects as go +import xarray as xr + + +def plot_confusion_matrix(cm_data: xr.DataArray, title: str = "Confusion Matrix", normalize: str = "none") -> go.Figure: + """Plot an interactive confusion matrix heatmap. + + Args: + cm_data: XArray DataArray with confusion matrix data (dimensions: true_label, predicted_label). + title: Title for the plot. + normalize: Normalization mode - "none", "true", or "pred". + + Returns: + Plotly figure with the interactive confusion matrix heatmap. + + """ + # Get the data as numpy array + cm_array = cm_data.values.astype(float) + labels = cm_data.coords["true_label"].values.tolist() + + # Store original counts for display + cm_counts = cm_data.values + + # Apply normalization + if normalize == "true": + # Normalize over true labels (rows) - each row sums to 1 + row_sums = cm_array.sum(axis=1, keepdims=True) + cm_normalized = np.divide(cm_array, row_sums, where=row_sums != 0) + colorbar_title = "Proportion" + elif normalize == "pred": + # Normalize over predicted labels (columns) - each column sums to 1 + col_sums = cm_array.sum(axis=0, keepdims=True) + cm_normalized = np.divide(cm_array, col_sums, where=col_sums != 0) + colorbar_title = "Proportion" + else: + # No normalization + cm_normalized = cm_array + colorbar_title = "Count" + + # Create annotations for the heatmap + annotations = [] + for i, true_label in enumerate(labels): + for j, pred_label in enumerate(labels): + count = int(cm_counts[i, j]) + normalized_val = cm_normalized[i, j] + + # Format text based on normalization mode + if normalize == "none": + # Show count and percentage of total + total = cm_counts.sum() + pct = (count / total * 100) if total > 0 else 0 + text = f"{count}
({pct:.1f}%)" + else: + # Show percentage only for normalized versions + text = f"{normalized_val:.1%}" + + # Determine text color based on normalized value + threshold = cm_normalized.max() / 2 if cm_normalized.max() > 0 else 0.5 + text_color = "white" if normalized_val > threshold else "black" + + annotations.append( + { + "x": pred_label, + "y": true_label, + "text": text, + "showarrow": False, + "font": {"size": 10, "color": text_color}, + } + ) + + # Create the heatmap with normalized values for coloring + fig = go.Figure( + data=go.Heatmap( + z=cm_normalized, + x=labels, + y=labels, + colorscale="Blues", + colorbar={"title": colorbar_title}, + hoverongaps=False, + hovertemplate="True: %{y}
Predicted: %{x}
Count: %{customdata}", + customdata=cm_counts, + ) + ) + + # Add annotations + fig.update_layout( + annotations=annotations, + xaxis={"title": "Predicted Label", "side": "bottom"}, + yaxis={"title": "True Label", "autorange": "reversed"}, + width=600, + height=550, + ) + + return fig diff --git a/src/entropice/dashboard/plots/regression.py b/src/entropice/dashboard/plots/regression.py new file mode 100644 index 0000000..4012c0e --- /dev/null +++ b/src/entropice/dashboard/plots/regression.py @@ -0,0 +1,180 @@ +"""Regression analysis plotting functions.""" + +from typing import cast + +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from entropice.dashboard.utils.colors import get_palette + + +def plot_regression_scatter( + y_true: np.ndarray | pd.Series, + y_pred: np.ndarray | pd.Series, + title: str = "True vs Predicted", +) -> go.Figure: + """Create scatter plot of true vs predicted values for regression. + + Args: + y_true: True target values. + y_pred: Predicted target values. + title: Title for the plot. + + Returns: + Plotly figure with regression scatter plot. + + """ + # Convert to numpy arrays if needed + y_true_np = cast(np.ndarray, y_true.to_numpy()) if isinstance(y_true, pd.Series) else y_true + y_pred_np = cast(np.ndarray, y_pred.to_numpy()) if isinstance(y_pred, pd.Series) else y_pred + + # Calculate metrics + mse = np.mean((y_true_np - y_pred_np) ** 2) + mae = np.mean(np.abs(y_true_np - y_pred_np)) + r2 = 1 - (np.sum((y_true_np - y_pred_np) ** 2) / np.sum((y_true_np - np.mean(y_true_np)) ** 2)) + + # Get colormap + hex_colors = get_palette("r2", n_colors=256) + + # Calculate point density for coloring + from scipy.stats import gaussian_kde + + try: + # Create KDE for density estimation + xy = np.vstack([y_true_np, y_pred_np]) + kde = gaussian_kde(xy) + density = kde(xy) + except (np.linalg.LinAlgError, ValueError): + # Fallback if KDE fails (e.g., all points identical) + density = np.ones(len(y_true_np)) + + # Create figure + fig = go.Figure() + + # Add scatter plot + fig.add_trace( + go.Scatter( + x=y_true_np, + y=y_pred_np, + mode="markers", + marker={ + "size": 6, + "color": density, + "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)], + "showscale": False, + "opacity": 0.6, + }, + text=[f"True: {true:.3f}
Pred: {pred:.3f}" for true, pred in zip(y_true_np, y_pred_np)], + hovertemplate="%{text}", + name="Data", + ) + ) + + # Add diagonal line (perfect prediction) + min_val = min(y_true_np.min(), y_pred_np.min()) + max_val = max(y_true_np.max(), y_pred_np.max()) + fig.add_trace( + go.Scatter( + x=[min_val, max_val], + y=[min_val, max_val], + mode="lines", + line={"color": "red", "dash": "dash", "width": 2}, + name="Perfect Prediction", + hovertemplate="y = x", + ) + ) + + # Add metrics as annotation + metrics_text = f"R² = {r2:.4f}
MSE = {mse:.4f}
MAE = {mae:.4f}" + + fig.add_annotation( + x=0.02, + y=0.98, + xref="paper", + yref="paper", + text=metrics_text, + showarrow=False, + bgcolor="white", + bordercolor="black", + borderwidth=1, + xanchor="left", + yanchor="top", + font={"size": 12}, + ) + + fig.update_layout( + title=title, + xaxis_title="True Values", + yaxis_title="Predicted Values", + height=500, + showlegend=True, + legend={"x": 0.98, "y": 0.02, "xanchor": "right", "yanchor": "bottom"}, + ) + + # Make axes equal + fig.update_xaxes(scaleanchor="y", scaleratio=1) + + return fig + + +def plot_residuals( + y_true: np.ndarray | pd.Series, + y_pred: np.ndarray | pd.Series, + title: str = "Residual Plot", +) -> go.Figure: + """Create residual plot for regression diagnostics. + + Args: + y_true: True target values. + y_pred: Predicted target values. + title: Title for the plot. + + Returns: + Plotly figure with residual plot. + + """ + # Convert to numpy arrays if needed + y_true_np = cast(np.ndarray, y_true.to_numpy()) if isinstance(y_true, pd.Series) else y_true + y_pred_np = cast(np.ndarray, y_pred.to_numpy()) if isinstance(y_pred, pd.Series) else y_pred + + # Calculate residuals + residuals = y_true_np - y_pred_np + + # Get colormap + hex_colors = get_palette("r2", n_colors=256) + + # Create figure + fig = go.Figure() + + # Add scatter plot + fig.add_trace( + go.Scatter( + x=y_pred, + y=residuals, + mode="markers", + marker={ + "size": 6, + "color": np.abs(residuals), + "colorscale": [[i / 255, c] for i, c in enumerate(hex_colors)], + "showscale": True, + "colorbar": {"title": "Abs Residual"}, + "opacity": 0.6, + }, + text=[f"Pred: {pred:.3f}
Residual: {res:.3f}" for pred, res in zip(y_pred, residuals)], + hovertemplate="%{text}", + ) + ) + + # Add zero line + fig.add_hline(y=0, line_dash="dash", line_color="red", line_width=2) + + fig.update_layout( + title=title, + xaxis_title="Predicted Values", + yaxis_title="Residuals (True - Predicted)", + height=400, + showlegend=False, + ) + + return fig diff --git a/src/entropice/dashboard/sections/cv_result.py b/src/entropice/dashboard/sections/cv_result.py new file mode 100644 index 0000000..b1da3b2 --- /dev/null +++ b/src/entropice/dashboard/sections/cv_result.py @@ -0,0 +1,185 @@ +"""Training Result Sections.""" + +import streamlit as st + +from entropice.dashboard.plots.metrics import plot_confusion_matrix +from entropice.dashboard.utils.formatters import format_metric_name +from entropice.dashboard.utils.loaders import TrainingResult +from entropice.dashboard.utils.stats import CVMetricStatistics +from entropice.utils.types import GridConfig + + +def render_run_information(selected_result: TrainingResult, refit_metric): + """Render training run configuration overview. + + Args: + selected_result: The selected TrainingResult object. + refit_metric: The refit metric used for model selection. + + """ + st.header("📋 Run Information") + + grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type] + + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("Task", selected_result.settings.task.capitalize()) + with col2: + st.metric("Target", selected_result.settings.target.capitalize()) + with col3: + st.metric("Grid", grid_config.display_name) + with col4: + st.metric("Model", selected_result.settings.model.upper()) + with col5: + st.metric("Trials", len(selected_result.results)) + + st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}") + + +def _render_metrics(metrics: dict[str, float]): + """Render a set of metrics in a two-column layout. + + Args: + metrics: Dictionary of metric names and their values. + + """ + ncols = min(5, len(metrics)) + cols = st.columns(ncols) + for idx, (metric_name, metric_value) in enumerate(metrics.items()): + with cols[idx % ncols]: + st.metric(format_metric_name(metric_name), f"{metric_value:.4f}") + + +def render_metrics_section(selected_result: TrainingResult): + """Render test metrics overview showing final model performance. + + Args: + selected_result: The selected TrainingResult object. + + """ + # Test + st.header("🎯 Test Set Performance") + st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)") + _render_metrics(selected_result.test_metrics) + + # Train + st.header("🏋️‍♂️ Training Set Performance") + st.caption("Performance metrics on the training set (best model from hyperparameter search)") + _render_metrics(selected_result.train_metrics) + + # Combined / All + st.header("🧮 Overall Performance") + st.caption("Overall performance metrics combining training and test sets") + _render_metrics(selected_result.combined_metrics) + + +@st.fragment +def render_confusion_matrices(selected_result: TrainingResult): + """Render confusion matrices for classification tasks. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("🎭 Confusion Matrices") + + # Check if this is a classification task + if selected_result.settings.task not in ["binary", "count_regimes", "density_regimes"]: + st.info( + "📊 Confusion matrices are only available for classification tasks " + "(binary, count_regimes, density_regimes)." + ) + st.caption("Coming soon for regression tasks: residual plots and error distributions.") + return + + # Check if confusion matrix data is available + if selected_result.confusion_matrix is None: + st.warning("⚠️ No confusion matrix data found for this training result.") + return + + cm = selected_result.confusion_matrix + + # Add normalization selection + st.subheader("Display Options") + normalize_option = st.radio( + "Normalization", + options=["No normalization", "Normalize over True Labels", "Normalize over Predicted Labels"], + horizontal=True, + help="Choose how to normalize the confusion matrix values", + ) + + # Map selection to normalization mode + normalize_map = { + "No normalization": "none", + "Normalize over True Labels": "true", + "Normalize over Predicted Labels": "pred", + } + normalize_mode = normalize_map[normalize_option] + + cols = st.columns(3) + + with cols[0]: + # Test Set Confusion Matrix + st.subheader("Test Set") + st.caption("Held-out test set") + fig_test = plot_confusion_matrix(cm["test"], title="Test Set", normalize=normalize_mode) + st.plotly_chart(fig_test, width="stretch") + with cols[1]: + # Training Set Confusion Matrix + st.subheader("Training Set") + st.caption("Training set") + fig_train = plot_confusion_matrix(cm["train"], title="Training Set", normalize=normalize_mode) + st.plotly_chart(fig_train, width="stretch") + with cols[2]: + # Combined Confusion Matrix + st.subheader("Combined") + st.caption("Train + Test sets") + fig_combined = plot_confusion_matrix(cm["combined"], title="Combined", normalize=normalize_mode) + st.plotly_chart(fig_combined, width="stretch") + + +def render_cv_statistics_section(cv_stats: CVMetricStatistics, test_score: float): + """Render cross-validation statistics for selected metric. + + Args: + cv_stats: CVMetricStatistics object containing cross-validation statistics. + test_score: The test set score for the selected metric. + + """ + st.header("📈 Cross-Validation Statistics") + st.caption("Performance during hyperparameter search (averaged across CV folds)") + + col1, col2, col3, col4, col5 = st.columns(5) + + with col1: + st.metric("Best Score", f"{cv_stats.best_score:.4f}") + with col2: + st.metric("Mean Score", f"{cv_stats.mean_score:.4f}") + with col3: + st.metric("Std Dev", f"{cv_stats.std_score:.4f}") + with col4: + st.metric("Worst Score", f"{cv_stats.worst_score:.4f}") + with col5: + st.metric("Median Score", f"{cv_stats.median_score:.4f}") + + if cv_stats.mean_cv_std is not None: + st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds") + + # Compare with test metric + st.subheader("CV vs Test Performance") + + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Best CV Score", f"{cv_stats.best_score:.4f}") + with col2: + st.metric("Test Score", f"{test_score:.4f}") + with col3: + delta = test_score - cv_stats.best_score + delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0 + st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%") + + if abs(delta) > cv_stats.std_score: + st.warning( + "⚠️ Test performance differs significantly (larger than the CV standard deviation) from CV performance. " + "This may indicate overfitting or data distribution mismatch between training and test sets." + ) diff --git a/src/entropice/dashboard/sections/experiment_results.py b/src/entropice/dashboard/sections/experiment_results.py index c8a59a4..ecfb448 100644 --- a/src/entropice/dashboard/sections/experiment_results.py +++ b/src/entropice/dashboard/sections/experiment_results.py @@ -2,15 +2,16 @@ from datetime import datetime +import pandas as pd import streamlit as st -from entropice.dashboard.utils.loaders import TrainingResult +from entropice.dashboard.utils.loaders import AutogluonTrainingResult, TrainingResult from entropice.utils.types import ( GridConfig, ) -def render_training_results_summary(training_results: list[TrainingResult]): +def render_training_results_summary(training_results: list[TrainingResult | AutogluonTrainingResult]): """Render summary metrics for training results.""" st.header("📊 Training Results Summary") col1, col2, col3, col4 = st.columns(4) @@ -23,7 +24,7 @@ def render_training_results_summary(training_results: list[TrainingResult]): st.metric("Total Runs", len(training_results)) with col3: - models = {tr.settings.model for tr in training_results} + models = {tr.settings.model for tr in training_results if hasattr(tr.settings, "model")} st.metric("Model Types", len(models)) with col4: @@ -33,14 +34,14 @@ def render_training_results_summary(training_results: list[TrainingResult]): @st.fragment -def render_experiment_results(training_results: list[TrainingResult]): # noqa: C901 +def render_experiment_results(training_results: list[TrainingResult | AutogluonTrainingResult]): # noqa: C901 """Render detailed experiment results table and expandable details.""" st.header("🎯 Experiment Results") # Filters experiments = sorted({tr.experiment for tr in training_results if tr.experiment}) tasks = sorted({tr.settings.task for tr in training_results}) - models = sorted({tr.settings.model for tr in training_results}) + models = sorted({tr.settings.model if isinstance(tr, TrainingResult) else "autogluon" for tr in training_results}) grids = sorted({f"{tr.settings.grid}-{tr.settings.level}" for tr in training_results}) # Create filter columns @@ -87,14 +88,26 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: filtered_results = [tr for tr in filtered_results if tr.experiment == selected_experiment] if selected_task != "All": filtered_results = [tr for tr in filtered_results if tr.settings.task == selected_task] - if selected_model != "All": - filtered_results = [tr for tr in filtered_results if tr.settings.model == selected_model] + if selected_model != "All" and selected_model != "autogluon": + filtered_results = [ + tr for tr in filtered_results if isinstance(tr, TrainingResult) and tr.settings.model == selected_model + ] + elif selected_model == "autogluon": + filtered_results = [tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)] if selected_grid != "All": filtered_results = [tr for tr in filtered_results if f"{tr.settings.grid}-{tr.settings.level}" == selected_grid] st.subheader("Results Table") - summary_df = TrainingResult.to_dataframe(filtered_results) + summary_df = TrainingResult.to_dataframe([tr for tr in filtered_results if isinstance(tr, TrainingResult)]) + autogluon_df = AutogluonTrainingResult.to_dataframe( + [tr for tr in filtered_results if isinstance(tr, AutogluonTrainingResult)] + ) + if len(summary_df) == 0: + summary_df = autogluon_df + elif len(autogluon_df) > 0: + summary_df = pd.concat([summary_df, autogluon_df], ignore_index=True) + # Display with color coding for best scores st.dataframe( summary_df, @@ -107,6 +120,8 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: for tr in filtered_results: tr_info = tr.display_info display_name = tr_info.get_display_name("model_first") + model = "autogluon" if isinstance(tr, AutogluonTrainingResult) else tr.settings.model + cv_splits = tr.settings.cv_splits if hasattr(tr.settings, "cv_splits") else "N/A" with st.expander(display_name): col1, col2 = st.columns([1, 2]) @@ -117,12 +132,12 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: f"- **Experiment:** {tr.experiment}\n" f"- **Task:** {tr.settings.task}\n" f"- **Target:** {tr.settings.target}\n" - f"- **Model:** {tr.settings.model}\n" + f"- **Model:** {model}\n" f"- **Grid:** {grid_config.display_name}\n" f"- **Created At:** {tr_info.timestamp.strftime('%Y-%m-%d %H:%M')}\n" f"- **Temporal Mode:** {tr.settings.temporal_mode}\n" f"- **Members:** {', '.join(tr.settings.members)}\n" - f"- **CV Splits:** {tr.settings.cv_splits}\n" + f"- **CV Splits:** {cv_splits}\n" f"- **Classes:** {tr.settings.classes}\n" ) @@ -140,26 +155,29 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa: file_str += f"- 📄 `{file.name}`\n" st.write(file_str) with col2: - st.write("**CV Score Summary:**") - - # Extract all test scores - metric_df = tr.get_metric_dataframe() - if metric_df is not None: - st.dataframe(metric_df, width="stretch", hide_index=True) + if isinstance(tr, AutogluonTrainingResult): + st.write("**Leaderboard:**") + st.dataframe(tr.leaderboard, width="stretch", hide_index=True) else: - st.write("No test scores found in results.") + st.write("**CV Score Summary:**") + # Extract all test scores + metric_df = tr.get_metric_dataframe() + if metric_df is not None: + st.dataframe(metric_df, width="stretch", hide_index=True) + else: + st.write("No test scores found in results.") - # Show parameter space explored - if "initial_K" in tr.results.columns: # Common parameter - st.write("\n**Parameter Ranges Explored:**") - for param in ["initial_K", "eps_cl", "eps_e"]: - if param in tr.results.columns: - min_val = tr.results[param].min() - max_val = tr.results[param].max() - unique_vals = tr.results[param].nunique() - st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") + # Show parameter space explored + if "initial_K" in tr.results.columns: # Common parameter + st.write("\n**Parameter Ranges Explored:**") + for param in ["initial_K", "eps_cl", "eps_e"]: + if param in tr.results.columns: + min_val = tr.results[param].min() + max_val = tr.results[param].max() + unique_vals = tr.results[param].nunique() + st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})") - st.write("**CV Results DataFrame:**") - st.dataframe(tr.results, width="stretch", hide_index=True) + st.write("**CV Results DataFrame:**") + st.dataframe(tr.results, width="stretch", hide_index=True) st.write(f"\n**Path:** `{tr.path}`") diff --git a/src/entropice/dashboard/sections/hparam_space.py b/src/entropice/dashboard/sections/hparam_space.py new file mode 100644 index 0000000..194202a --- /dev/null +++ b/src/entropice/dashboard/sections/hparam_space.py @@ -0,0 +1,172 @@ +"""Hyperparameter Space Visualization Section.""" + +import streamlit as st + +from entropice.dashboard.plots.hyperparameter_space import ( + plot_parameter_correlations, + plot_parameter_distributions, + plot_parameter_interactions, + plot_score_evolution, + plot_score_vs_parameters, +) +from entropice.dashboard.utils.formatters import format_metric_name +from entropice.dashboard.utils.loaders import TrainingResult + + +def _render_performance_summary(results, refit_metric: str): + """Render performance summary subsection.""" + best_idx = results[f"mean_test_{refit_metric}"].idxmax() + best_row = results.loc[best_idx] + # Extract parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + best_params = {col.replace("param_", ""): best_row[col] for col in param_cols} + + # Display best parameter combination + if not best_params: + return + + with st.container(border=True): + st.subheader("🏆 Best Parameter Combination") + st.caption(f"Parameters of the best model (selected by {format_metric_name(refit_metric)} score)") + n_params = len(best_params) + cols = st.columns(n_params) + for idx, (param_name, param_value) in enumerate(best_params.items()): + with cols[idx]: + # Format value based on type and magnitude + if isinstance(param_value, int): + formatted_value = f"{param_value:.0f}" + elif isinstance(param_value, float): + # Use scientific notation for very small numbers + if abs(param_value) < 0.001 and param_value != 0: + formatted_value = f"{param_value:.2e}" + else: + formatted_value = f"{param_value:.4f}" + else: + formatted_value = str(param_value) + + st.metric(param_name, formatted_value) + + +def _render_parameter_distributions(results, param_grid: dict | None): + """Render parameter distributions subsection.""" + st.subheader("Parameter Distributions") + st.caption("Distribution of hyperparameter values explored during random search") + + param_charts = plot_parameter_distributions(results, param_grid) + + if not param_charts: + st.info("No parameter distribution data available.") + return + + # Display charts in a grid + param_names = list(param_charts.keys()) + n_cols = min(3, len(param_names)) + n_rows = (len(param_names) + n_cols - 1) // n_cols + + for row in range(n_rows): + cols = st.columns(n_cols) + for col_idx in range(n_cols): + param_idx = row * n_cols + col_idx + if param_idx < len(param_names): + param_name = param_names[param_idx] + with cols[col_idx]: + st.plotly_chart(param_charts[param_name], width="stretch") + + +def _render_score_evolution(results, selected_metric: str): + """Render score evolution subsection.""" + st.subheader("Score Evolution Over Iterations") + st.caption(f"How {format_metric_name(selected_metric)} evolved during the random search") + + evolution_chart = plot_score_evolution(results, selected_metric) + if evolution_chart: + st.plotly_chart(evolution_chart, width="stretch") + else: + st.warning(f"Score evolution not available for metric: {selected_metric}") + + +def _render_score_vs_parameters(results, selected_metric: str, param_grid: dict | None): + """Render score vs parameters subsection.""" + st.subheader("Score vs Individual Parameters") + st.caption(f"Relationship between {format_metric_name(selected_metric)} and each hyperparameter") + + score_vs_param_charts = plot_score_vs_parameters(results, selected_metric, param_grid) + + if not score_vs_param_charts: + st.info("No score vs parameter data available.") + return + + param_names = list(score_vs_param_charts.keys()) + n_cols = min(2, len(param_names)) + n_rows = (len(param_names) + n_cols - 1) // n_cols + + for row in range(n_rows): + cols = st.columns(n_cols) + for col_idx in range(n_cols): + param_idx = row * n_cols + col_idx + if param_idx < len(param_names): + param_name = param_names[param_idx] + with cols[col_idx]: + st.plotly_chart(score_vs_param_charts[param_name], width="stretch") + + +def _render_parameter_correlations(results, selected_metric: str): + """Render parameter correlations subsection.""" + st.subheader("Parameter-Score Correlations") + st.caption(f"Correlation between numeric parameters and {format_metric_name(selected_metric)}") + + corr_chart = plot_parameter_correlations(results, selected_metric) + if corr_chart: + st.plotly_chart(corr_chart, width="stretch") + else: + st.info("No numeric parameters found for correlation analysis.") + + +def _render_parameter_interactions(results, selected_metric: str, param_grid: dict | None): + """Render parameter interactions subsection.""" + st.subheader("Parameter Interactions") + st.caption(f"Interaction between parameter pairs and their effect on {format_metric_name(selected_metric)}") + + interaction_charts = plot_parameter_interactions(results, selected_metric, param_grid) + + if not interaction_charts: + st.info("Not enough numeric parameters for parameter interaction visualization.") + return + + n_cols = min(2, len(interaction_charts)) + n_rows = (len(interaction_charts) + n_cols - 1) // n_cols + + for row in range(n_rows): + cols = st.columns(n_cols) + for col_idx in range(n_cols): + chart_idx = row * n_cols + col_idx + if chart_idx < len(interaction_charts): + with cols[col_idx]: + st.plotly_chart(interaction_charts[chart_idx], width="stretch") + + +def render_hparam_space_section(selected_result: TrainingResult, selected_metric: str): + """Render the hyperparameter space visualization section. + + Args: + selected_result: The selected TrainingResult object. + selected_metric: The metric to focus analysis on. + + """ + st.header("🧩 Hyperparameter Space Exploration") + + results = selected_result.results + refit_metric = selected_result._get_best_metric_name() + param_grid = selected_result.settings.param_grid + + _render_performance_summary(results, refit_metric) + + _render_parameter_distributions(results, param_grid) + + _render_score_evolution(results, selected_metric) + + _render_score_vs_parameters(results, selected_metric, param_grid) + + _render_parameter_correlations(results, selected_metric) + + _render_parameter_interactions(results, selected_metric, param_grid) diff --git a/src/entropice/dashboard/sections/regression_analysis.py b/src/entropice/dashboard/sections/regression_analysis.py new file mode 100644 index 0000000..c819d06 --- /dev/null +++ b/src/entropice/dashboard/sections/regression_analysis.py @@ -0,0 +1,122 @@ +"""Regression Analysis Section.""" + +import streamlit as st + +from entropice.dashboard.plots.regression import plot_regression_scatter, plot_residuals +from entropice.dashboard.utils.loaders import TrainingResult +from entropice.ml.dataset import DatasetEnsemble + + +def render_regression_analysis(selected_result: TrainingResult): + """Render regression analysis with true vs predicted scatter plots. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("📊 Regression Analysis") + + # Check if this is a regression task + if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]: + st.info("📈 Regression analysis is only available for regression tasks (count, density).") + return + + # Load predictions + predictions_df = selected_result.load_predictions() + if predictions_df is None: + st.warning("⚠️ No prediction data found for this training result.") + return + + # Create DatasetEnsemble from settings + with st.spinner("Loading training data to get true values..."): + ensemble = DatasetEnsemble( + grid=selected_result.settings.grid, + level=selected_result.settings.level, + members=selected_result.settings.members, + temporal_mode=selected_result.settings.temporal_mode, + dimension_filters=selected_result.settings.dimension_filters, + variable_filters=selected_result.settings.variable_filters, + add_lonlat=selected_result.settings.add_lonlat, + ) + + # Create training set to get true values + training_set = ensemble.create_training_set( + task=selected_result.settings.task, + target=selected_result.settings.target, + device="cpu", + cache_mode="read", + ) + + # Get split information + split_series = training_set.split + + # Merge predictions with true values and split info + # predictions_df should have 'cell_id' and 'predicted' columns + # training_set.targets has 'y' (true values) with cell_id as index + true_values = training_set.targets[["y"]].reset_index() + + # Merge on cell_id + merged = predictions_df.merge(true_values, on="cell_id", how="inner") + merged["split"] = split_series.reindex(merged["cell_id"]).values + + # Get train, test, and combined data + train_data = merged[merged["split"] == "train"] + test_data = merged[merged["split"] == "test"] + + if len(train_data) == 0 or len(test_data) == 0: + st.error("❌ Could not properly split data into train and test sets.") + return + + # Display scatter plots + st.subheader("True vs Predicted Values") + st.caption("Scatter plots showing the relationship between true and predicted values") + + cols = st.columns(3) + + with cols[0]: + st.markdown("#### Test Set") + st.caption("Held-out test set") + fig_test = plot_regression_scatter( + test_data["y"], + test_data["predicted"], + title="Test Set", + ) + st.plotly_chart(fig_test, use_container_width=True) + + with cols[1]: + st.markdown("#### Training Set") + st.caption("Training set") + fig_train = plot_regression_scatter( + train_data["y"], + train_data["predicted"], + title="Training Set", + ) + st.plotly_chart(fig_train, use_container_width=True) + + with cols[2]: + st.markdown("#### Combined") + st.caption("Train + Test sets") + fig_combined = plot_regression_scatter( + merged["y"], + merged["predicted"], + title="Combined", + ) + st.plotly_chart(fig_combined, use_container_width=True) + + # Display residual plots + st.subheader("Residual Analysis") + st.caption("Residual plots to assess model fit and identify patterns in errors") + + cols = st.columns(3) + + with cols[0]: + fig_test_res = plot_residuals(test_data["y"], test_data["predicted"], title="Test Set Residuals") + st.plotly_chart(fig_test_res, use_container_width=True) + + with cols[1]: + fig_train_res = plot_residuals(train_data["y"], train_data["predicted"], title="Training Set Residuals") + st.plotly_chart(fig_train_res, use_container_width=True) + + with cols[2]: + fig_combined_res = plot_residuals(merged["y"], merged["predicted"], title="Combined Residuals") + st.plotly_chart(fig_combined_res, use_container_width=True) diff --git a/src/entropice/dashboard/utils/class_ordering.py b/src/entropice/dashboard/utils/class_ordering.py deleted file mode 100644 index 79fd6f4..0000000 --- a/src/entropice/dashboard/utils/class_ordering.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Utilities for ordering predicted classes consistently across visualizations. - -This module leverages the canonical class labels defined in the ML dataset module -to ensure consistent ordering across all visualizations. -""" - -import pandas as pd - -from entropice.utils.types import Task - -# Canonical orderings imported from the ML pipeline -# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"} -# Count/Density labels are defined in the bin_values function -BINARY_LABELS = ["No RTS", "RTS"] -COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"] -DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"] - -CLASS_ORDERINGS: dict[Task | str, list[str]] = { - "binary": BINARY_LABELS, - "count": COUNT_LABELS, - "density": DENSITY_LABELS, -} - - -def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]: - """Get properly ordered class labels for a given task. - - This uses the same canonical ordering as defined in the ML dataset module, - ensuring consistency between training and inference visualizations. - - Args: - task: Task type ('binary', 'count', 'density'). - available_classes: Optional list of available classes to filter and order. - If None, returns all canonical classes for the task. - - Returns: - List of class labels in proper order. - - Examples: - >>> get_ordered_classes("binary") - ['No RTS', 'RTS'] - >>> get_ordered_classes("count", ["None", "Few", "Several"]) - ['None', 'Few', 'Several'] - - """ - canonical_order = CLASS_ORDERINGS[task] - - if available_classes is None: - return canonical_order - - # Filter canonical order to only include available classes, preserving order - return [cls for cls in canonical_order if cls in available_classes] - - -def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series: - """Sort a pandas Series with class labels according to canonical ordering. - - Args: - series: Pandas Series with class labels as index. - task: Task type ('binary', 'count', 'density'). - - Returns: - Sorted Series with classes in canonical order. - - """ - available_classes = series.index.tolist() - ordered_classes = get_ordered_classes(task, available_classes) - - # Reindex to get proper order - return series.reindex(ordered_classes) diff --git a/src/entropice/dashboard/utils/formatters.py b/src/entropice/dashboard/utils/formatters.py index bad2965..4408314 100644 --- a/src/entropice/dashboard/utils/formatters.py +++ b/src/entropice/dashboard/utils/formatters.py @@ -59,7 +59,7 @@ task_display_infos: dict[Task, TaskDisplayInfo] = { class TrainingResultDisplayInfo: task: Task target: TargetDataset - model: Model + model: Model | Literal["autogluon"] grid: Grid level: int timestamp: datetime diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index fe34802..a0e6fbe 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -17,6 +17,7 @@ from shapely.geometry import shape import entropice.spatial.grids import entropice.utils.paths from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo +from entropice.ml.autogluon_training import AutoGluonTrainingSettings from entropice.ml.dataset import DatasetEnsemble, TrainingSet from entropice.ml.training import TrainingSettings from entropice.utils.types import GridConfig, TargetDataset, Task, all_target_datasets, all_tasks @@ -215,14 +216,18 @@ class TrainingResult: return pd.DataFrame.from_records(records) -@st.cache_data +@st.cache_data(ttl=300) # Cache for 5 minutes def load_all_training_results() -> list[TrainingResult]: """Load all training results from the results directory.""" results_dir = entropice.utils.paths.RESULTS_DIR training_results: list[TrainingResult] = [] + incomplete_results: list[tuple[Path, Exception]] = [] for result_path in results_dir.iterdir(): if not result_path.is_dir(): continue + # Skip AutoGluon results directory + if "autogluon" in result_path.name.lower(): + continue try: training_result = TrainingResult.from_path(result_path) training_results.append(training_result) @@ -237,10 +242,159 @@ def load_all_training_results() -> list[TrainingResult]: training_results.append(training_result) is_experiment_dir = True except FileNotFoundError as e2: - st.warning(f"Skipping incomplete training result: {e2}") + incomplete_results.append((experiment_path, e2)) if not is_experiment_dir: - st.warning(f"Skipping incomplete training result: {e}") + incomplete_results.append((result_path, e)) + if len(incomplete_results) > 0: + st.warning( + f"Found {len(incomplete_results)} incomplete training results that were skipped:\n - " + + "\n - ".join(f"{p}: {e}" for p, e in incomplete_results) + ) + # Sort by creation time (most recent first) + training_results.sort(key=lambda tr: tr.created_at, reverse=True) + return training_results + + +@dataclass +class AutogluonTrainingResult: + """Wrapper for training result data and metadata.""" + + path: Path + experiment: str + settings: AutoGluonTrainingSettings + test_metrics: dict[str, float | dict | pd.DataFrame] + leaderboard: pd.DataFrame + feature_importance: pd.DataFrame | None + created_at: float + files: list[Path] + + @classmethod + def from_path(cls, result_path: Path, experiment_name: str | None = None) -> "AutogluonTrainingResult": + """Load an AutogluonTrainingResult from a given result directory path.""" + settings_file = result_path / "training_settings.toml" + metrics_file = result_path / "test_metrics.pickle" + leaderboard_file = result_path / "leaderboard.parquet" + feature_importance_file = result_path / "feature_importance.parquet" + all_files = list(result_path.iterdir()) + if not settings_file.exists(): + raise FileNotFoundError(f"Missing settings file in {result_path}") + if not metrics_file.exists(): + raise FileNotFoundError(f"Missing metrics file in {result_path}") + if not leaderboard_file.exists(): + raise FileNotFoundError(f"Missing leaderboard file in {result_path}") + + created_at = result_path.stat().st_ctime + settings_dict = toml.load(settings_file)["settings"] + settings = AutoGluonTrainingSettings(**settings_dict) + with open(metrics_file, "rb") as f: + metrics = pickle.load(f) + leaderboard = pd.read_parquet(leaderboard_file) + + if feature_importance_file.exists(): + feature_importance = pd.read_parquet(feature_importance_file) + else: + feature_importance = None + + return cls( + path=result_path, + experiment=experiment_name or "N/A", + settings=settings, + test_metrics=metrics, + leaderboard=leaderboard, + feature_importance=feature_importance, + created_at=created_at, + files=all_files, + ) + + @property + def test_confusion_matrix(self) -> pd.DataFrame | None: + """Get the test confusion matrix.""" + if "confusion_matrix" not in self.test_metrics: + return None + assert isinstance(self.test_metrics["confusion_matrix"], pd.DataFrame) + return self.test_metrics["confusion_matrix"] + + @property + def display_info(self) -> TrainingResultDisplayInfo: + """Get display information for the training result.""" + return TrainingResultDisplayInfo( + task=self.settings.task, + target=self.settings.target, + model="autogluon", + grid=self.settings.grid, + level=self.settings.level, + timestamp=datetime.fromtimestamp(self.created_at), + ) + + def _get_best_metric_name(self) -> str: + """Get the primary metric name for a given task.""" + match self.settings.task: + case "binary": + return "f1" + case "count_regimes" | "density_regimes": + return "f1_weighted" + case _: # regression tasks + return "root_mean_squared_error" + + @staticmethod + def to_dataframe(training_results: list["AutogluonTrainingResult"]) -> pd.DataFrame: + """Convert a list of AutogluonTrainingResult objects to a DataFrame for display.""" + records = [] + for tr in training_results: + info = tr.display_info + best_metric_name = tr._get_best_metric_name() + + record = { + "Experiment": tr.experiment if tr.experiment else "N/A", + "Task": info.task, + "Target": info.target, + "Model": info.model, + "Grid": GridConfig.from_grid_level((info.grid, info.level)).display_name, + "Created At": info.timestamp.strftime("%Y-%m-%d %H:%M"), + "Score-Metric": best_metric_name.title(), + "Best Models Score (Test-Set)": tr.test_metrics.get(best_metric_name), + "Path": str(tr.path.name), + } + records.append(record) + return pd.DataFrame.from_records(records) + + +@st.cache_data(ttl=300) # Cache for 5 minutes +def load_all_autogluon_training_results() -> list[AutogluonTrainingResult]: + """Load all training results from the results directory.""" + results_dir = entropice.utils.paths.RESULTS_DIR + training_results: list[AutogluonTrainingResult] = [] + incomplete_results: list[tuple[Path, Exception]] = [] + for result_path in results_dir.iterdir(): + if not result_path.is_dir(): + continue + # Skip AutoGluon results directory + if "autogluon" not in result_path.name.lower(): + continue + try: + training_result = AutogluonTrainingResult.from_path(result_path) + training_results.append(training_result) + except FileNotFoundError as e: + is_experiment_dir = False + for experiment_path in result_path.iterdir(): + if not experiment_path.is_dir(): + continue + try: + experiment_name = experiment_path.parent.name + training_result = AutogluonTrainingResult.from_path(experiment_path, experiment_name) + training_results.append(training_result) + is_experiment_dir = True + except FileNotFoundError as e2: + incomplete_results.append((experiment_path, e2)) + if not is_experiment_dir: + incomplete_results.append((result_path, e)) + + if len(incomplete_results) > 0: + st.warning( + f"Found {len(incomplete_results)} incomplete autogluon training results that were skipped:\n - " + + "\n - ".join(f"{p}: {e}" for p, e in incomplete_results) + ) # Sort by creation time (most recent first) training_results.sort(key=lambda tr: tr.created_at, reverse=True) return training_results diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py index f589fd6..af2e1e4 100644 --- a/src/entropice/dashboard/views/model_state_page.py +++ b/src/entropice/dashboard/views/model_state_page.py @@ -369,6 +369,7 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result: Trainin options=["gain", "weight", "cover", "total_gain", "total_cover"], index=0, help="Choose which importance metric to visualize", + key="model_state_importance_type", ) # Top N slider diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index dded4fd..e594cee 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -9,7 +9,7 @@ from entropice.dashboard.sections.experiment_results import ( render_training_results_summary, ) from entropice.dashboard.sections.storage_statistics import render_storage_statistics -from entropice.dashboard.utils.loaders import load_all_training_results +from entropice.dashboard.utils.loaders import load_all_autogluon_training_results, load_all_training_results from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics @@ -27,6 +27,9 @@ def render_overview_page(): ) # Load training results training_results = load_all_training_results() + autogluon_results = load_all_autogluon_training_results() + if len(autogluon_results) > 0: + training_results.extend(autogluon_results) if not training_results: st.warning("No training results found. Please run some training experiments first.") diff --git a/src/entropice/dashboard/views/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py index b1ad430..4aba169 100644 --- a/src/entropice/dashboard/views/training_analysis_page.py +++ b/src/entropice/dashboard/views/training_analysis_page.py @@ -2,150 +2,22 @@ from typing import cast -import geopandas as gpd import streamlit as st -import xarray as xr -from stopuhr import stopwatch -from entropice.dashboard.plots.hyperparameter_analysis import ( - render_binned_parameter_space, - render_confusion_matrix_heatmap, - render_confusion_matrix_map, - render_espa_binned_parameter_space, - render_multi_metric_comparison, - render_parameter_correlation, - render_parameter_distributions, - render_performance_summary, - render_top_configurations, +from entropice.dashboard.sections.cv_result import ( + render_confusion_matrices, + render_cv_statistics_section, + render_metrics_section, + render_run_information, ) +from entropice.dashboard.sections.hparam_space import render_hparam_space_section +from entropice.dashboard.sections.regression_analysis import render_regression_analysis from entropice.dashboard.utils.formatters import format_metric_name from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results -from entropice.dashboard.utils.stats import CVResultsStatistics -from entropice.utils.types import GridConfig +from entropice.dashboard.utils.stats import CVMetricStatistics -def load_predictions_with_labels(selected_result: TrainingResult) -> gpd.GeoDataFrame | None: - """Load predictions and merge with training data to get true labels and split info. - - Args: - selected_result: The selected TrainingResult object. - - Returns: - GeoDataFrame with predictions, true labels, and split information, or None if unavailable. - - """ - from sklearn.model_selection import train_test_split - - from entropice.ml.dataset import DatasetEnsemble, bin_values, taskcol - - # Load predictions - preds_gdf = selected_result.load_predictions() - if preds_gdf is None: - return None - - # Create a minimal dataset ensemble to access target data - settings = selected_result.settings - dataset_ensemble = DatasetEnsemble( - grid=settings.grid, - level=settings.level, - target=settings.target, - members=[], # No feature data needed, just targets - ) - - # Load target dataset (just labels, no features) - with st.spinner("Loading target labels..."): - targets = dataset_ensemble._read_target() - - # Get coverage and task columns - task_col = taskcol[settings.task][settings.target] - - # Filter for valid labels (same as in _cat_and_split) - valid_labels = targets[task_col].notna() - filtered_targets = targets.loc[valid_labels].copy() - - # Apply binning to get class labels (same logic as _cat_and_split) - if settings.task == "binary": - binned = filtered_targets[task_col].map({False: "No RTS", True: "RTS"}).astype("category") - elif settings.task == "count": - binned = bin_values(filtered_targets[task_col].astype(int), task=settings.task) - elif settings.task == "density": - binned = bin_values(filtered_targets[task_col], task=settings.task) - else: - raise ValueError(f"Invalid task: {settings.task}") - - filtered_targets["true_class"] = binned.to_numpy() - - # Recreate the train/test split deterministically (same random_state=42 as in _cat_and_split) - _train_idx, test_idx = train_test_split( - filtered_targets.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True - ) - filtered_targets["split"] = "train" - filtered_targets.loc[test_idx, "split"] = "test" - filtered_targets["split"] = filtered_targets["split"].astype("category") - - # Ensure cell_id is available as a column for merging - # Check if cell_id already exists, otherwise use the index - if "cell_id" not in filtered_targets.columns: - filtered_targets = filtered_targets.reset_index().rename(columns={"index": "cell_id"}) - - # Merge predictions with labels (inner join to keep only cells with predictions) - merged = filtered_targets.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner") - merged_gdf = gpd.GeoDataFrame(merged, geometry="geometry", crs=targets.crs) - - return merged_gdf - - -def compute_confusion_matrix_from_merged_data( - merged_data: gpd.GeoDataFrame, - split_type: str, - label_names: list[str], -) -> xr.DataArray | None: - """Compute confusion matrix from merged predictions and labels. - - Args: - merged_data: GeoDataFrame with 'true_class', 'predicted_class', and 'split' columns. - split_type: One of 'test', 'train', or 'all'. - label_names: List of class label names in order. - - Returns: - xarray.DataArray with confusion matrix or None if data unavailable. - - """ - from sklearn.metrics import confusion_matrix - - # Filter by split type - if split_type == "train": - data = merged_data[merged_data["split"] == "train"] - elif split_type == "test": - data = merged_data[merged_data["split"] == "test"] - elif split_type == "all": - data = merged_data - else: - raise ValueError(f"Invalid split_type: {split_type}") - - if len(data) == 0: - st.warning(f"No data available for {split_type} split.") - return None - - # Get true and predicted labels - y_true = data["true_class"].to_numpy() - y_pred = data["predicted_class"].to_numpy() - - # Compute confusion matrix - cm = confusion_matrix(y_true, y_pred, labels=label_names) - - # Create xarray DataArray - cm_xr = xr.DataArray( - cm, - dims=["true_label", "predicted_label"], - coords={"true_label": label_names, "predicted_label": label_names}, - name="confusion_matrix", - ) - - return cm_xr - - -def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]: +def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str]: """Render sidebar for training run and analysis settings selection. Args: @@ -155,351 +27,63 @@ def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> Tuple of (selected_result, selected_metric, refit_metric, top_n). """ - st.header("Select Training Run") + with st.sidebar.form("training_analysis_settings_form"): + st.header("Select Training Run") - # Create selection options with task-first naming - training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results} + # Create selection options with task-first naming + training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results} - selected_name = st.selectbox( - "Training Run", - options=list(training_options.keys()), - index=0, - help="Select a training run to analyze", - key="training_run_select", - ) + selected_name = st.selectbox( + "Training Run", + options=list(training_options.keys()), + index=0, + help="Select a training run to analyze", + key="training_run_select", + ) - selected_result = cast(TrainingResult, training_options[selected_name]) + selected_result = cast(TrainingResult, training_options[selected_name]) - st.divider() - - # Metric selection for detailed analysis - st.subheader("Analysis Settings") - - available_metrics = selected_result.available_metrics - - # Try to get refit metric from settings - refit_metric = "f1" if selected_result.settings.task == "binary" else "f1_weighted" - - if refit_metric in available_metrics: - default_metric_idx = available_metrics.index(refit_metric) - else: - default_metric_idx = 0 - - selected_metric = st.selectbox( - "Primary Metric for Analysis", - options=available_metrics, - index=default_metric_idx, - format_func=format_metric_name, - help="Select the metric to focus on for detailed analysis", - key="metric_select", - ) - - # Top N configurations - top_n = st.slider( - "Top N Configurations", - min_value=5, - max_value=50, - value=10, - step=5, - help="Number of top configurations to display", - key="top_n_slider", - ) - - return selected_result, selected_metric, refit_metric, top_n - - -def render_run_information(selected_result: TrainingResult, refit_metric): - """Render training run configuration overview. - - Args: - selected_result: The selected TrainingResult object. - refit_metric: The refit metric used for model selection. - - """ - st.header("📋 Run Information") - - grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type] - - col1, col2, col3, col4, col5 = st.columns(5) - with col1: - st.metric("Task", selected_result.settings.task.capitalize()) - with col2: - st.metric("Target", selected_result.settings.target.capitalize()) - with col3: - st.metric("Grid", grid_config.display_name) - with col4: - st.metric("Model", selected_result.settings.model.upper()) - with col5: - st.metric("Trials", len(selected_result.results)) - - st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}") - - -def render_test_metrics_section(selected_result: TrainingResult): - """Render test metrics overview showing final model performance. - - Args: - selected_result: The selected TrainingResult object. - - """ - st.header("🎯 Test Set Performance") - st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)") - - test_metrics = selected_result.metrics - - if not test_metrics: - st.warning("No test metrics available for this training run.") - return - - # Display metrics in columns based on task type - task = selected_result.settings.task - - if task == "binary": - # Binary classification metrics - col1, col2, col3, col4, col5 = st.columns(5) - - with col1: - st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}") - with col2: - st.metric("F1 Score", f"{test_metrics.get('f1', 0):.4f}") - with col3: - st.metric("Precision", f"{test_metrics.get('precision', 0):.4f}") - with col4: - st.metric("Recall", f"{test_metrics.get('recall', 0):.4f}") - with col5: - st.metric("Jaccard", f"{test_metrics.get('jaccard', 0):.4f}") - else: - # Multiclass metrics - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}") - with col2: - st.metric("F1 (Macro)", f"{test_metrics.get('f1_macro', 0):.4f}") - with col3: - st.metric("F1 (Weighted)", f"{test_metrics.get('f1_weighted', 0):.4f}") - - col4, col5, col6 = st.columns(3) - - with col4: - st.metric("Precision (Macro)", f"{test_metrics.get('precision_macro', 0):.4f}") - with col5: - st.metric("Precision (Weighted)", f"{test_metrics.get('precision_weighted', 0):.4f}") - with col6: - st.metric("Recall (Macro)", f"{test_metrics.get('recall_macro', 0):.4f}") - - col7, col8, col9 = st.columns(3) - - with col7: - st.metric("Jaccard (Micro)", f"{test_metrics.get('jaccard_micro', 0):.4f}") - with col8: - st.metric("Jaccard (Macro)", f"{test_metrics.get('jaccard_macro', 0):.4f}") - with col9: - st.metric("Jaccard (Weighted)", f"{test_metrics.get('jaccard_weighted', 0):.4f}") - - -def render_cv_statistics_section(selected_result, selected_metric): - """Render cross-validation statistics for selected metric. - - Args: - selected_result: The selected TrainingResult object. - selected_metric: The metric to display statistics for. - - """ - st.header("📈 Cross-Validation Statistics") - st.caption("Performance during hyperparameter search (averaged across CV folds)") - - from entropice.dashboard.utils.stats import CVMetricStatistics - - cv_stats = CVMetricStatistics.compute(selected_result, selected_metric) - - col1, col2, col3, col4, col5 = st.columns(5) - - with col1: - st.metric("Best Score", f"{cv_stats.best_score:.4f}") - - with col2: - st.metric("Mean Score", f"{cv_stats.mean_score:.4f}") - - with col3: - st.metric("Std Dev", f"{cv_stats.std_score:.4f}") - - with col4: - st.metric("Worst Score", f"{cv_stats.worst_score:.4f}") - - with col5: - st.metric("Median Score", f"{cv_stats.median_score:.4f}") - - if cv_stats.mean_cv_std is not None: - st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds") - - # Compare with test metric if available - if selected_metric in selected_result.metrics: - test_score = selected_result.metrics[selected_metric] st.divider() - st.subheader("CV vs Test Performance") - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Best CV Score", f"{cv_stats.best_score:.4f}") - with col2: - st.metric("Test Score", f"{test_score:.4f}") - with col3: - delta = test_score - cv_stats.best_score - delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0 - st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%") + # Metric selection for detailed analysis + st.subheader("Analysis Settings") - if abs(delta) > cv_stats.std_score: - st.warning( - "⚠️ Test performance differs significantly from CV performance. " - "This may indicate overfitting or data distribution mismatch." - ) + available_metrics = selected_result.available_metrics - -@st.fragment -def render_confusion_matrix_section(selected_result: TrainingResult, merged_predictions: gpd.GeoDataFrame | None): - """Render confusion matrix visualization and analysis. - - Args: - selected_result: The selected TrainingResult object. - merged_predictions: GeoDataFrame with predictions merged with true labels and split info. - - """ - st.header("🎲 Confusion Matrix") - st.caption("Detailed breakdown of predictions") - - # Add selector for confusion matrix type - cm_type = st.selectbox( - "Select Data Split", - options=["test", "train", "all"], - format_func=lambda x: {"test": "Test Set", "train": "CV Set (Train Split)", "all": "All Available Data"}[x], - help="Choose which data split to display the confusion matrix for", - key="cm_split_select", - ) - - # Get label names from settings - label_names = selected_result.settings.classes - - # Compute or load confusion matrix based on selection - if cm_type == "test": - if selected_result.confusion_matrix is None: - st.warning("No confusion matrix available for the test set.") - return - cm = selected_result.confusion_matrix - st.info("📊 Showing confusion matrix for the **Test Set** (held-out data, never used during training)") - else: - if merged_predictions is None: - st.warning("Predictions data not available. Cannot compute confusion matrix.") - return - - with st.spinner(f"Computing confusion matrix for {cm_type} split..."): - cm = compute_confusion_matrix_from_merged_data(merged_predictions, cm_type, label_names) - if cm is None: - return - - if cm_type == "train": - st.info( - "📊 Showing confusion matrix for the **CV Set (Train Split)** " - "(data used during hyperparameter search cross-validation)" - ) - else: # all - st.info("📊 Showing confusion matrix for **All Available Data** (combined train and test splits)") - - render_confusion_matrix_heatmap(cm, selected_result.settings.task) - - -def render_parameter_space_section(selected_result, selected_metric): - """Render parameter space analysis section. - - Args: - selected_result: The selected TrainingResult object. - selected_metric: The metric to analyze parameters against. - - """ - st.header("🔍 Parameter Space Analysis") - - # Compute CV results statistics - cv_results_stats = CVResultsStatistics.compute(selected_result) - - # Show parameter space summary - with st.expander("📋 Parameter Space Summary", expanded=False): - param_summary_df = cv_results_stats.parameters_to_dataframe() - if not param_summary_df.empty: - st.dataframe(param_summary_df, hide_index=True, width="stretch") + # Try to get refit metric from settings + if selected_result.settings.task == "binary": + refit_metric = "f1" + elif selected_result.settings.task in ["count_regimes", "density_regimes"]: + refit_metric = "f1_weighted" else: - st.info("No parameter information available.") + refit_metric = "r2" - results = selected_result.results - settings = selected_result.settings + if refit_metric in available_metrics: + default_metric_idx = available_metrics.index(refit_metric) + else: + default_metric_idx = 0 - # Parameter distributions - st.subheader("📈 Parameter Distributions") - render_parameter_distributions(results, settings) + selected_metric = st.selectbox( + "Primary Metric for Analysis", + options=available_metrics, + index=default_metric_idx, + format_func=format_metric_name, + help="Select the metric to focus on for detailed analysis", + key="metric_select", + ) - # Binned parameter space plots - st.subheader("🎨 Binned Parameter Space") + # Form submit button + submitted = st.form_submit_button( + "Load Training Result", + type="primary", + use_container_width=True, + ) - # Check if this is an ESPA model and show ESPA-specific plots - model_type = settings.model - if model_type == "espa": - # Show ESPA-specific binned plots (eps_cl vs eps_e binned by K) - render_espa_binned_parameter_space(results, selected_metric) + if not submitted: + st.info("👆 Click 'Load Training Result' to apply changes.") + st.stop() - # Optionally show the generic binned plots in an expander - with st.expander("📊 All Parameter Combinations", expanded=False): - st.caption("Generic parameter space exploration (all pairwise combinations)") - render_binned_parameter_space(results, selected_metric) - else: - # For non-ESPA models, show the generic binned plots - render_binned_parameter_space(results, selected_metric) - - -def render_data_export_section(results, selected_result): - """Render data export section with download buttons. - - Args: - results: DataFrame with CV results. - selected_result: The selected TrainingResult object. - - """ - with st.expander("💾 Export Data", expanded=False): - st.subheader("Download Results") - - col1, col2 = st.columns(2) - - with col1: - # Download full results as CSV - csv_data = results.to_csv(index=False) - st.download_button( - label="📥 Download Full Results (CSV)", - data=csv_data, - file_name=f"{selected_result.path.name}_results.csv", - mime="text/csv", - ) - - with col2: - # Download settings as JSON - import json - - settings_dict = { - "task": selected_result.settings.task, - "grid": selected_result.settings.grid, - "level": selected_result.settings.level, - "model": selected_result.settings.model, - "cv_splits": selected_result.settings.cv_splits, - "classes": selected_result.settings.classes, - } - settings_json = json.dumps(settings_dict, indent=2) - st.download_button( - label="⚙️ Download Settings (JSON)", - data=settings_json, - file_name=f"{selected_result.path.name}_settings.json", - mime="application/json", - ) - - # Show raw data preview - st.subheader("Raw Data Preview") - st.dataframe(results.head(100), width="stretch") + return selected_result, selected_metric, refit_metric def render_training_analysis_page(): @@ -513,91 +97,47 @@ def render_training_analysis_page(): """ ) - # Load all available training results + # Load training results training_results = load_all_training_results() if not training_results: st.warning("No training results found. Please run some training experiments first.") - st.info("Run training using: `pixi run python -m entropice.ml.training`") + st.stop() return - st.success(f"Found **{len(training_results)}** training result(s)") + st.write(f"Found **{len(training_results)}** training result(s)") st.divider() + selected_result, selected_metric, refit_metric = render_analysis_settings_sidebar(training_results) - # Sidebar: Training run selection - with st.sidebar: - selection_result = render_analysis_settings_sidebar(training_results) - if selection_result[0] is None: - return - selected_result, selected_metric, refit_metric, top_n = selection_result + cv_statistics = CVMetricStatistics.compute(selected_result, selected_metric) - # Load predictions with labels once (used by confusion matrix and map) - merged_predictions = load_predictions_with_labels(selected_result) - - # Main content area - results = selected_result.results - settings = selected_result.settings - - # Run Information render_run_information(selected_result, refit_metric) st.divider() - # Test Metrics Section - render_test_metrics_section(selected_result) + render_metrics_section(selected_result) st.divider() - # Confusion Matrix Section - render_confusion_matrix_section(selected_result, merged_predictions) + # Render confusion matrices for classification, regression analysis for regression + if selected_result.settings.task in ["binary", "count_regimes", "density_regimes"]: + render_confusion_matrices(selected_result) + else: + render_regression_analysis(selected_result) st.divider() - # Performance Summary Section - st.header("📊 CV Performance Overview") - st.caption("Summary of hyperparameter search results across all configurations") - render_performance_summary(results, refit_metric) + render_cv_statistics_section(cv_statistics, selected_result.test_metrics.get(selected_metric, float("nan"))) st.divider() - # Prediction Analysis Map Section - st.header("🗺️ Model Performance Map") - st.caption("Interactive 3D map showing prediction correctness across the training dataset") - render_confusion_matrix_map(selected_result.path, settings, merged_predictions) + render_hparam_space_section(selected_result, selected_metric) st.divider() - # Cross-Validation Statistics - render_cv_statistics_section(selected_result, selected_metric) - - st.divider() - - # Parameter Space Analysis - render_parameter_space_section(selected_result, selected_metric) - - st.divider() - - # Parameter Correlation - st.header("🔗 Parameter Correlation") - render_parameter_correlation(results, selected_metric) - - st.divider() - - # Multi-Metric Comparison - if len(selected_result.available_metrics) >= 2: - st.header("📊 Multi-Metric Comparison") - render_multi_metric_comparison(results) - st.divider() - - # Top Configurations - st.header("🏆 Top Performing Configurations") - render_top_configurations(results, selected_metric, top_n) - - st.divider() - - # Raw Data Export - render_data_export_section(results, selected_result) + # List all results at the end + st.header("📄 All Training Results") + st.dataframe(selected_result.results) st.balloons() - stopwatch.summary() diff --git a/src/entropice/ml/autogluon_training.py b/src/entropice/ml/autogluon_training.py index 69ba846..ed33a67 100644 --- a/src/entropice/ml/autogluon_training.py +++ b/src/entropice/ml/autogluon_training.py @@ -44,8 +44,8 @@ class AutoGluonSettings: class AutoGluonTrainingSettings(DatasetEnsemble, AutoGluonSettings): """Combined settings for AutoGluon training.""" - classes: list[str] | None - problem_type: str + classes: list[str] | None = None + problem_type: str = "binary" def _determine_problem_type_and_metric(task: Task) -> tuple[str, str]: @@ -177,6 +177,8 @@ def autogluon_train( toml.dump({"settings": asdict(combined_settings)}, f) # Save test metrics + # We need to use pickle here, because the confusion matrix is stored as a dataframe + # This only matters for classification tasks test_metrics_file = results_dir / "test_metrics.pickle" print(f"💾 Saving test metrics to {test_metrics_file}") with open(test_metrics_file, "wb") as f: