From 591da6992efecef9a4c1349116ccfa5408ebdf6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 23 Dec 2025 23:33:54 +0100 Subject: [PATCH] More plot refinements --- src/entropice/alphaearth.py | 2 +- .../plots/hyperparameter_analysis.py | 284 ++++++++++++++++-- .../dashboard/training_analysis_page.py | 18 +- src/entropice/dataset.py | 11 +- src/entropice/paths.py | 2 +- src/entropice/training.py | 13 +- 6 files changed, 298 insertions(+), 32 deletions(-) diff --git a/src/entropice/alphaearth.py b/src/entropice/alphaearth.py index 74ec975..65dbaf4 100644 --- a/src/entropice/alphaearth.py +++ b/src/entropice/alphaearth.py @@ -73,7 +73,7 @@ def download(grid: Literal["hex", "healpix"], level: int): print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.") # 2024-2025 for hex-6 - for year in track(range(2022, 2025), total=3, description="Processing years..."): + for year in track(range(2024, 2025), total=1, description="Processing years..."): embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index dd05665..58eafae 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -120,11 +120,12 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str): ) -def render_parameter_distributions(results: pd.DataFrame): +def render_parameter_distributions(results: pd.DataFrame, settings: dict | None = None): """Render histograms of parameter distributions explored. Args: results: DataFrame with CV results. + settings: Optional settings dictionary containing param_grid configuration. """ # Get parameter columns @@ -134,6 +135,20 @@ def render_parameter_distributions(results: pd.DataFrame): st.warning("No parameter columns found in results.") return + # Extract scale information from settings if available + param_scales = {} + if settings and "param_grid" in settings: + param_grid = settings["param_grid"] + for param_name, param_config in param_grid.items(): + if isinstance(param_config, dict) and "distribution" in param_config: + # loguniform distribution indicates log scale + if param_config["distribution"] == "loguniform": + param_scales[param_name] = "log" + else: + param_scales[param_name] = "linear" + else: + param_scales[param_name] = "linear" + # Get colormap from colors module cmap = get_cmap("parameter_distribution") bar_color = mcolors.rgb2hex(cmap(0.5)) @@ -178,6 +193,13 @@ def render_parameter_distributions(results: pd.DataFrame): # Determine number of bins based on unique values n_bins = min(20, max(5, n_unique)) + # Determine x-axis scale from config or infer from data + use_log_scale = param_scales.get(param_name, "linear") == "log" + if not use_log_scale: + # Fall back to automatic detection if not in config + value_range = max_val / (param_values.min() + 1e-10) + use_log_scale = value_range > 100 or max_val < 0.01 + # For very small values or large ranges, use a simple bar chart instead of histogram if n_unique <= 10: # Few unique values - use bar chart @@ -185,6 +207,9 @@ def render_parameter_distributions(results: pd.DataFrame): value_counts.columns = [param_name, "count"] value_counts = value_counts.sort_values(param_name) + x_scale = alt.Scale(type="log") if use_log_scale else alt.Scale() + title_suffix = " (log scale)" if use_log_scale else "" + # Format x-axis values if max_val < 0.01: formatted_col = f"{param_name}_formatted" @@ -200,38 +225,85 @@ def render_parameter_distributions(results: pd.DataFrame): alt.Tooltip("count", title="Count"), ], ) - .properties(height=250, title=param_name) + .properties(height=250, title=f"{param_name}{title_suffix}") ) else: chart = ( alt.Chart(value_counts) .mark_bar(color=bar_color) .encode( - alt.X(f"{param_name}:Q", title=param_name), + alt.X(f"{param_name}:Q", title=param_name, scale=x_scale), alt.Y("count:Q", title="Count"), tooltip=[ alt.Tooltip(param_name, format=".3f"), alt.Tooltip("count", title="Count"), ], ) - .properties(height=250, title=param_name) + .properties(height=250, title=f"{param_name}{title_suffix}") ) else: # Many unique values - use binned histogram - # Avoid log scale for binning as it can cause issues - chart = ( - alt.Chart(df_plot) - .mark_bar(color=bar_color) - .encode( - alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name), - alt.Y("count()", title="Count"), - tooltip=[ - alt.Tooltip(f"{param_name}:Q", format=".2e" if max_val < 0.01 else ".3f", bin=True), - "count()", - ], + title_suffix = " (log scale)" if use_log_scale else "" + + if use_log_scale: + # For log scale parameters, create bins in log space then transform back + # This gives better distribution visualization + log_values = np.log10(param_values.to_numpy()) + log_bins = np.linspace(log_values.min(), log_values.max(), n_bins + 1) + bins_linear = 10**log_bins + + # Manually bin the data + binned = pd.cut(param_values, bins=bins_linear) + bin_counts = binned.value_counts().sort_index() + + # Create dataframe for plotting + bin_data = [] + for interval, count in bin_counts.items(): + bin_mid = (interval.left + interval.right) / 2 + bin_data.append( + { + param_name: bin_mid, + "count": count, + "bin_label": f"{interval.left:.2e} - {interval.right:.2e}", + } + ) + + df_binned = pd.DataFrame(bin_data) + + chart = ( + alt.Chart(df_binned) + .mark_bar(color=bar_color) + .encode( + alt.X( + f"{param_name}:Q", + title=param_name, + scale=alt.Scale(type="log"), + axis=alt.Axis(format=".2e"), + ), + alt.Y("count:Q", title="Count"), + tooltip=[ + alt.Tooltip("bin_label:N", title="Range"), + alt.Tooltip("count:Q", title="Count"), + ], + ) + .properties(height=250, title=f"{param_name}{title_suffix}") + ) + else: + # Linear scale - use standard binning + format_str = ".2e" if max_val < 0.01 else ".3f" + chart = ( + alt.Chart(df_plot) + .mark_bar(color=bar_color) + .encode( + alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name), + alt.Y("count()", title="Count"), + tooltip=[ + alt.Tooltip(f"{param_name}:Q", format=format_str, bin=True), + "count()", + ], + ) + .properties(height=250, title=param_name) ) - .properties(height=250, title=param_name) - ) st.altair_chart(chart, use_container_width=True) @@ -778,6 +850,184 @@ def render_multi_metric_comparison(results: pd.DataFrame): st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}") +def render_espa_binned_parameter_space(results: pd.DataFrame, metric: str, k_bin_width: int = 40): + """Render ESPA-specific binned parameter space plots. + + Creates faceted plots for all combinations of the three ESPA parameters: + - eps_cl vs eps_e (binned by initial_K) + - eps_cl vs initial_K (binned by eps_e) + - eps_e vs initial_K (binned by eps_cl) + + Args: + results: DataFrame with CV results. + metric: The metric to visualize (e.g., 'f1', 'accuracy'). + k_bin_width: Width of bins for initial_K parameter. + + """ + score_col = f"mean_test_{metric}" + if score_col not in results.columns: + st.warning(f"Metric {metric} not found in results.") + return + + # Check if this is an ESPA model with the required parameters + required_params = ["param_initial_K", "param_eps_cl", "param_eps_e"] + if not all(param in results.columns for param in required_params): + st.info("ESPA-specific parameters not found. This visualization is only for ESPA models.") + return + + # Get colormap from colors module + hex_colors = get_palette(metric, n_colors=256) + + # Prepare base plot data + base_data = results[["param_eps_e", "param_eps_cl", "param_initial_K", score_col]].copy() + base_data = base_data.dropna() + + if len(base_data) == 0: + st.warning("No data available for ESPA binned parameter space.") + return + + # Configuration for each plot combination + plot_configs = [ + { + "x_param": "param_eps_e", + "y_param": "param_eps_cl", + "bin_param": "param_initial_K", + "x_label": "eps_e", + "y_label": "eps_cl", + "bin_label": "initial_K", + "x_scale": "log", + "y_scale": "log", + "bin_type": "linear", + "bin_width": k_bin_width, + "title": "eps_cl vs eps_e (binned by initial_K)", + }, + { + "x_param": "param_eps_cl", + "y_param": "param_initial_K", + "bin_param": "param_eps_e", + "x_label": "eps_cl", + "y_label": "initial_K", + "bin_label": "eps_e", + "x_scale": "log", + "y_scale": "linear", + "bin_type": "log", + "bin_width": None, # Will use log bins + "title": "initial_K vs eps_cl (binned by eps_e)", + }, + { + "x_param": "param_eps_e", + "y_param": "param_initial_K", + "bin_param": "param_eps_cl", + "x_label": "eps_e", + "y_label": "initial_K", + "bin_label": "eps_cl", + "x_scale": "log", + "y_scale": "linear", + "bin_type": "log", + "bin_width": None, # Will use log bins + "title": "initial_K vs eps_e (binned by eps_cl)", + }, + ] + + # Create each plot + for config in plot_configs: + st.markdown(f"**{config['title']}**") + + plot_data = base_data.copy() + + # Create bins for the binning parameter + bin_values = plot_data[config["bin_param"]].dropna() + if len(bin_values) == 0: + st.warning(f"No {config['bin_label']} values found.") + continue + + if config["bin_type"] == "log": + # Logarithmic binning for epsilon parameters + log_min = np.log10(bin_values.min()) + log_max = np.log10(bin_values.max()) + n_bins = min(10, max(5, int(log_max - log_min) + 1)) + bins = np.logspace(log_min, log_max, num=n_bins) + # Adjust bins to ensure all values are captured + bins[0] = bins[0] * 0.999 # Extend first bin to capture minimum + bins[-1] = bins[-1] * 1.001 # Extend last bin to capture maximum + else: + # Linear binning for initial_K + bin_min = bin_values.min() + bin_max = bin_values.max() + bins = np.arange(bin_min, bin_max + config["bin_width"], config["bin_width"]) + + # Bin the parameter + plot_data["binned_param"] = pd.cut(plot_data[config["bin_param"]], bins=bins, right=False) + + # Remove any NaN bins (shouldn't happen, but just in case) + plot_data = plot_data.dropna(subset=["binned_param"]) + + # Sort bins and convert to string for ordering + plot_data = plot_data.sort_values("binned_param") + plot_data["binned_param_str"] = plot_data["binned_param"].astype(str) + bin_order = plot_data["binned_param_str"].unique().tolist() + + # Create faceted scatter plot + x_scale = alt.Scale(type=config["x_scale"]) if config["x_scale"] == "log" else alt.Scale() + # For initial_K plots, set domain to start from the minimum value in the search space + if config["y_label"] == "initial_K": + y_min = plot_data[config["y_param"]].min() + y_max = plot_data[config["y_param"]].max() + y_scale = alt.Scale(domain=[y_min, y_max]) + else: + y_scale = alt.Scale(type=config["y_scale"]) if config["y_scale"] == "log" else alt.Scale() + + chart = ( + alt.Chart(plot_data) + .mark_circle(size=60, opacity=0.7) + .encode( + x=alt.X( + f"{config['x_param']}:Q", + scale=x_scale, + axis=alt.Axis(title=config["x_label"], grid=True, gridOpacity=0.5), + ), + y=alt.Y( + f"{config['y_param']}:Q", + scale=y_scale, + axis=alt.Axis(title=config["y_label"], grid=True, gridOpacity=0.5), + ), + color=alt.Color( + f"{score_col}:Q", + scale=alt.Scale(range=hex_colors), + title=metric.replace("_", " ").title(), + ), + tooltip=[ + alt.Tooltip("param_eps_e:Q", title="eps_e", format=".2e"), + alt.Tooltip("param_eps_cl:Q", title="eps_cl", format=".2e"), + alt.Tooltip("param_initial_K:Q", title="initial_K", format=".0f"), + alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"), + alt.Tooltip("binned_param_str:N", title=f"{config['bin_label']} Bin"), + ], + ) + .properties(width=200, height=200) + .facet( + facet=alt.Facet( + "binned_param_str:N", + title=f"{config['bin_label']} (binned)", + sort=bin_order, + ), + columns=5, + ) + ) + + st.altair_chart(chart, use_container_width=True) + + # Show statistics about the binning + n_bins = len(bin_order) + n_total = len(plot_data) + st.caption( + f"Showing {n_total} data points across {n_bins} bins of {config['bin_label']}. " + f"Each facet shows {config['y_label']} vs {config['x_label']} for a range of {config['bin_label']} values." + ) + + st.write("") # Add some spacing between plots + + def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10): """Render table of top N configurations. diff --git a/src/entropice/dashboard/training_analysis_page.py b/src/entropice/dashboard/training_analysis_page.py index 8fc7821..0b506c1 100644 --- a/src/entropice/dashboard/training_analysis_page.py +++ b/src/entropice/dashboard/training_analysis_page.py @@ -4,6 +4,7 @@ import streamlit as st from entropice.dashboard.plots.hyperparameter_analysis import ( render_binned_parameter_space, + render_espa_binned_parameter_space, render_multi_metric_comparison, render_parameter_correlation, render_parameter_distributions, @@ -168,11 +169,24 @@ def render_training_analysis_page(): # Parameter distributions st.subheader("📈 Parameter Distributions") - render_parameter_distributions(results) + render_parameter_distributions(results, settings) # Binned parameter space plots st.subheader("🎨 Binned Parameter Space") - render_binned_parameter_space(results, selected_metric) + + # Check if this is an ESPA model and show ESPA-specific plots + model_type = settings.get("model", "espa") + if model_type == "espa": + # Show ESPA-specific binned plots (eps_cl vs eps_e binned by K) + render_espa_binned_parameter_space(results, selected_metric) + + # Optionally show the generic binned plots in an expander + with st.expander("📊 All Parameter Combinations", expanded=False): + st.caption("Generic parameter space exploration (all pairwise combinations)") + render_binned_parameter_space(results, selected_metric) + else: + # For non-ESPA models, show the generic binned plots + render_binned_parameter_space(results, selected_metric) st.divider() diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index b80cd4a..2160737 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -487,11 +487,14 @@ class DatasetEnsemble: X_test = cp.asarray(X_test) y_train = cp.asarray(y_train) y_test = cp.asarray(y_test) + print(f"Using CUDA device: {cp.cuda.runtime.getDeviceProperties(0)['name'].decode()}") elif device == "torch": - X_train = torch.from_numpy(X_train) - X_test = torch.from_numpy(X_test) - y_train = torch.from_numpy(y_train) - y_test = torch.from_numpy(y_test) + torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + X_train = torch.from_numpy(X_train).to(device=torch_device) + X_test = torch.from_numpy(X_test).to(device=torch_device) + y_train = torch.from_numpy(y_train).to(device=torch_device) + y_test = torch.from_numpy(y_test).to(device=torch_device) + print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}") else: assert device == "cpu", "Invalid device specified." diff --git a/src/entropice/paths.py b/src/entropice/paths.py index 6ca18c7..5a8a5a2 100644 --- a/src/entropice/paths.py +++ b/src/entropice/paths.py @@ -9,7 +9,7 @@ from typing import Literal DATA_DIR = ( Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice" ) -# DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster +DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster GRIDS_DIR = DATA_DIR / "grids" FIGURES_DIR = Path("figures") diff --git a/src/entropice/training.py b/src/entropice/training.py index 7b0cc87..bab3363 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -15,6 +15,7 @@ from entropy import ESPAClassifier from rich import pretty, traceback from scipy.stats import loguniform, randint from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen +from sklearn import set_config from sklearn.model_selection import KFold, RandomizedSearchCV from stopuhr import stopwatch from xgboost.sklearn import XGBClassifier @@ -28,7 +29,7 @@ pretty.install() # Disabled array_api_dispatch to avoid namespace conflicts between NumPy and CuPy # when using XGBoost with device="cuda" -# set_config(array_api_dispatch=True) +set_config(array_api_dispatch=True) cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type] @@ -78,10 +79,8 @@ def _create_clf( elif settings.model == "xgboost": param_grid = { "learning_rate": loguniform(1e-4, 1e-1), - "max_depth": randint(3, 15), - "n_estimators": randint(100, 1000), - "subsample": loguniform(0.5, 1.0), - "colsample_bytree": loguniform(0.5, 1.0), + "max_depth": randint(5, 50), + "n_estimators": randint(50, 1000), } clf = XGBClassifier( objective="multi:softprob" if settings.task != "binary" else "binary:logistic", @@ -94,9 +93,9 @@ def _create_clf( elif settings.model == "rf": param_grid = { "max_depth": randint(5, 50), - "n_estimators": randint(50, 500), + "n_estimators": randint(50, 1000), } - clf = RandomForestClassifier(random_state=42) + clf = RandomForestClassifier(random_state=42, split_criterion="entropy") fit_params = {} elif settings.model == "knn": param_grid = {