diff --git a/src/entropice/dashboard/model_state_page.py b/src/entropice/dashboard/model_state_page.py index e379d83..88558fa 100644 --- a/src/entropice/dashboard/model_state_page.py +++ b/src/entropice/dashboard/model_state_page.py @@ -1,8 +1,322 @@ +"""Model State page for the Entropice dashboard.""" + import streamlit as st +from entropice.dashboard.plots.colors import generate_unified_colormap +from entropice.dashboard.plots.model_state import ( + plot_box_assignment_bars, + plot_box_assignments, + plot_common_features, + plot_embedding_aggregation_summary, + plot_embedding_heatmap, + plot_era5_heatmap, + plot_era5_summary, + plot_top_features, +) +from entropice.dashboard.utils.data import ( + extract_common_features, + extract_embedding_features, + extract_era5_features, + load_all_training_results, +) +from entropice.dashboard.utils.training import load_model_state + def render_model_state_page(): """Render the Model State page of the dashboard.""" st.title("Model State") - st.write("This page will display model state and feature visualizations.") - # Add more components and visualizations as needed for model state. + st.markdown("Comprehensive visualization of the best model's internal state and feature importance") + + # Load available training results + training_results = load_all_training_results() + + if not training_results: + st.error("No training results found. Please run a training search first.") + return + + # Result selection + result_options = {tr.name: tr for tr in training_results} + selected_name = st.selectbox( + "Select Training Result", + options=list(result_options.keys()), + help="Choose a training result to visualize model state", + ) + selected_result = result_options[selected_name] + + # Load model state + with st.spinner("Loading model state..."): + model_state = load_model_state(selected_result) + if model_state is None: + st.error("Could not load model state for this result.") + return + + # Scale feature weights by number of features + n_features = model_state.sizes["feature"] + model_state["feature_weights"] *= n_features + + # Extract different feature types + embedding_feature_array = extract_embedding_features(model_state) + era5_feature_array = extract_era5_features(model_state) + common_feature_array = extract_common_features(model_state) + + # Generate unified colormaps + _, _, altair_colors = generate_unified_colormap(selected_result.settings) + + # Display basic model state info + with st.expander("Model State Information", expanded=False): + st.write(f"**Variables:** {list(model_state.data_vars)}") + st.write(f"**Dimensions:** {dict(model_state.sizes)}") + st.write(f"**Coordinates:** {list(model_state.coords)}") + + # Show statistics + st.write("**Feature Weight Statistics:**") + feature_weights = model_state["feature_weights"].to_pandas() + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Mean Weight", f"{feature_weights.mean():.4f}") + with col2: + st.metric("Max Weight", f"{feature_weights.max():.4f}") + with col3: + st.metric("Total Features", len(feature_weights)) + + # Feature importance section + st.header("Feature Importance") + st.markdown("The most important features based on learned feature weights from the best estimator.") + + @st.fragment + def render_feature_importance(): + # Slider to control number of features to display + top_n = st.slider( + "Number of top features to display", + min_value=5, + max_value=50, + value=10, + step=5, + help="Select how many of the most important features to visualize", + ) + + with st.spinner("Generating feature importance plot..."): + feature_chart = plot_top_features(model_state, top_n=top_n) + st.altair_chart(feature_chart, use_container_width=True) + + st.markdown( + """ + **Interpretation:** + - **Magnitude**: Larger absolute values indicate more important features + - **Color**: Blue bars indicate positive weights, coral bars indicate negative weights + """ + ) + + render_feature_importance() + + # Box-to-Label Assignment Visualization + st.header("Box-to-Label Assignments") + st.markdown( + """ + This visualization shows how the learned boxes (prototypes in feature space) are + assigned to different class labels. The ESPA classifier learns K boxes and assigns + them to classes through the Lambda matrix. Higher values indicate stronger assignment + of a box to a particular class. + """ + ) + + with st.spinner("Generating box assignment visualizations..."): + col1, col2 = st.columns([0.7, 0.3]) + + with col1: + st.markdown("### Assignment Heatmap") + box_assignment_heatmap = plot_box_assignments(model_state) + st.altair_chart(box_assignment_heatmap, use_container_width=True) + + with col2: + st.markdown("### Box Count by Class") + box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors) + st.altair_chart(box_assignment_bars, use_container_width=True) + + # Show statistics + with st.expander("Box Assignment Statistics"): + box_assignments = model_state["box_assignments"].to_pandas() + st.write("**Assignment Matrix Statistics:**") + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Boxes", len(box_assignments.columns)) + with col2: + st.metric("Number of Classes", len(box_assignments.index)) + with col3: + st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}") + with col4: + st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}") + + # Show which boxes are most strongly assigned to each class + st.write("**Top Box Assignments per Class:**") + for class_label in box_assignments.index: + top_boxes = box_assignments.loc[class_label].nlargest(5) + st.write( + f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} " + f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})" + ) + + st.markdown( + """ + **Interpretation:** + - Each box can be assigned to multiple classes with different strengths + - Boxes with higher assignment values for a class contribute more to that class's predictions + - The distribution shows how the model partitions the feature space for classification + """ + ) + + # Embedding features analysis (if present) + if embedding_feature_array is not None: + with st.container(border=True): + st.header("πŸ›°οΈ Embedding Feature Analysis") + st.markdown( + """ + Analysis of embedding features showing which aggregations, bands, and years + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating dimension summaries..."): + chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array) + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_agg, use_container_width=True) + with col2: + st.altair_chart(chart_band, use_container_width=True) + with col3: + st.altair_chart(chart_year, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap by Aggregation") + st.markdown("Shows the weight of each band-year combination for each aggregation type.") + with st.spinner("Generating heatmap..."): + heatmap_chart = plot_embedding_heatmap(embedding_feature_array) + st.altair_chart(heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("Embedding Feature Statistics"): + st.write("**Overall Statistics:**") + n_emb_features = embedding_feature_array.size + mean_weight = float(embedding_feature_array.mean().values) + max_weight = float(embedding_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Embedding Features", n_emb_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top embedding features + st.write("**Top 10 Embedding Features:**") + emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() + top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] + st.dataframe(top_emb, width="stretch") + else: + st.info("No embedding features found in this model.") + + # ERA5 features analysis (if present) + if era5_feature_array is not None: + with st.container(border=True): + st.header("β›… ERA5 Feature Analysis") + st.markdown( + """ + Analysis of ERA5 climate features showing which variables and time periods + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating ERA5 dimension summaries..."): + chart_variable, chart_time = plot_era5_summary(era5_feature_array) + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, use_container_width=True) + with col2: + st.altair_chart(chart_time, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap") + st.markdown("Shows the weight of each variable-time combination.") + with st.spinner("Generating ERA5 heatmap..."): + era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("ERA5 Feature Statistics"): + st.write("**Overall Statistics:**") + n_era5_features = era5_feature_array.size + mean_weight = float(era5_feature_array.mean().values) + max_weight = float(era5_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total ERA5 Features", n_era5_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top ERA5 features + st.write("**Top 10 ERA5 Features:**") + era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() + top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] + st.dataframe(top_era5, width="stretch") + else: + st.info("No ERA5 features found in this model.") + + # Common features analysis (if present) + if common_feature_array is not None: + with st.container(border=True): + st.header("πŸ—ΊοΈ Common Feature Analysis") + st.markdown( + """ + Analysis of common features including cell area, water area, land area, land ratio, + longitude, and latitude. These features provide spatial and geographic context. + """ + ) + + # Bar chart showing all common feature weights + with st.spinner("Generating common features chart..."): + common_chart = plot_common_features(common_feature_array) + st.altair_chart(common_chart, use_container_width=True) + + # Statistics + with st.expander("Common Feature Statistics"): + st.write("**Overall Statistics:**") + n_common_features = common_feature_array.size + mean_weight = float(common_feature_array.mean().values) + max_weight = float(common_feature_array.max().values) + min_weight = float(common_feature_array.min().values) + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Common Features", n_common_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + with col4: + st.metric("Min Weight", f"{min_weight:.4f}") + + # Show all common features sorted by importance + st.write("**All Common Features (by absolute weight):**") + common_df = common_feature_array.to_dataframe(name="weight").reset_index() + common_df["abs_weight"] = common_df["weight"].abs() + common_df = common_df.sort_values("abs_weight", ascending=False) + st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") + + st.markdown( + """ + **Interpretation:** + - **cell_area, water_area, land_area**: Spatial extent features that may indicate + size-related patterns + - **land_ratio**: Proportion of land vs water in each cell + - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns + - Positive weights indicate features that increase the probability of the positive class + - Negative weights indicate features that decrease the probability of the positive class + """ + ) + else: + st.info("No common features found in this model.") diff --git a/src/entropice/dashboard/plots/colors.py b/src/entropice/dashboard/plots/colors.py index 006aca5..3924348 100644 --- a/src/entropice/dashboard/plots/colors.py +++ b/src/entropice/dashboard/plots/colors.py @@ -25,6 +25,9 @@ Material palettes: """ import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import streamlit as st from pypalettes import load_cmap @@ -90,3 +93,54 @@ def get_palette(variable: str, n_colors: int) -> list[str]: cmap = get_cmap(variable).resampled(n_colors) colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)] return colors + + +def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]: + """Generate unified colormaps for all plotting libraries. + + This function creates consistent color schemes across Matplotlib/Ultraplot, + Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number + of classes from the settings, then generating appropriate colormaps for each library. + + Args: + settings: Settings dictionary containing task type, classes, and other configuration. + + Returns: + Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where: + - matplotlib_cmap: matplotlib ListedColormap object + - folium_cmap: matplotlib ListedColormap object (for geopandas.explore) + - altair_colors: list of hex color strings for Altair + + """ + # Determine task type and number of classes from settings + task = settings.get("task", "binary") + n_classes = len(settings.get("classes", [])) + + # Check theme + is_dark_theme = st.context.theme.type == "dark" + + # Define base colormaps for different tasks + if task == "binary": + # For binary: use a simple two-color scheme + if is_dark_theme: + base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme + else: + base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme + else: + # For multi-class: use a sequential colormap + # Use matplotlib's viridis colormap + cmap = plt.get_cmap("viridis") + # Sample colors evenly across the colormap + indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends + base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices] + + # Create matplotlib colormap (for ultraplot and geopandas) + matplotlib_cmap = mcolors.ListedColormap(base_colors) + + # Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps) + folium_cmap = mcolors.ListedColormap(base_colors) + + # Create Altair color list (Altair uses hex color strings in range) + altair_colors = base_colors + + return matplotlib_cmap, folium_cmap, altair_colors diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py new file mode 100644 index 0000000..c86cb84 --- /dev/null +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -0,0 +1,535 @@ +"""Hyperparameter analysis plotting functions for RandomizedSearchCV results.""" + +import altair as alt +import pandas as pd +import streamlit as st + + +def render_performance_summary(results: pd.DataFrame, refit_metric: str): + """Render summary statistics of model performance. + + Args: + results: DataFrame with CV results. + refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted'). + + """ + st.subheader("πŸ“Š Performance Summary") + + # Get all test score columns + score_cols = [col for col in results.columns if col.startswith("mean_test_")] + + if not score_cols: + st.warning("No test score columns found in results.") + return + + # Calculate statistics for each metric + col1, col2 = st.columns(2) + + with col1: + st.markdown("#### Best Scores") + best_scores = [] + for col in score_cols: + metric_name = col.replace("mean_test_", "").replace("_", " ").title() + best_score = results[col].max() + best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"}) + + st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True) + + with col2: + st.markdown("#### Score Statistics") + score_stats = [] + for col in score_cols: + metric_name = col.replace("mean_test_", "").replace("_", " ").title() + mean_score = results[col].mean() + std_score = results[col].std() + score_stats.append( + { + "Metric": metric_name, + "Mean Β± Std": f"{mean_score:.4f} Β± {std_score:.4f}", + } + ) + + st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True) + + # Show best parameter combination + st.markdown("#### πŸ† Best Parameter Combination") + refit_col = f"mean_test_{refit_metric}" + + # Check if refit metric exists in results + if refit_col not in results.columns: + st.warning( + f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}" + ) + # Use the first available metric as fallback + refit_col = score_cols[0] + refit_metric = refit_col.replace("mean_test_", "") + st.info(f"Using '{refit_metric}' as fallback metric.") + + best_idx = results[refit_col].idxmax() + best_row = results.loc[best_idx] + + # Extract parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if param_cols: + best_params = {col.replace("param_", ""): best_row[col] for col in param_cols} + + # Display in a nice formatted way + param_df = pd.DataFrame([best_params]).T + param_df.columns = ["Value"] + param_df.index.name = "Parameter" + + col1, col2 = st.columns([1, 1]) + with col1: + st.dataframe(param_df, use_container_width=True) + + with col2: + st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}") + rank_col = "rank_test_" + refit_metric + if rank_col in best_row.index: + try: + # Handle potential Series or scalar values + rank_val = best_row[rank_col] + if hasattr(rank_val, "item"): + rank_val = rank_val.item() + rank_display = str(int(float(rank_val))) + except (ValueError, TypeError, AttributeError): + rank_display = "N/A" + else: + rank_display = "N/A" + st.metric("Rank", rank_display) + + +def render_parameter_distributions(results: pd.DataFrame): + """Render histograms of parameter distributions explored. + + Args: + results: DataFrame with CV results. + + """ + st.subheader("πŸ“ˆ Parameter Space Exploration") + + # Get parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if not param_cols: + st.warning("No parameter columns found in results.") + return + + # Create histograms for each parameter + n_params = len(param_cols) + n_cols = min(3, n_params) + n_rows = (n_params + n_cols - 1) // n_cols + + for row in range(n_rows): + cols = st.columns(n_cols) + for col_idx in range(n_cols): + param_idx = row * n_cols + col_idx + if param_idx >= n_params: + break + + param_col = param_cols[param_idx] + param_name = param_col.replace("param_", "") + + with cols[col_idx]: + # Check if parameter is numeric or categorical + param_values = results[param_col].dropna() + + if pd.api.types.is_numeric_dtype(param_values): + # Numeric parameter - use histogram + df_plot = pd.DataFrame({param_name: param_values}) + + # Use log scale if the range spans multiple orders of magnitude + value_range = param_values.max() / (param_values.min() + 1e-10) + use_log = value_range > 100 + + if use_log: + chart = ( + alt.Chart(df_plot) + .mark_bar() + .encode( + alt.X( + param_name, + bin=alt.Bin(maxbins=30), + scale=alt.Scale(type="log"), + title=param_name, + ), + alt.Y("count()", title="Count"), + tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"], + ) + .properties(height=250, title=f"{param_name} (log scale)") + ) + else: + chart = ( + alt.Chart(df_plot) + .mark_bar() + .encode( + alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name), + alt.Y("count()", title="Count"), + tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"], + ) + .properties(height=250, title=param_name) + ) + + st.altair_chart(chart, use_container_width=True) + + else: + # Categorical parameter - use bar chart + value_counts = param_values.value_counts().reset_index() + value_counts.columns = [param_name, "count"] + + chart = ( + alt.Chart(value_counts) + .mark_bar() + .encode( + alt.X(param_name, title=param_name, sort="-y"), + alt.Y("count", title="Count"), + tooltip=[param_name, "count"], + ) + .properties(height=250, title=param_name) + ) + + st.altair_chart(chart, use_container_width=True) + + +def render_score_vs_parameter(results: pd.DataFrame, metric: str): + """Render scatter plots of score vs each parameter. + + Args: + results: DataFrame with CV results. + metric: The metric to plot (e.g., 'f1', 'accuracy'). + + """ + st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters") + + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + st.warning(f"Metric {metric} not found in results.") + return + + # Get parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if not param_cols: + st.warning("No parameter columns found in results.") + return + + # Create scatter plots for each parameter + n_params = len(param_cols) + n_cols = min(2, n_params) + n_rows = (n_params + n_cols - 1) // n_cols + + for row in range(n_rows): + cols = st.columns(n_cols) + for col_idx in range(n_cols): + param_idx = row * n_cols + col_idx + if param_idx >= n_params: + break + + param_col = param_cols[param_idx] + param_name = param_col.replace("param_", "") + + with cols[col_idx]: + param_values = results[param_col].dropna() + + if pd.api.types.is_numeric_dtype(param_values): + # Numeric parameter - scatter plot + df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]}) + + # Use log scale if needed + value_range = param_values.max() / (param_values.min() + 1e-10) + use_log = value_range > 100 + + if use_log: + chart = ( + alt.Chart(df_plot) + .mark_circle(size=60, opacity=0.6) + .encode( + alt.X( + param_name, + scale=alt.Scale(type="log"), + title=param_name, + ), + alt.Y(metric, title=metric.replace("_", " ").title()), + alt.Color( + metric, + scale=alt.Scale(scheme="viridis"), + legend=None, + ), + tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")], + ) + .properties(height=300, title=f"{metric} vs {param_name} (log scale)") + ) + else: + chart = ( + alt.Chart(df_plot) + .mark_circle(size=60, opacity=0.6) + .encode( + alt.X(param_name, title=param_name), + alt.Y(metric, title=metric.replace("_", " ").title()), + alt.Color( + metric, + scale=alt.Scale(scheme="viridis"), + legend=None, + ), + tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")], + ) + .properties(height=300, title=f"{metric} vs {param_name}") + ) + + st.altair_chart(chart, use_container_width=True) + + else: + # Categorical parameter - box plot + df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]}) + + chart = ( + alt.Chart(df_plot) + .mark_boxplot() + .encode( + alt.X(param_name, title=param_name), + alt.Y(metric, title=metric.replace("_", " ").title()), + tooltip=[param_name, alt.Tooltip(metric, format=".4f")], + ) + .properties(height=300, title=f"{metric} vs {param_name}") + ) + + st.altair_chart(chart, use_container_width=True) + + +def render_parameter_correlation(results: pd.DataFrame, metric: str): + """Render correlation heatmap between parameters and score. + + Args: + results: DataFrame with CV results. + metric: The metric to analyze (e.g., 'f1', 'accuracy'). + + """ + st.subheader(f"πŸ”— Parameter Correlations with {metric.replace('_', ' ').title()}") + + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + st.warning(f"Metric {metric} not found in results.") + return + + # Get numeric parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])] + + if not numeric_params: + st.warning("No numeric parameters found for correlation analysis.") + return + + # Calculate correlations + correlations = [] + for param_col in numeric_params: + param_name = param_col.replace("param_", "") + corr = results[[param_col, score_col]].corr().iloc[0, 1] + correlations.append({"Parameter": param_name, "Correlation": corr}) + + corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False) + + # Create bar chart + chart = ( + alt.Chart(corr_df) + .mark_bar() + .encode( + alt.X("Correlation", title="Correlation with Score"), + alt.Y("Parameter", sort="-x", title="Parameter"), + alt.Color( + "Correlation", + scale=alt.Scale(scheme="redblue", domain=[-1, 1]), + legend=None, + ), + tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")], + ) + .properties(height=max(200, len(correlations) * 30)) + ) + + st.altair_chart(chart, use_container_width=True) + + # Show correlation table + with st.expander("πŸ“‹ Correlation Table"): + st.dataframe( + corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]), + hide_index=True, + use_container_width=True, + ) + + +def render_score_evolution(results: pd.DataFrame, metric: str): + """Render evolution of scores during search. + + Args: + results: DataFrame with CV results. + metric: The metric to plot (e.g., 'f1', 'accuracy'). + + """ + st.subheader(f"πŸ“‰ {metric.replace('_', ' ').title()} Evolution") + + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + st.warning(f"Metric {metric} not found in results.") + return + + # Create a copy with iteration number + df_plot = results[[score_col]].copy() + df_plot["Iteration"] = range(len(df_plot)) + df_plot["Best So Far"] = df_plot[score_col].cummax() + df_plot = df_plot.rename(columns={score_col: "Score"}) + + # Reshape for Altair + df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type") + + # Create line chart + chart = ( + alt.Chart(df_long) + .mark_line() + .encode( + alt.X("Iteration", title="Iteration"), + alt.Y("value", title=metric.replace("_", " ").title()), + alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")), + strokeDash=alt.StrokeDash( + "Type", + legend=None, + scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]), + ), + tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")], + ) + .properties(height=400) + ) + + st.altair_chart(chart, use_container_width=True) + + # Show statistics + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}") + with col2: + st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}") + with col3: + st.metric("Std Dev", f"{df_plot['Score'].std():.4f}") + with col4: + # Find iteration where best was found + best_iter = df_plot["Score"].idxmax() + st.metric("Best at Iteration", best_iter) + + +def render_multi_metric_comparison(results: pd.DataFrame): + """Render comparison of multiple metrics. + + Args: + results: DataFrame with CV results. + + """ + st.subheader("πŸ“Š Multi-Metric Comparison") + + # Get all test score columns + score_cols = [col for col in results.columns if col.startswith("mean_test_")] + + if len(score_cols) < 2: + st.warning("Need at least 2 metrics for comparison.") + return + + # Let user select two metrics to compare + col1, col2 = st.columns(2) + with col1: + metric1 = st.selectbox( + "Select First Metric", + options=[col.replace("mean_test_", "") for col in score_cols], + index=0, + key="metric1_select", + ) + + with col2: + metric2 = st.selectbox( + "Select Second Metric", + options=[col.replace("mean_test_", "") for col in score_cols], + index=min(1, len(score_cols) - 1), + key="metric2_select", + ) + + if metric1 == metric2: + st.warning("Please select different metrics.") + return + + # Create scatter plot + df_plot = pd.DataFrame( + { + metric1: results[f"mean_test_{metric1}"], + metric2: results[f"mean_test_{metric2}"], + "Iteration": range(len(results)), + } + ) + + chart = ( + alt.Chart(df_plot) + .mark_circle(size=60, opacity=0.6) + .encode( + alt.X(metric1, title=metric1.replace("_", " ").title()), + alt.Y(metric2, title=metric2.replace("_", " ").title()), + alt.Color("Iteration", scale=alt.Scale(scheme="viridis")), + tooltip=[ + alt.Tooltip(metric1, format=".4f"), + alt.Tooltip(metric2, format=".4f"), + "Iteration", + ], + ) + .properties(height=500) + ) + + st.altair_chart(chart, use_container_width=True) + + # Calculate correlation + corr = df_plot[[metric1, metric2]].corr().iloc[0, 1] + st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}") + + +def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10): + """Render table of top N configurations. + + Args: + results: DataFrame with CV results. + metric: The metric to rank by (e.g., 'f1', 'accuracy'). + top_n: Number of top configurations to show. + + """ + st.subheader(f"πŸ† Top {top_n} Configurations by {metric.replace('_', ' ').title()}") + + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + st.warning(f"Metric {metric} not found in results.") + return + + # Get parameter columns + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + if not param_cols: + st.warning("No parameter columns found in results.") + return + + # Get top N configurations + top_configs = results.nlargest(top_n, score_col) + + # Create display dataframe + display_cols = ["rank_test_" + metric, score_col, *param_cols] + display_cols = [col for col in display_cols if col in top_configs.columns] + + display_df = top_configs[display_cols].copy() + + # Rename columns for better display + display_df = display_df.rename( + columns={ + "rank_test_" + metric: "Rank", + score_col: metric.replace("_", " ").title(), + } + ) + + # Rename parameter columns + display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns] + + # Format score column + score_col_display = metric.replace("_", " ").title() + display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}") + + st.dataframe(display_df, hide_index=True, use_container_width=True) diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py new file mode 100644 index 0000000..94f017d --- /dev/null +++ b/src/entropice/dashboard/plots/model_state.py @@ -0,0 +1,449 @@ +"""Plotting functions for model state visualization.""" + +import altair as alt +import pandas as pd +import xarray as xr + + +def plot_top_features(model_state: xr.Dataset, top_n: int = 10) -> alt.Chart: + """Plot the top N most important features based on feature weights. + + Args: + model_state: The xarray Dataset containing the model state. + top_n: Number of top features to display. + + Returns: + Altair chart showing the top features by importance. + + """ + # Extract feature weights + feature_weights = model_state["feature_weights"].to_pandas() + + # Sort by absolute weight and take top N + top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True) + + # Create DataFrame for plotting with original (signed) weights + plot_data = pd.DataFrame( + { + "feature": top_features.index, + "weight": feature_weights.loc[top_features.index].to_numpy(), + "abs_weight": top_features.to_numpy(), + } + ) + + # Create horizontal bar chart + chart = ( + alt.Chart(plot_data) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), + color=alt.condition( + alt.datum.weight > 0, + alt.value("steelblue"), # Positive weights + alt.value("coral"), # Negative weights + ), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), + ], + ) + .properties( + width=600, + height=400, + title=f"Top {top_n} Most Important Features", + ) + ) + + return chart + + +def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart: + """Create a heatmap showing embedding feature weights across bands and years. + + Args: + embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. + + Returns: + Altair chart showing the heatmap. + + """ + # Convert to DataFrame for plotting + df = embedding_array.to_dataframe(name="weight").reset_index() + + # Create faceted heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("year:O", title="Year"), + y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("agg:N", title="Aggregation"), + alt.Tooltip("band:N", title="Band"), + alt.Tooltip("year:O", title="Year"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties(width=200, height=200) + .facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11) + ) + + return chart + + +def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]: + """Create bar charts summarizing embedding weights by aggregation, band, and year. + + Args: + embedding_array: DataArray with dimensions (agg, band, year) containing feature weights. + + Returns: + Tuple of three Altair charts (by_agg, by_band, by_year). + + """ + # Aggregate by different dimensions + by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs() + by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs() + by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs() + + # Create DataFrames + df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()}) + df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()}) + df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()}) + + # Sort by weight + df_agg = df_agg.sort_values("mean_abs_weight", ascending=True) + df_band = df_band.sort_values("mean_abs_weight", ascending=True) + df_year = df_year.sort_values("mean_abs_weight", ascending=True) + + # Create charts with different colors + chart_agg = ( + alt.Chart(df_agg) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Aggregation", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="blues"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Aggregation"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Aggregation") + ) + + chart_band = ( + alt.Chart(df_band) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Band", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="greens"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Band"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Band") + ) + + chart_year = ( + alt.Chart(df_year) + .mark_bar() + .encode( + y=alt.Y("dimension:O", title="Year", sort="-x"), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="oranges"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:O", title="Year"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=250, height=200, title="By Year") + ) + + return chart_agg, chart_band, chart_year + + +def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart: + """Create a heatmap showing ERA5 feature weights across variables and time. + + Args: + era5_array: DataArray with dimensions (variable, time) containing feature weights. + + Returns: + Altair chart showing the heatmap. + + """ + # Convert to DataFrame for plotting + df = era5_array.to_dataframe(name="weight").reset_index() + + # Create heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("time:N", title="Time", sort=None), + y=alt.Y("variable:N", title="Variable", sort="-color"), + color=alt.Color( + "weight:Q", + scale=alt.Scale(scheme="redblue", domainMid=0), + title="Weight", + ), + tooltip=[ + alt.Tooltip("variable:N", title="Variable"), + alt.Tooltip("time:N", title="Time"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + ], + ) + .properties( + height=400, + title="ERA5 Feature Weights Heatmap", + ) + ) + + return chart + + +def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]: + """Create bar charts summarizing ERA5 weights by variable and time. + + Args: + era5_array: DataArray with dimensions (variable, time) containing feature weights. + + Returns: + Tuple of two Altair charts (by_variable, by_time). + + """ + # Aggregate by different dimensions + by_variable = era5_array.mean(dim="time").to_pandas().abs() + by_time = era5_array.mean(dim="variable").to_pandas().abs() + + # Create DataFrames + df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()}) + df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()}) + + # Sort by weight + df_variable = df_variable.sort_values("mean_abs_weight", ascending=True) + df_time = df_time.sort_values("mean_abs_weight", ascending=True) + + # Create charts with different colors + chart_variable = ( + alt.Chart(df_variable) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="purples"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Variable"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Variable") + ) + + chart_time = ( + alt.Chart(df_time) + .mark_bar() + .encode( + y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"), + color=alt.Color( + "mean_abs_weight:Q", + scale=alt.Scale(scheme="teals"), + legend=None, + ), + tooltip=[ + alt.Tooltip("dimension:N", title="Time"), + alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"), + ], + ) + .properties(width=400, height=300, title="By Time") + ) + + return chart_variable, chart_time + + +def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart: + """Create a heatmap showing which boxes are assigned to which labels/classes. + + Args: + model_state: The xarray Dataset containing the model state with box_assignments. + + Returns: + Altair chart showing the box-to-label assignment heatmap. + + """ + # Extract box assignments + box_assignments = model_state["box_assignments"] + + # Convert to DataFrame for plotting + df = box_assignments.to_dataframe(name="assignment").reset_index() + + # Create heatmap + chart = ( + alt.Chart(df) + .mark_rect() + .encode( + x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)), + y=alt.Y("class:N", title="Class Label"), + color=alt.Color( + "assignment:Q", + scale=alt.Scale(scheme="viridis"), + title="Assignment Strength", + ), + tooltip=[ + alt.Tooltip("class:N", title="Class"), + alt.Tooltip("box:O", title="Box"), + alt.Tooltip("assignment:Q", format=".4f", title="Assignment"), + ], + ) + .properties( + height=150, + title="Box-to-Label Assignments (Lambda Matrix)", + ) + ) + + return chart + + +def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]) -> alt.Chart: + """Create a bar chart showing how many boxes are assigned to each class. + + Args: + model_state: The xarray Dataset containing the model state with box_assignments. + altair_colors: List of hex color strings for altair. + + Returns: + Altair chart showing count of boxes per class. + + """ + # Extract box assignments + box_assignments = model_state["box_assignments"] + + # Convert to DataFrame + df = box_assignments.to_dataframe(name="assignment").reset_index() + + # For each box, find which class it's most strongly assigned to + box_to_class = df.groupby("box")["assignment"].idxmax() + primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True) + + # Count boxes per class + counts = primary_classes.groupby("class").size().reset_index(name="count") + + # Replace the special (-1, 0] interval with "No RTS" if present + counts["class"] = counts["class"].replace("(-1, 0]", "No RTS") + + # Sort the classes: "No RTS" first, then by the lower bound of intervals + def sort_key(class_str): + if class_str == "No RTS": + return -1 # Put "No RTS" first + # Parse interval string like "(0, 4]" or "(4, 36]" + try: + lower = float(str(class_str).split(",")[0].strip("([ ")) + return lower + except (ValueError, IndexError): + return float("inf") # Put unparseable values at the end + + # Sort counts by the same key + counts["sort_key"] = counts["class"].apply(sort_key) + counts = counts.sort_values("sort_key") + + # Create an ordered list of classes for consistent color mapping + class_order = counts["class"].tolist() + + # Create bar chart + chart = ( + alt.Chart(counts) + .mark_bar() + .encode( + x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)), + y=alt.Y("count:Q", title="Number of Boxes"), + color=alt.Color( + "class:N", + title="Class", + scale=alt.Scale(domain=class_order, range=altair_colors), + legend=None, + ), + tooltip=[ + alt.Tooltip("class:N", title="Class"), + alt.Tooltip("count:Q", title="Number of Boxes"), + ], + ) + .properties( + width=600, + height=300, + title="Number of Boxes Assigned to Each Class (by Primary Assignment)", + ) + ) + + return chart + + +def plot_common_features(common_array: xr.DataArray) -> alt.Chart: + """Create a bar chart showing the weights of common features. + + Args: + common_array: DataArray with dimension (feature) containing feature weights. + + Returns: + Altair chart showing the common feature weights. + + """ + # Convert to DataFrame for plotting + df = common_array.to_dataframe(name="weight").reset_index() + + # Sort by absolute weight + df["abs_weight"] = df["weight"].abs() + df = df.sort_values("abs_weight", ascending=True) + + # Create bar chart + chart = ( + alt.Chart(df) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x"), + x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"), + color=alt.condition( + alt.datum.weight > 0, + alt.value("steelblue"), # Positive weights + alt.value("coral"), # Negative weights + ), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("weight:Q", format=".4f", title="Weight"), + alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"), + ], + ) + .properties( + width=600, + height=300, + title="Common Feature Weights", + ) + ) + + return chart diff --git a/src/entropice/dashboard/training_analysis_page.py b/src/entropice/dashboard/training_analysis_page.py index a5e0957..67e530f 100644 --- a/src/entropice/dashboard/training_analysis_page.py +++ b/src/entropice/dashboard/training_analysis_page.py @@ -2,9 +2,237 @@ import streamlit as st +from entropice.dashboard.plots.hyperparameter_analysis import ( + render_multi_metric_comparison, + render_parameter_correlation, + render_parameter_distributions, + render_performance_summary, + render_score_evolution, + render_score_vs_parameter, + render_top_configurations, +) +from entropice.dashboard.utils.data import load_all_training_results +from entropice.dashboard.utils.training import ( + format_metric_name, + get_available_metrics, + get_cv_statistics, + get_parameter_space_summary, +) + def render_training_analysis_page(): """Render the Training Results Analysis page of the dashboard.""" - st.title("Training Results Analysis") - st.write("This page will display analysis of training results and model performance.") - # Add more components and visualizations as needed for training results analysis. + st.title("🦾 Training Results Analysis") + + # Load all available training results + training_results = load_all_training_results() + + if not training_results: + st.warning("No training results found. Please run some training experiments first.") + st.info("Run training using: `pixi run python -m entropice.training`") + return + + # Sidebar: Training run selection + with st.sidebar: + st.header("Select Training Run") + + # Create selection options + training_options = {tr.name: tr for tr in training_results} + + selected_name = st.selectbox( + "Training Run", + options=list(training_options.keys()), + index=0, + help="Select a training run to analyze", + ) + + selected_result = training_options[selected_name] + + st.divider() + + # Display selected run info + st.subheader("Run Information") + st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}") + st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}") + st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}") + st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}") + st.write(f"**Trials:** {len(selected_result.results)}") + st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}") + + # Refit metric - determine from available metrics + available_metrics = get_available_metrics(selected_result.results) + + # Try to get refit metric from settings + refit_metric = selected_result.settings.get("refit_metric") + + if not refit_metric or refit_metric not in available_metrics: + # Infer from task or use first available metric + task = selected_result.settings.get("task", "binary") + if task == "binary" and "f1" in available_metrics: + refit_metric = "f1" + elif "f1_weighted" in available_metrics: + refit_metric = "f1_weighted" + elif "accuracy" in available_metrics: + refit_metric = "accuracy" + elif available_metrics: + refit_metric = available_metrics[0] + else: + st.error("No metrics found in results.") + return + + st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}") + + st.divider() + + # Metric selection for detailed analysis + st.subheader("Analysis Settings") + + available_metrics = get_available_metrics(selected_result.results) + + if refit_metric in available_metrics: + default_metric_idx = available_metrics.index(refit_metric) + else: + default_metric_idx = 0 + + selected_metric = st.selectbox( + "Primary Metric for Analysis", + options=available_metrics, + index=default_metric_idx, + format_func=format_metric_name, + help="Select the metric to focus on for detailed analysis", + ) + + # Top N configurations + top_n = st.slider( + "Top N Configurations", + min_value=5, + max_value=50, + value=10, + step=5, + help="Number of top configurations to display", + ) + + # Main content area + results = selected_result.results + settings = selected_result.settings + + # Performance Summary Section + st.header("πŸ“Š Performance Overview") + + render_performance_summary(results, refit_metric) + + st.divider() + + # Quick Statistics + st.header("πŸ“ˆ Cross-Validation Statistics") + + cv_stats = get_cv_statistics(results, selected_metric) + + if cv_stats: + col1, col2, col3, col4, col5 = st.columns(5) + + with col1: + st.metric("Best Score", f"{cv_stats['best_score']:.4f}") + + with col2: + st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}") + + with col3: + st.metric("Std Dev", f"{cv_stats['std_score']:.4f}") + + with col4: + st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}") + + with col5: + st.metric("Median Score", f"{cv_stats['median_score']:.4f}") + + if "mean_cv_std" in cv_stats: + st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds") + + st.divider() + + # Score Evolution + st.header("πŸ“‰ Training Progress") + + render_score_evolution(results, selected_metric) + + st.divider() + + # Parameter Space Exploration + st.header("πŸ” Parameter Space Analysis") + + # Show parameter space summary + with st.expander("πŸ“‹ Parameter Space Summary", expanded=False): + param_summary = get_parameter_space_summary(results) + if not param_summary.empty: + st.dataframe(param_summary, hide_index=True, use_container_width=True) + else: + st.info("No parameter information available.") + + # Parameter distributions + render_parameter_distributions(results) + + st.divider() + + # Score vs Parameters + st.header("🎯 Parameter Impact Analysis") + + render_score_vs_parameter(results, selected_metric) + + st.divider() + + # Parameter Correlation + st.header("πŸ”— Parameter Correlation Analysis") + + render_parameter_correlation(results, selected_metric) + + st.divider() + + # Multi-Metric Comparison + if len(available_metrics) >= 2: + st.header("πŸ“Š Multi-Metric Analysis") + + render_multi_metric_comparison(results) + + st.divider() + + # Top Configurations + st.header("πŸ† Top Performing Configurations") + + render_top_configurations(results, selected_metric, top_n) + + st.divider() + + # Raw Data Export + with st.expander("πŸ’Ύ Export Data", expanded=False): + st.subheader("Download Results") + + col1, col2 = st.columns(2) + + with col1: + # Download full results as CSV + csv_data = results.to_csv(index=False) + st.download_button( + label="πŸ“₯ Download Full Results (CSV)", + data=csv_data, + file_name=f"{selected_result.path.name}_results.csv", + mime="text/csv", + use_container_width=True, + ) + + with col2: + # Download settings as text + import json + + settings_json = json.dumps(settings, indent=2) + st.download_button( + label="βš™οΈ Download Settings (JSON)", + data=settings_json, + file_name=f"{selected_result.path.name}_settings.json", + mime="application/json", + use_container_width=True, + ) + + # Show raw data preview + st.subheader("Raw Data Preview") + st.dataframe(results.head(100), use_container_width=True) diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py index 04dfadd..72f31ff 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/training_data_page.py @@ -118,6 +118,42 @@ def render_training_data_page(): # Display dataset ID in a styled container st.info(f"**Dataset ID:** `{ensemble.id()}`") + # Display dataset statistics + st.markdown("---") + st.subheader("πŸ“ˆ Dataset Statistics") + + with st.spinner("Computing dataset statistics..."): + stats = ensemble.get_stats() + + # Display target information + col1, col2 = st.columns(2) + with col1: + st.metric(label="Target", value=stats["target"].replace("darts_", "")) + with col2: + st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}") + + # Display member statistics + st.markdown("**Member Statistics:**") + + for member, member_stats in stats["members"].items(): + with st.expander(f"πŸ“¦ {member}", expanded=False): + col1, col2 = st.columns(2) + with col1: + st.markdown(f"**Number of Features:** {member_stats['num_features']}") + st.markdown(f"**Number of Variables:** {member_stats['num_variables']}") + with col2: + st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`") + + # Display variables as a compact list + st.markdown(f"**Variables ({member_stats['num_variables']}):**") + vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]]) + st.markdown(vars_str) + + # Display total features + st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}") + + st.markdown("---") + # Create tabs for different data views tab_names = ["πŸ“Š Labels"] diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py index 70c0a21..c0ce468 100644 --- a/src/entropice/dashboard/utils/data.py +++ b/src/entropice/dashboard/utils/data.py @@ -8,6 +8,7 @@ import antimeridian import pandas as pd import streamlit as st import toml +import xarray as xr from shapely.geometry import shape import entropice.paths @@ -119,3 +120,98 @@ def load_source_data(e: DatasetEnsemble, source: str): ds = e._read_member(source, targets, lazy=False) return ds, targets + + +def extract_embedding_features(model_state) -> xr.DataArray | None: + """Extract embedding features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted embedding features. This DataArray has dimensions + ('agg', 'band', 'year') corresponding to the different components of the embedding features. + Returns None if no embedding features are found. + + """ + + def _is_embedding_feature(feature: str) -> bool: + return feature.startswith("embedding_") + + embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)] + if len(embedding_features) == 0: + return None + + # Split the single feature dimension of embedding features into separate dimensions (agg, band, year) + embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"] + embedding_feature_array = embedding_feature_array.assign_coords( + agg=("feature", [f.split("_")[1] for f in embedding_features]), + band=("feature", [f.split("_")[2] for f in embedding_features]), + year=("feature", [f.split("_")[3] for f in embedding_features]), + ) + embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 + return embedding_feature_array + + +def extract_era5_features(model_state) -> xr.DataArray | None: + """Extract ERA5 features from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted ERA5 features. This DataArray has dimensions + ('variable', 'time') corresponding to the different components of the ERA5 features. + Returns None if no ERA5 features are found. + + """ + + def _is_era5_feature(feature: str) -> bool: + return feature.startswith("era5_") + + def _extract_var_name(feature: str) -> str: + parts = feature.split("_") + # era5_variablename_timetype format + return "_".join(parts[1:-1]) + + def _extract_time_name(feature: str) -> str: + parts = feature.split("_") + # Last part is the time type + return parts[-1] + + era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)] + if len(era5_features) == 0: + return None + # Split the single feature dimension of era5 features into separate dimensions (variable, time) + era5_features_array = model_state.sel(feature=era5_features)["feature_weights"] + era5_features_array = era5_features_array.assign_coords( + variable=("feature", [_extract_var_name(f) for f in era5_features]), + time=("feature", [_extract_time_name(f) for f in era5_features]), + ) + era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 + return era5_features_array + + +def extract_common_features(model_state) -> xr.DataArray | None: + """Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state. + + Args: + model_state: The xarray Dataset containing the model state. + + Returns: + xr.DataArray: The extracted common features with a single 'feature' dimension. + Returns None if no common features are found. + + """ + common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"] + + def _is_common_feature(feature: str) -> bool: + return feature in common_feature_names + + common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)] + if len(common_features) == 0: + return None + + # Extract the feature weights for common features + common_feature_array = model_state.sel(feature=common_features)["feature_weights"] + return common_feature_array diff --git a/src/entropice/dashboard/utils/training.py b/src/entropice/dashboard/utils/training.py new file mode 100644 index 0000000..70c09b3 --- /dev/null +++ b/src/entropice/dashboard/utils/training.py @@ -0,0 +1,232 @@ +"""Training utilities for dashboard.""" + +import pickle + +import numpy as np +import pandas as pd +import streamlit as st +import xarray as xr + +from entropice.dashboard.utils.data import TrainingResult + + +def format_metric_name(metric: str) -> str: + """Format metric name for display. + + Args: + metric: Raw metric name (e.g., 'f1_micro', 'precision_macro'). + + Returns: + Formatted metric name (e.g., 'F1 Micro', 'Precision Macro'). + + """ + # Split by underscore and capitalize each part + parts = metric.split("_") + # Special handling for F1 + formatted_parts = [] + for part in parts: + if part.lower() == "f1": + formatted_parts.append("F1") + else: + formatted_parts.append(part.capitalize()) + return " ".join(formatted_parts) + + +def get_available_metrics(results: pd.DataFrame) -> list[str]: + """Get list of available metrics from results. + + Args: + results: DataFrame with CV results. + + Returns: + List of metric names (without 'mean_test_' prefix). + + """ + score_cols = [col for col in results.columns if col.startswith("mean_test_")] + return [col.replace("mean_test_", "") for col in score_cols] + + +def load_best_model(result: TrainingResult): + """Load the best model from a training result. + + Args: + result: TrainingResult object. + + Returns: + The loaded model object, or None if loading fails. + + """ + model_file = result.path / "best_estimator_model.pkl" + if not model_file.exists(): + return None + + try: + with open(model_file, "rb") as f: + model = pickle.load(f) + return model + except Exception as e: + st.error(f"Error loading model: {e}") + return None + + +def load_model_state(result: TrainingResult) -> xr.Dataset | None: + """Load the model state from a training result. + + Args: + result: TrainingResult object. + + Returns: + xarray Dataset with model state, or None if not available. + + """ + state_file = result.path / "best_estimator_state.nc" + if not state_file.exists(): + return None + + try: + state = xr.open_dataset(state_file, engine="h5netcdf") + return state + except Exception as e: + st.error(f"Error loading model state: {e}") + return None + + +def load_predictions(result: TrainingResult) -> pd.DataFrame | None: + """Load predictions from a training result. + + Args: + result: TrainingResult object. + + Returns: + DataFrame with predictions, or None if not available. + + """ + preds_file = result.path / "predicted_probabilities.parquet" + if not preds_file.exists(): + return None + + try: + preds = pd.read_parquet(preds_file) + return preds + except Exception as e: + st.error(f"Error loading predictions: {e}") + return None + + +def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame: + """Get summary of parameter space explored. + + Args: + results: DataFrame with CV results. + + Returns: + DataFrame with parameter ranges and statistics. + + """ + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + + summary_data = [] + for param_col in param_cols: + param_name = param_col.replace("param_", "") + param_values = results[param_col].dropna() + + if pd.api.types.is_numeric_dtype(param_values): + summary_data.append( + { + "Parameter": param_name, + "Type": "Numeric", + "Min": f"{param_values.min():.2e}", + "Max": f"{param_values.max():.2e}", + "Mean": f"{param_values.mean():.2e}", + "Unique Values": param_values.nunique(), + } + ) + else: + unique_vals = param_values.unique() + summary_data.append( + { + "Parameter": param_name, + "Type": "Categorical", + "Min": "-", + "Max": "-", + "Mean": "-", + "Unique Values": len(unique_vals), + } + ) + + return pd.DataFrame(summary_data) + + +def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict: + """Get cross-validation statistics for a metric. + + Args: + results: DataFrame with CV results. + metric: Metric name (without 'mean_test_' prefix). + + Returns: + Dictionary with CV statistics. + + """ + score_col = f"mean_test_{metric}" + std_col = f"std_test_{metric}" + + if score_col not in results.columns: + return {} + + stats = { + "best_score": results[score_col].max(), + "mean_score": results[score_col].mean(), + "std_score": results[score_col].std(), + "worst_score": results[score_col].min(), + "median_score": results[score_col].median(), + } + + if std_col in results.columns: + stats["mean_cv_std"] = results[std_col].mean() + + return stats + + +def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame: + """Prepare results dataframe with binned columns for plotting. + + Args: + results: DataFrame with CV results. + k_bin_width: Width of bins for initial_K parameter. + + Returns: + DataFrame with added binned columns. + + """ + results_copy = results.copy() + + # Check if we have the parameters + if "param_initial_K" in results.columns: + # Bin initial_K + k_values = results["param_initial_K"].dropna() + if len(k_values) > 0: + k_min = k_values.min() + k_max = k_values.max() + k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width) + results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False) + + if "param_eps_cl" in results.columns: + # Create logarithmic bins for eps_cl + eps_cl_values = results["param_eps_cl"].dropna() + if len(eps_cl_values) > 0 and eps_cl_values.min() > 0: + eps_cl_min = eps_cl_values.min() + eps_cl_max = eps_cl_values.max() + eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10) + results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins) + + if "param_eps_e" in results.columns: + # Create logarithmic bins for eps_e + eps_e_values = results["param_eps_e"].dropna() + if len(eps_e_values) > 0 and eps_e_values.min() > 0: + eps_e_min = eps_e_values.min() + eps_e_max = eps_e_values.max() + eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10) + results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins) + + return results_copy diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index 324f61a..5003728 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -283,25 +283,52 @@ class DatasetEnsemble: arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] return arcticdem_df - def print_stats(self): - targets = self._read_target() - print(f"=== Target: {self.target}") - print(f"\tNumber of target samples: {len(targets)}") + def get_stats(self) -> dict: + """Get dataset statistics. + + Returns: + dict: Dictionary containing target stats, member stats, and total features count. + + """ + targets = self._read_target() + stats = { + "target": self.target, + "num_target_samples": len(targets), + "members": {}, + "total_features": 2 if self.add_lonlat else 0, # Lat and Lon + } - n_cols = 2 if self.add_lonlat else 0 # Lat and Lon for member in self.members: ds = self._read_member(member, targets, lazy=True) - print(f"=== Member: {member}") - print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}") - print(f"\tDimensions: {dict(ds.sizes)}") - print(f"\tCoordinates: {list(ds.coords)}") n_cols_member = len(ds.data_vars) for dim in ds.sizes: if dim != "cell_ids": n_cols_member *= ds.sizes[dim] - print(f"\tNumber of features from member: {n_cols_member}") - n_cols += n_cols_member - print(f"=== Total number of features in dataset: {n_cols}") + + stats["members"][member] = { + "variables": list(ds.data_vars), + "num_variables": len(ds.data_vars), + "dimensions": dict(ds.sizes), + "coordinates": list(ds.coords), + "num_features": n_cols_member, + } + stats["total_features"] += n_cols_member + + return stats + + def print_stats(self): + stats = self.get_stats() + print(f"=== Target: {stats['target']}") + print(f"\tNumber of target samples: {stats['num_target_samples']}") + + for member, member_stats in stats["members"].items(): + print(f"=== Member: {member}") + print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}") + print(f"\tDimensions: {member_stats['dimensions']}") + print(f"\tCoordinates: {member_stats['coordinates']}") + print(f"\tNumber of features from member: {member_stats['num_features']}") + + print(f"=== Total number of features in dataset: {stats['total_features']}") @lru_cache(maxsize=1) def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: