diff --git a/pixi.lock b/pixi.lock index 7635c77..c1369b7 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,7 +9,6 @@ environments: - https://pypi.org/simple options: channel-priority: disabled - pypi-prerelease-mode: if-necessary-or-explicit packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -2869,7 +2868,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: 6488240242ab4091686c27ee2e721c81143e0a158bdff086518da6ccdc2971c8 + sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -2933,6 +2932,7 @@ packages: - ruff>=0.14.9,<0.15 - pandas-stubs>=2.3.3.251201,<3 requires_python: '>=3.13,<3.14' + editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 name: entropy version: 0.1.0 diff --git a/pyproject.toml b/pyproject.toml index 5ec1bff..87d1773 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,8 @@ dependencies = [ "pydeck>=0.9.1,<0.10", "pypalettes>=0.2.1,<0.3", "ty>=0.0.2,<0.0.3", - "ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3", + "ruff>=0.14.9,<0.15", + "pandas-stubs>=2.3.3.251201,<3", ] [project.scripts] @@ -118,6 +119,7 @@ platforms = ["linux-64"] [tool.pixi.activation.env] SCIPY_ARRAY_API = "1" +FAST_DATA_DIR = "./data" [tool.pixi.system-requirements] cuda = "12" diff --git a/src/entropice/dashboard/model_state_page.py b/src/entropice/dashboard/model_state_page.py index 88558fa..3e424dd 100644 --- a/src/entropice/dashboard/model_state_page.py +++ b/src/entropice/dashboard/model_state_page.py @@ -1,6 +1,7 @@ """Model State page for the Entropice dashboard.""" import streamlit as st +import xarray as xr from entropice.dashboard.plots.colors import generate_unified_colormap from entropice.dashboard.plots.model_state import ( @@ -43,6 +44,9 @@ def render_model_state_page(): ) selected_result = result_options[selected_name] + # Get the model type from settings + model_type = selected_result.settings.get("model", "espa") + # Load model state with st.spinner("Loading model state..."): model_state = load_model_state(selected_result) @@ -50,34 +54,40 @@ def render_model_state_page(): st.error("Could not load model state for this result.") return - # Scale feature weights by number of features - n_features = model_state.sizes["feature"] - model_state["feature_weights"] *= n_features - - # Extract different feature types - embedding_feature_array = extract_embedding_features(model_state) - era5_feature_array = extract_era5_features(model_state) - common_feature_array = extract_common_features(model_state) - - # Generate unified colormaps - _, _, altair_colors = generate_unified_colormap(selected_result.settings) - # Display basic model state info with st.expander("Model State Information", expanded=False): + st.write(f"**Model Type:** {model_type.upper()}") st.write(f"**Variables:** {list(model_state.data_vars)}") st.write(f"**Dimensions:** {dict(model_state.sizes)}") st.write(f"**Coordinates:** {list(model_state.coords)}") + st.write(f"**Attributes:** {dict(model_state.attrs)}") - # Show statistics - st.write("**Feature Weight Statistics:**") - feature_weights = model_state["feature_weights"].to_pandas() - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Mean Weight", f"{feature_weights.mean():.4f}") - with col2: - st.metric("Max Weight", f"{feature_weights.max():.4f}") - with col3: - st.metric("Total Features", len(feature_weights)) + # Render model-specific visualizations + if model_type == "espa": + render_espa_model_state(model_state, selected_result) + elif model_type == "xgboost": + render_xgboost_model_state(model_state, selected_result) + elif model_type == "rf": + render_rf_model_state(model_state, selected_result) + elif model_type == "knn": + render_knn_model_state(model_state, selected_result) + else: + st.warning(f"Visualization for model type '{model_type}' is not yet implemented.") + + +def render_espa_model_state(model_state: xr.Dataset, selected_result): + """Render visualizations for ESPA model.""" + # Scale feature weights by number of features + n_features = model_state.sizes["feature"] + model_state["feature_weights"] *= n_features + + # Extract different feature types + embedding_feature_array = extract_embedding_features(model_state) + era5_feature_array = extract_era5_features(model_state) + common_feature_array = extract_common_features(model_state) + + # Generate unified colormaps + _, _, altair_colors = generate_unified_colormap(selected_result.settings) # Feature importance section st.header("Feature Importance") @@ -167,156 +177,362 @@ def render_model_state_page(): # Embedding features analysis (if present) if embedding_feature_array is not None: - with st.container(border=True): - st.header("🛰️ Embedding Feature Analysis") - st.markdown( - """ - Analysis of embedding features showing which aggregations, bands, and years - are most important for the model predictions. - """ - ) - - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating dimension summaries..."): - chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array) - col1, col2, col3 = st.columns(3) - with col1: - st.altair_chart(chart_agg, use_container_width=True) - with col2: - st.altair_chart(chart_band, use_container_width=True) - with col3: - st.altair_chart(chart_year, use_container_width=True) - - # Detailed heatmap - st.markdown("### Detailed Heatmap by Aggregation") - st.markdown("Shows the weight of each band-year combination for each aggregation type.") - with st.spinner("Generating heatmap..."): - heatmap_chart = plot_embedding_heatmap(embedding_feature_array) - st.altair_chart(heatmap_chart, use_container_width=True) - - # Statistics - with st.expander("Embedding Feature Statistics"): - st.write("**Overall Statistics:**") - n_emb_features = embedding_feature_array.size - mean_weight = float(embedding_feature_array.mean().values) - max_weight = float(embedding_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total Embedding Features", n_emb_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - - # Show top embedding features - st.write("**Top 10 Embedding Features:**") - emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() - top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] - st.dataframe(top_emb, width="stretch") - else: - st.info("No embedding features found in this model.") + render_embedding_features(embedding_feature_array) # ERA5 features analysis (if present) if era5_feature_array is not None: - with st.container(border=True): - st.header("⛅ ERA5 Feature Analysis") - st.markdown( - """ - Analysis of ERA5 climate features showing which variables and time periods - are most important for the model predictions. - """ - ) - - # Summary bar charts - st.markdown("### Importance by Dimension") - with st.spinner("Generating ERA5 dimension summaries..."): - chart_variable, chart_time = plot_era5_summary(era5_feature_array) - col1, col2 = st.columns(2) - with col1: - st.altair_chart(chart_variable, use_container_width=True) - with col2: - st.altair_chart(chart_time, use_container_width=True) - - # Detailed heatmap - st.markdown("### Detailed Heatmap") - st.markdown("Shows the weight of each variable-time combination.") - with st.spinner("Generating ERA5 heatmap..."): - era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) - st.altair_chart(era5_heatmap_chart, use_container_width=True) - - # Statistics - with st.expander("ERA5 Feature Statistics"): - st.write("**Overall Statistics:**") - n_era5_features = era5_feature_array.size - mean_weight = float(era5_feature_array.mean().values) - max_weight = float(era5_feature_array.max().values) - col1, col2, col3 = st.columns(3) - with col1: - st.metric("Total ERA5 Features", n_era5_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - - # Show top ERA5 features - st.write("**Top 10 ERA5 Features:**") - era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() - top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] - st.dataframe(top_era5, width="stretch") - else: - st.info("No ERA5 features found in this model.") + render_era5_features(era5_feature_array) # Common features analysis (if present) if common_feature_array is not None: - with st.container(border=True): - st.header("🗺️ Common Feature Analysis") - st.markdown( - """ - Analysis of common features including cell area, water area, land area, land ratio, - longitude, and latitude. These features provide spatial and geographic context. - """ - ) + render_common_features(common_feature_array) - # Bar chart showing all common feature weights - with st.spinner("Generating common features chart..."): - common_chart = plot_common_features(common_feature_array) - st.altair_chart(common_chart, use_container_width=True) - # Statistics - with st.expander("Common Feature Statistics"): - st.write("**Overall Statistics:**") - n_common_features = common_feature_array.size - mean_weight = float(common_feature_array.mean().values) - max_weight = float(common_feature_array.max().values) - min_weight = float(common_feature_array.min().values) - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Common Features", n_common_features) - with col2: - st.metric("Mean Weight", f"{mean_weight:.4f}") - with col3: - st.metric("Max Weight", f"{max_weight:.4f}") - with col4: - st.metric("Min Weight", f"{min_weight:.4f}") +def render_xgboost_model_state(model_state: xr.Dataset, selected_result): + """Render visualizations for XGBoost model.""" + from entropice.dashboard.plots.model_state import ( + plot_xgboost_feature_importance, + plot_xgboost_importance_comparison, + ) - # Show all common features sorted by importance - st.write("**All Common Features (by absolute weight):**") - common_df = common_feature_array.to_dataframe(name="weight").reset_index() - common_df["abs_weight"] = common_df["weight"].abs() - common_df = common_df.sort_values("abs_weight", ascending=False) - st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") + st.header("🌲 XGBoost Model Analysis") + st.markdown( + f""" + XGBoost gradient boosted tree model with **{model_state.attrs.get("n_trees", "N/A")} trees**. - st.markdown( - """ - **Interpretation:** - - **cell_area, water_area, land_area**: Spatial extent features that may indicate - size-related patterns - - **land_ratio**: Proportion of land vs water in each cell - - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns - - Positive weights indicate features that increase the probability of the positive class - - Negative weights indicate features that decrease the probability of the positive class - """ - ) - else: - st.info("No common features found in this model.") + **Objective:** {model_state.attrs.get("objective", "N/A")} + """ + ) + + # Feature importance with different types + st.subheader("Feature Importance Analysis") + st.markdown( + """ + XGBoost provides multiple ways to measure feature importance: + - **Weight**: Number of times a feature is used to split the data + - **Gain**: Average gain across all splits using the feature + - **Cover**: Average coverage across all splits using the feature + - **Total Gain**: Total gain across all splits + - **Total Cover**: Total coverage across all splits + """ + ) + + # Importance type selector + importance_type = st.selectbox( + "Select Importance Type", + options=["gain", "weight", "cover", "total_gain", "total_cover"], + index=0, + help="Choose which importance metric to visualize", + ) + + # Top N slider + top_n = st.slider( + "Number of top features to display", + min_value=5, + max_value=50, + value=20, + step=5, + help="Select how many of the most important features to visualize", + ) + + with st.spinner("Generating feature importance plot..."): + importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n) + st.altair_chart(importance_chart, use_container_width=True) + + # Comparison of importance types + st.subheader("Importance Type Comparison") + st.markdown("Compare the top features across different importance metrics.") + + with st.spinner("Generating importance comparison..."): + comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15) + st.altair_chart(comparison_chart, use_container_width=True) + + # Statistics + with st.expander("Model Statistics"): + st.write("**Overall Statistics:**") + col1, col2 = st.columns(2) + with col1: + st.metric("Number of Trees", model_state.attrs.get("n_trees", "N/A")) + with col2: + st.metric("Total Features", model_state.sizes.get("feature", "N/A")) + + +def render_rf_model_state(model_state: xr.Dataset, selected_result): + """Render visualizations for Random Forest model.""" + from entropice.dashboard.plots.model_state import plot_rf_feature_importance + + st.header("🌳 Random Forest Model Analysis") + + # Check if using cuML (which doesn't provide tree statistics) + is_cuml = "cuML" in model_state.attrs.get("description", "") + + st.markdown( + f""" + Random Forest ensemble with **{model_state.attrs.get("n_estimators", "N/A")} trees** + (max depth: {model_state.attrs.get("max_depth", "N/A")}). + """ + ) + + if is_cuml: + st.info("ℹ️ Using cuML GPU-accelerated Random Forest. Individual tree statistics are not available.") + + # Display OOB score if available + oob_score = model_state.attrs.get("oob_score") + if oob_score is not None: + st.info(f"**Out-of-Bag Score:** {oob_score:.4f}") + + # Feature importance + st.subheader("Feature Importance (Gini Importance)") + st.markdown( + """ + Random Forest uses Gini impurity to measure feature importance. Features with higher + importance values contribute more to the model's predictions. + """ + ) + + # Top N slider + top_n = st.slider( + "Number of top features to display", + min_value=5, + max_value=50, + value=20, + step=5, + help="Select how many of the most important features to visualize", + ) + + with st.spinner("Generating feature importance plot..."): + importance_chart = plot_rf_feature_importance(model_state, top_n=top_n) + st.altair_chart(importance_chart, use_container_width=True) + + # Tree statistics (only if available - sklearn RF has them, cuML RF doesn't) + if not is_cuml and "tree_depths" in model_state: + from entropice.dashboard.plots.model_state import plot_rf_tree_statistics + + st.subheader("Tree Structure Statistics") + st.markdown("Distribution of tree properties across the forest.") + + with st.spinner("Generating tree statistics..."): + chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state) + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_depths, use_container_width=True) + with col2: + st.altair_chart(chart_leaves, use_container_width=True) + with col3: + st.altair_chart(chart_nodes, use_container_width=True) + + # Statistics + with st.expander("Forest Statistics"): + st.write("**Overall Statistics:**") + depths = model_state["tree_depths"].to_pandas() + leaves = model_state["tree_n_leaves"].to_pandas() + nodes = model_state["tree_n_nodes"].to_pandas() + + col1, col2, col3 = st.columns(3) + with col1: + st.write("**Tree Depths:**") + st.metric("Mean Depth", f"{depths.mean():.2f}") + st.metric("Max Depth", f"{depths.max()}") + st.metric("Min Depth", f"{depths.min()}") + with col2: + st.write("**Leaf Counts:**") + st.metric("Mean Leaves", f"{leaves.mean():.2f}") + st.metric("Max Leaves", f"{leaves.max()}") + st.metric("Min Leaves", f"{leaves.min()}") + with col3: + st.write("**Node Counts:**") + st.metric("Mean Nodes", f"{nodes.mean():.2f}") + st.metric("Max Nodes", f"{nodes.max()}") + st.metric("Min Nodes", f"{nodes.min()}") + + +def render_knn_model_state(model_state: xr.Dataset, selected_result): + """Render visualizations for KNN model.""" + st.header("🔍 K-Nearest Neighbors Model Analysis") + st.markdown( + """ + K-Nearest Neighbors is a non-parametric, instance-based learning algorithm. + Unlike tree-based or parametric models, KNN doesn't learn feature weights or build + a model structure. Instead, it memorizes the training data and makes predictions + based on the k nearest neighbors. + """ + ) + + # Display model metadata + st.subheader("Model Configuration") + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Number of Neighbors (k)", model_state.attrs.get("n_neighbors", "N/A")) + st.metric("Training Samples", model_state.attrs.get("n_samples_fit", "N/A")) + with col2: + st.metric("Weights", model_state.attrs.get("weights", "N/A")) + st.metric("Algorithm", model_state.attrs.get("algorithm", "N/A")) + with col3: + st.metric("Metric", model_state.attrs.get("metric", "N/A")) + + st.info( + """ + **Note:** KNN doesn't have traditional feature importance or model parameters to visualize. + The model's behavior depends entirely on: + - The number of neighbors (k) + - The distance metric used + - The weighting scheme for neighbors + + To understand the model better, consider visualizing the decision boundaries on a + reduced-dimensional representation of your data (e.g., using PCA or t-SNE). + """ + ) + + +# Helper functions for embedding/era5/common features +def render_embedding_features(embedding_feature_array): + """Render embedding feature visualizations.""" + with st.container(border=True): + st.header("🛰️ Embedding Feature Analysis") + st.markdown( + """ + Analysis of embedding features showing which aggregations, bands, and years + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating dimension summaries..."): + chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array) + col1, col2, col3 = st.columns(3) + with col1: + st.altair_chart(chart_agg, use_container_width=True) + with col2: + st.altair_chart(chart_band, use_container_width=True) + with col3: + st.altair_chart(chart_year, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap by Aggregation") + st.markdown("Shows the weight of each band-year combination for each aggregation type.") + with st.spinner("Generating heatmap..."): + heatmap_chart = plot_embedding_heatmap(embedding_feature_array) + st.altair_chart(heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("Embedding Feature Statistics"): + st.write("**Overall Statistics:**") + n_emb_features = embedding_feature_array.size + mean_weight = float(embedding_feature_array.mean().values) + max_weight = float(embedding_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total Embedding Features", n_emb_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top embedding features + st.write("**Top 10 Embedding Features:**") + emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index() + top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]] + st.dataframe(top_emb, width="stretch") + + +def render_era5_features(era5_feature_array): + """Render ERA5 feature visualizations.""" + with st.container(border=True): + st.header("⛅ ERA5 Feature Analysis") + st.markdown( + """ + Analysis of ERA5 climate features showing which variables and time periods + are most important for the model predictions. + """ + ) + + # Summary bar charts + st.markdown("### Importance by Dimension") + with st.spinner("Generating ERA5 dimension summaries..."): + chart_variable, chart_time = plot_era5_summary(era5_feature_array) + col1, col2 = st.columns(2) + with col1: + st.altair_chart(chart_variable, use_container_width=True) + with col2: + st.altair_chart(chart_time, use_container_width=True) + + # Detailed heatmap + st.markdown("### Detailed Heatmap") + st.markdown("Shows the weight of each variable-time combination.") + with st.spinner("Generating ERA5 heatmap..."): + era5_heatmap_chart = plot_era5_heatmap(era5_feature_array) + st.altair_chart(era5_heatmap_chart, use_container_width=True) + + # Statistics + with st.expander("ERA5 Feature Statistics"): + st.write("**Overall Statistics:**") + n_era5_features = era5_feature_array.size + mean_weight = float(era5_feature_array.mean().values) + max_weight = float(era5_feature_array.max().values) + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Total ERA5 Features", n_era5_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + + # Show top ERA5 features + st.write("**Top 10 ERA5 Features:**") + era5_df = era5_feature_array.to_dataframe(name="weight").reset_index() + top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]] + st.dataframe(top_era5, width="stretch") + + +def render_common_features(common_feature_array): + """Render common feature visualizations.""" + with st.container(border=True): + st.header("🗺️ Common Feature Analysis") + st.markdown( + """ + Analysis of common features including cell area, water area, land area, land ratio, + longitude, and latitude. These features provide spatial and geographic context. + """ + ) + + # Bar chart showing all common feature weights + with st.spinner("Generating common features chart..."): + common_chart = plot_common_features(common_feature_array) + st.altair_chart(common_chart, use_container_width=True) + + # Statistics + with st.expander("Common Feature Statistics"): + st.write("**Overall Statistics:**") + n_common_features = common_feature_array.size + mean_weight = float(common_feature_array.mean().values) + max_weight = float(common_feature_array.max().values) + min_weight = float(common_feature_array.min().values) + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Common Features", n_common_features) + with col2: + st.metric("Mean Weight", f"{mean_weight:.4f}") + with col3: + st.metric("Max Weight", f"{max_weight:.4f}") + with col4: + st.metric("Min Weight", f"{min_weight:.4f}") + + # Show all common features sorted by importance + st.write("**All Common Features (by absolute weight):**") + common_df = common_feature_array.to_dataframe(name="weight").reset_index() + common_df["abs_weight"] = common_df["weight"].abs() + common_df = common_df.sort_values("abs_weight", ascending=False) + st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch") + + st.markdown( + """ + **Interpretation:** + - **cell_area, water_area, land_area**: Spatial extent features that may indicate + size-related patterns + - **land_ratio**: Proportion of land vs water in each cell + - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns + - Positive weights indicate features that increase the probability of the positive class + - Negative weights indicate features that decrease the probability of the positive class + """ + ) diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index eee8ca6..dd05665 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -1,11 +1,12 @@ """Hyperparameter analysis plotting functions for RandomizedSearchCV results.""" import altair as alt +import matplotlib.colors as mcolors import numpy as np import pandas as pd import streamlit as st -from entropice.dashboard.plots.colors import get_cmap +from entropice.dashboard.plots.colors import get_cmap, get_palette def render_performance_summary(results: pd.DataFrame, refit_metric: str): @@ -135,8 +136,6 @@ def render_parameter_distributions(results: pd.DataFrame): # Get colormap from colors module cmap = get_cmap("parameter_distribution") - import matplotlib.colors as mcolors - bar_color = mcolors.rgb2hex(cmap(0.5)) # Create histograms for each parameter @@ -159,37 +158,77 @@ def render_parameter_distributions(results: pd.DataFrame): param_values = results[param_col].dropna() if pd.api.types.is_numeric_dtype(param_values): + # Check if we have enough unique values for a histogram + n_unique = param_values.nunique() + + if n_unique == 1: + # Only one unique value - show as text + value = param_values.iloc[0] + formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}" + st.metric(param_name, formatted) + st.caption(f"All {len(param_values)} samples have the same value") + continue + # Numeric parameter - use histogram - df_plot = pd.DataFrame({param_name: param_values}) + df_plot = pd.DataFrame({param_name: param_values.to_numpy()}) - # Use log scale if the range spans multiple orders of magnitude - value_range = param_values.max() / (param_values.min() + 1e-10) - use_log = value_range > 100 + # Use log scale if the range spans multiple orders of magnitude OR values are very small + max_val = param_values.max() - if use_log: - chart = ( - alt.Chart(df_plot) - .mark_bar(color=bar_color) - .encode( - alt.X( - param_name, - bin=alt.Bin(maxbins=20), - scale=alt.Scale(type="log"), - title=param_name, - ), - alt.Y("count()", title="Count"), - tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"], + # Determine number of bins based on unique values + n_bins = min(20, max(5, n_unique)) + + # For very small values or large ranges, use a simple bar chart instead of histogram + if n_unique <= 10: + # Few unique values - use bar chart + value_counts = param_values.value_counts().reset_index() + value_counts.columns = [param_name, "count"] + value_counts = value_counts.sort_values(param_name) + + # Format x-axis values + if max_val < 0.01: + formatted_col = f"{param_name}_formatted" + value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}") + chart = ( + alt.Chart(value_counts) + .mark_bar(color=bar_color) + .encode( + alt.X(f"{formatted_col}:N", title=param_name, sort=None), + alt.Y("count:Q", title="Count"), + tooltip=[ + alt.Tooltip(param_name, format=".2e"), + alt.Tooltip("count", title="Count"), + ], + ) + .properties(height=250, title=param_name) + ) + else: + chart = ( + alt.Chart(value_counts) + .mark_bar(color=bar_color) + .encode( + alt.X(f"{param_name}:Q", title=param_name), + alt.Y("count:Q", title="Count"), + tooltip=[ + alt.Tooltip(param_name, format=".3f"), + alt.Tooltip("count", title="Count"), + ], + ) + .properties(height=250, title=param_name) ) - .properties(height=250, title=param_name) - ) else: + # Many unique values - use binned histogram + # Avoid log scale for binning as it can cause issues chart = ( alt.Chart(df_plot) .mark_bar(color=bar_color) .encode( - alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name), + alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name), alt.Y("count()", title="Count"), - tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"], + tooltip=[ + alt.Tooltip(f"{param_name}:Q", format=".2e" if max_val < 0.01 else ".3f", bin=True), + "count()", + ], ) .properties(height=250, title=param_name) ) @@ -276,7 +315,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): alt.Y(metric, title=metric.replace("_", " ").title()), alt.Color( metric, - scale=alt.Scale(scheme="viridis"), + scale=alt.Scale(range=get_palette(metric, n_colors=256)), legend=None, ), tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")], @@ -292,7 +331,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str): alt.Y(metric, title=metric.replace("_", " ").title()), alt.Color( metric, - scale=alt.Scale(scheme="viridis"), + scale=alt.Scale(range=get_palette(metric, n_colors=256)), legend=None, ), tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")], @@ -350,13 +389,8 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str): corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False) - # Get colormap from colors module - cmap = get_cmap("correlation") - # Sample colors for red-blue diverging scheme - colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)] - import matplotlib.colors as mcolors - - hex_colors = [mcolors.rgb2hex(c) for c in colors] + # Get colormap from colors module (use diverging colormap for correlation) + hex_colors = get_palette("correlation", n_colors=256) # Create bar chart chart = ( @@ -367,7 +401,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str): alt.Y("Parameter", sort="-x", title="Parameter"), alt.Color( "Correlation", - scale=alt.Scale(scheme="redblue", domain=[-1, 1]), + scale=alt.Scale(range=hex_colors, domain=[-1, 1]), legend=None, ), tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")], @@ -402,7 +436,7 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str): st.info("Need at least 2 numeric parameters for binned parameter space analysis.") return - # Prepare binned data + # Prepare binned data and gather parameter info results_binned = results.copy() bin_info = {} @@ -414,146 +448,160 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str): continue # Determine if we should use log bins - value_range = param_values.max() / (param_values.min() + 1e-10) - use_log = value_range > 100 or param_values.max() < 1.0 + min_val = param_values.min() + max_val = param_values.max() - if use_log: + if min_val > 0: + value_range = max_val / min_val + use_log = value_range > 100 or max_val < 0.01 + else: + use_log = False + + if use_log and min_val > 0: # Logarithmic binning for parameters spanning many orders of magnitude - log_min = np.log10(param_values.min()) - log_max = np.log10(param_values.max()) - n_bins = min(10, int(log_max - log_min) + 1) + log_min = np.log10(min_val) + log_max = np.log10(max_val) + n_bins = min(10, max(3, int(log_max - log_min) + 1)) bins = np.logspace(log_min, log_max, num=n_bins) else: # Linear binning - n_bins = min(10, int(np.sqrt(len(param_values)))) - bins = np.linspace(param_values.min(), param_values.max(), num=n_bins) + n_bins = min(10, max(3, int(np.sqrt(len(param_values))))) + bins = np.linspace(min_val, max_val, num=n_bins) results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins) - bin_info[param_name] = {"use_log": use_log, "bins": bins} + bin_info[param_name] = { + "use_log": use_log, + "bins": bins, + "min": min_val, + "max": max_val, + "n_unique": param_values.nunique(), + } # Get colormap from colors module - cmap = get_cmap(metric) + hex_colors = get_palette(metric, n_colors=256) - # Create binned plots for pairs of parameters - # Show the most meaningful combinations + # Get parameter names sorted by scale type (log parameters first, then linear) param_names = [col.replace("param_", "") for col in numeric_params] + param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p)) - # If we have 3+ parameters, create multiple plots - if len(param_names) >= 3: - st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values") + st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values") - # Plot first parameter vs others, binned by third - for i in range(1, min(3, len(param_names))): - x_param = param_names[0] - y_param = param_names[i] + # Create all pairwise combinations systematically + if len(param_names) == 2: + # Simple case: just one pair + x_param, y_param = param_names_sorted + _render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500) + else: + # Multiple parameters: create structured plots + st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}") - # Find a third parameter for faceting - facet_param = None - for p in param_names: - if p != x_param and p != y_param: - facet_param = p - break + # Strategy: Show all combinations (including duplicates with swapped axes) + # Group by first parameter to create organized sections - if facet_param and f"{facet_param}_binned" in results_binned.columns: - # Create faceted plot - plot_data = results_binned[ - [f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col] - ].dropna() + for i, x_param in enumerate(param_names_sorted): + # Get all other parameters to pair with this one + other_params = [p for p in param_names_sorted if p != x_param] - # Convert binned column to string with sorted categories - plot_data = plot_data.sort_values(f"{facet_param}_binned") - plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str) - bin_order = plot_data[f"{facet_param}_binned"].unique().tolist() + if not other_params: + continue - x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale() - y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale() + # Create expander for each primary parameter + with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)): + # Create plots in rows of 2 + n_others = len(other_params) + for row_idx in range(0, n_others, 2): + cols = st.columns(2) - # Get colormap colors - cmap_colors = get_cmap(metric) - import matplotlib.colors as mcolors + for col_idx in range(2): + other_idx = row_idx + col_idx + if other_idx >= n_others: + break - # Sample colors from colormap - hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)] + y_param = other_params[other_idx] - chart = ( - alt.Chart(plot_data) - .mark_circle(size=60, opacity=0.7) - .encode( - x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale), - y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale), - color=alt.Color( - f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title() - ), - tooltip=[ - alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"), - alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"), - alt.Tooltip(f"{score_col}:Q", format=".4f"), - f"{facet_param}_binned:N", - ], - ) - .properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})") - .facet( - facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order), - columns=4, - ) - ) - st.altair_chart(chart, use_container_width=True) + with cols[col_idx]: + _render_2d_param_plot( + results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350 + ) - # Also create a simple 2D plot without faceting - plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna() - x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale() - y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale() +def _render_2d_param_plot( + results_binned: pd.DataFrame, + x_param: str, + y_param: str, + score_col: str, + bin_info: dict, + hex_colors: list, + metric: str, + height: int = 400, +): + """Render a 2D parameter space plot. - chart = ( - alt.Chart(plot_data) - .mark_circle(size=80, opacity=0.6) - .encode( - x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale), - y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale), - color=alt.Color( - f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title() - ), - tooltip=[ - alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"), - alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"), - alt.Tooltip(f"{score_col}:Q", format=".4f"), - ], - ) - .properties(height=400, title=f"{x_param} vs {y_param}") - ) - st.altair_chart(chart, use_container_width=True) + Args: + results_binned: DataFrame with binned results. + x_param: Name of x-axis parameter. + y_param: Name of y-axis parameter. + score_col: Column name for the score. + bin_info: Dictionary with binning information for each parameter. + hex_colors: List of hex colors for the colormap. + metric: Name of the metric being visualized. + height: Height of the plot in pixels. - elif len(param_names) == 2: - # Simple 2-parameter plot - st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values") + """ + plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna() - x_param = param_names[0] - y_param = param_names[1] + if len(plot_data) == 0: + st.warning(f"No data available for {x_param} vs {y_param}") + return - plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna() + x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale() + y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale() - x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale() - y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale() + # Determine marker size based on number of points + n_points = len(plot_data) + marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40 - chart = ( - alt.Chart(plot_data) - .mark_circle(size=100, opacity=0.6) - .encode( - x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale), - y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale), - color=alt.Color( - f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title() + chart = ( + alt.Chart(plot_data) + .mark_circle(size=marker_size, opacity=0.7) + .encode( + x=alt.X( + f"param_{x_param}:Q", + title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}", + scale=x_scale, + ), + y=alt.Y( + f"param_{y_param}:Q", + title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}", + scale=y_scale, + ), + color=alt.Color( + f"{score_col}:Q", + scale=alt.Scale(range=hex_colors), + title=metric.replace("_", " ").title(), + ), + tooltip=[ + alt.Tooltip( + f"param_{x_param}:Q", + title=x_param, + format=".2e" if bin_info[x_param]["use_log"] else ".3f", ), - tooltip=[ - alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"), - alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"), - alt.Tooltip(f"{score_col}:Q", format=".4f"), - ], - ) - .properties(height=500, title=f"{x_param} vs {y_param}") + alt.Tooltip( + f"param_{y_param}:Q", + title=y_param, + format=".2e" if bin_info[y_param]["use_log"] else ".3f", + ), + alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"), + ], ) - st.altair_chart(chart, use_container_width=True) + .properties( + height=height, + title=f"{x_param} vs {y_param}", + ) + .interactive() + ) + + st.altair_chart(chart, use_container_width=True) def render_score_evolution(results: pd.DataFrame, metric: str): @@ -580,6 +628,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str): # Reshape for Altair df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type") + # Get colormap for score evolution + evolution_cmap = get_cmap("score_evolution") + evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))] + # Create line chart chart = ( alt.Chart(df_long) @@ -587,7 +639,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str): .encode( alt.X("Iteration", title="Iteration"), alt.Y("value", title=metric.replace("_", " ").title()), - alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")), + alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)), strokeDash=alt.StrokeDash( "Type", legend=None, @@ -614,6 +666,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str): st.metric("Best at Iteration", best_iter) +@st.fragment def render_multi_metric_comparison(results: pd.DataFrame): """Render comparison of multiple metrics. @@ -628,8 +681,16 @@ def render_multi_metric_comparison(results: pd.DataFrame): st.warning("Need at least 2 metrics for comparison.") return - # Let user select two metrics to compare - col1, col2 = st.columns(2) + # Get parameter columns for color options + param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] + param_names = [col.replace("param_", "") for col in param_cols] + + if not param_names: + st.warning("No parameters found for coloring.") + return + + # Let user select two metrics and color parameter + col1, col2, col3 = st.columns(3) with col1: metric1 = st.selectbox( "Select First Metric", @@ -646,21 +707,57 @@ def render_multi_metric_comparison(results: pd.DataFrame): key="metric2_select", ) + with col3: + color_by = st.selectbox( + "Color By", + options=param_names, + index=0, + key="color_select", + ) + if metric1 == metric2: st.warning("Please select different metrics.") return - # Create scatter plot + # Create scatter plot data df_plot = pd.DataFrame( { metric1: results[f"mean_test_{metric1}"], metric2: results[f"mean_test_{metric2}"], - "Iteration": range(len(results)), } ) - # Get colormap from colors module - cmap = get_cmap("iteration") + # Add color parameter to dataframe + param_col = f"param_{color_by}" + df_plot[color_by] = results[param_col] + color_col = color_by + + # Check if parameter is numeric and should use log scale + param_values = results[param_col].dropna() + if pd.api.types.is_numeric_dtype(param_values): + value_range = param_values.max() / (param_values.min() + 1e-10) + use_log = value_range > 100 + + if use_log: + color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256)) + color_format = ".2e" + else: + color_scale = alt.Scale(range=get_palette(color_by, n_colors=256)) + color_format = ".3f" + else: + # Categorical parameter + color_scale = alt.Scale(scheme="category20") + color_format = None + + # Build tooltip list + tooltip_list = [ + alt.Tooltip(metric1, format=".4f"), + alt.Tooltip(metric2, format=".4f"), + ] + if color_format: + tooltip_list.append(alt.Tooltip(color_col, format=color_format)) + else: + tooltip_list.append(color_col) chart = ( alt.Chart(df_plot) @@ -668,12 +765,8 @@ def render_multi_metric_comparison(results: pd.DataFrame): .encode( alt.X(metric1, title=metric1.replace("_", " ").title()), alt.Y(metric2, title=metric2.replace("_", " ").title()), - alt.Color("Iteration", scale=alt.Scale(scheme="viridis")), - tooltip=[ - alt.Tooltip(metric1, format=".4f"), - alt.Tooltip(metric2, format=".4f"), - "Iteration", - ], + alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()), + tooltip=tooltip_list, ) .properties(height=500) ) diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py index 94f017d..113bfac 100644 --- a/src/entropice/dashboard/plots/model_state.py +++ b/src/entropice/dashboard/plots/model_state.py @@ -447,3 +447,212 @@ def plot_common_features(common_array: xr.DataArray) -> alt.Chart: ) return chart + + +# XGBoost-specific visualizations +def plot_xgboost_feature_importance( + model_state: xr.Dataset, importance_type: str = "gain", top_n: int = 20 +) -> alt.Chart: + """Plot XGBoost feature importance for a specific importance type. + + Args: + model_state: The xarray Dataset containing the XGBoost model state. + importance_type: Type of importance ('weight', 'gain', 'cover', 'total_gain', 'total_cover'). + top_n: Number of top features to display. + + Returns: + Altair chart showing the top features by importance. + + """ + # Get the importance array + importance_key = f"feature_importance_{importance_type}" + if importance_key not in model_state: + raise ValueError(f"Importance type '{importance_type}' not found in model state") + + importance = model_state[importance_key].to_pandas() + + # Sort and take top N + top_features = importance.nlargest(top_n).sort_values(ascending=True) + + # Create DataFrame for plotting + plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()}) + + # Create bar chart + chart = ( + alt.Chart(plot_data) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"), + color=alt.value("steelblue"), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("importance:Q", format=".4f", title="Importance"), + ], + ) + .properties( + width=600, + height=400, + title=f"Top {top_n} Features by {importance_type.replace('_', ' ').title()} Importance (XGBoost)", + ) + ) + + return chart + + +def plot_xgboost_importance_comparison(model_state: xr.Dataset, top_n: int = 15) -> alt.Chart: + """Compare different importance types for XGBoost side-by-side. + + Args: + model_state: The xarray Dataset containing the XGBoost model state. + top_n: Number of top features to display. + + Returns: + Altair chart comparing importance types. + + """ + # Collect all importance types + importance_types = ["weight", "gain", "cover"] + all_data = [] + + for imp_type in importance_types: + importance_key = f"feature_importance_{imp_type}" + if importance_key in model_state: + importance = model_state[importance_key].to_pandas() + # Get top features by this importance type + top_features = importance.nlargest(top_n) + for feature, value in top_features.items(): + all_data.append( + { + "feature": feature, + "importance": value, + "type": imp_type.replace("_", " ").title(), + } + ) + + df = pd.DataFrame(all_data) + + # Create faceted bar chart + chart = ( + alt.Chart(df) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=200)), + x=alt.X("importance:Q", title="Importance Value"), + color=alt.Color("type:N", title="Importance Type", scale=alt.Scale(scheme="category10")), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("type:N", title="Type"), + alt.Tooltip("importance:Q", format=".4f", title="Importance"), + ], + ) + .properties(width=250, height=300) + .facet(facet=alt.Facet("type:N", title="Importance Type"), columns=3) + ) + + return chart + + +# Random Forest-specific visualizations +def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.Chart: + """Plot Random Forest feature importance (Gini importance). + + Args: + model_state: The xarray Dataset containing the Random Forest model state. + top_n: Number of top features to display. + + Returns: + Altair chart showing the top features by importance. + + """ + importance = model_state["feature_importance"].to_pandas() + + # Sort and take top N + top_features = importance.nlargest(top_n).sort_values(ascending=True) + + # Create DataFrame for plotting + plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()}) + + # Create bar chart + chart = ( + alt.Chart(plot_data) + .mark_bar() + .encode( + y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)), + x=alt.X("importance:Q", title="Gini Importance"), + color=alt.value("forestgreen"), + tooltip=[ + alt.Tooltip("feature:N", title="Feature"), + alt.Tooltip("importance:Q", format=".4f", title="Importance"), + ], + ) + .properties( + width=600, + height=400, + title=f"Top {top_n} Features by Gini Importance (Random Forest)", + ) + ) + + return chart + + +def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]: + """Plot Random Forest tree statistics. + + Args: + model_state: The xarray Dataset containing the Random Forest model state. + + Returns: + Tuple of three Altair charts (depth histogram, leaves histogram, nodes histogram). + + """ + # Extract tree statistics + depths = model_state["tree_depths"].to_pandas() + n_leaves = model_state["tree_n_leaves"].to_pandas() + n_nodes = model_state["tree_n_nodes"].to_pandas() + + # Create DataFrames + df_depths = pd.DataFrame({"value": depths.to_numpy()}) + df_leaves = pd.DataFrame({"value": n_leaves.to_numpy()}) + df_nodes = pd.DataFrame({"value": n_nodes.to_numpy()}) + + # Histogram for tree depths + chart_depths = ( + alt.Chart(df_depths) + .mark_bar() + .encode( + x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"), + y=alt.Y("count()", title="Number of Trees"), + color=alt.value("steelblue"), + tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")], + ) + .properties(width=300, height=200, title="Distribution of Tree Depths") + ) + + # Histogram for number of leaves + chart_leaves = ( + alt.Chart(df_leaves) + .mark_bar() + .encode( + x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"), + y=alt.Y("count()", title="Number of Trees"), + color=alt.value("forestgreen"), + tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")], + ) + .properties(width=300, height=200, title="Distribution of Leaf Counts") + ) + + # Histogram for number of nodes + chart_nodes = ( + alt.Chart(df_nodes) + .mark_bar() + .encode( + x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"), + y=alt.Y("count()", title="Number of Trees"), + color=alt.value("darkorange"), + tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")], + ) + .properties(width=300, height=200, title="Distribution of Node Counts") + ) + + return chart_depths, chart_leaves, chart_nodes diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py index c73245b..294fc65 100644 --- a/src/entropice/dashboard/plots/source_data.py +++ b/src/entropice/dashboard/plots/source_data.py @@ -761,6 +761,117 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str): st.warning("No valid data available for selected parameters") +@st.fragment +def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str): + """Render interactive pydeck map for grid cell areas. + + Args: + grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio. + grid: Grid type ('hex' or 'healpix'). + + """ + st.subheader("🗺️ Grid Cell Areas Distribution") + + # Controls + col1, col2 = st.columns([3, 1]) + + with col1: + area_metric = st.selectbox( + "Area Metric", + options=["cell_area", "land_area", "water_area", "land_ratio"], + format_func=lambda x: x.replace("_", " ").title(), + key="areas_metric", + ) + + with col2: + opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity") + + # Create GeoDataFrame + gdf = grid_gdf.copy() + + # Convert to WGS84 first + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix geometries after CRS conversion + if grid == "hex": + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry) + + # Get values for the selected metric + values = gdf_wgs84[area_metric].to_numpy() + + # Normalize values for color mapping + vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers + normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1) + + # Apply colormap based on metric type + if area_metric == "land_ratio": + cmap = get_cmap("terrain") # Different colormap for ratio + else: + cmap = get_cmap("terrain") + + colors = [cmap(val) for val in normalized] + gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors] + + # Set elevation based on normalized values for 3D visualization + gdf_wgs84["elevation"] = normalized + + # Create GeoJSON + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "cell_area": f"{float(row['cell_area']):.2f}", + "land_area": f"{float(row['land_area']):.2f}", + "water_area": f"{float(row['water_area']):.2f}", + "land_ratio": f"{float(row['land_ratio']):.2%}", + "fill_color": row["fill_color"], + "elevation": float(row["elevation"]), + }, + } + geojson_data.append(feature) + + # Create pydeck layer with 3D elevation + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=opacity, + stroked=True, + filled=True, + extruded=True, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + get_elevation="properties.elevation", + elevation_scale=500000, + line_width_min_pixels=0.5, + pickable=True, + ) + + view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0) + + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": "Cell Area: {cell_area} km²
" + "Land Area: {land_area} km²
" + "Water Area: {water_area} km²
" + "Land Ratio: {land_ratio}", + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + st.pydeck_chart(deck) + + # Show statistics + st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}") + + # Show additional info + st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.") + + @st.fragment def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str): """Render interactive pydeck map for ERA5 climate data. diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py index 88dc5f5..1e8f805 100644 --- a/src/entropice/dashboard/training_data_page.py +++ b/src/entropice/dashboard/training_data_page.py @@ -2,6 +2,7 @@ import streamlit as st +from entropice import grids from entropice.dashboard.plots.source_data import ( render_alphaearth_map, render_alphaearth_overview, @@ -9,6 +10,7 @@ from entropice.dashboard.plots.source_data import ( render_arcticdem_map, render_arcticdem_overview, render_arcticdem_plots, + render_areas_map, render_era5_map, render_era5_overview, render_era5_plots, @@ -184,7 +186,7 @@ def render_training_data_page(): st.markdown("---") # Create tabs for different data views - tab_names = ["📊 Labels"] + tab_names = ["📊 Labels", "📐 Areas"] # Add tabs for each member for member in ensemble.members: @@ -228,8 +230,50 @@ def render_training_data_page(): render_spatial_map(train_data_dict) + # Areas tab + with tabs[1]: + st.markdown("### Grid Cell Areas and Land/Water Distribution") + + st.markdown( + "This visualization shows the spatial distribution of cell areas, land areas, " + "water areas, and land ratio across the grid. The grid has been filtered to " + "include only cells in the permafrost region (>50° latitude, <85° latitude) " + "with >10% land coverage." + ) + + # Load grid data + grid_gdf = grids.open(ensemble.grid, ensemble.level) + + st.success( + f"Loaded {len(grid_gdf)} grid cells with areas ranging from " + f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²" + ) + + # Show summary statistics + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Cells", f"{len(grid_gdf):,}") + with col2: + st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²") + with col3: + st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}") + with col4: + total_land = grid_gdf["land_area"].sum() + st.metric("Total Land Area", f"{total_land:,.0f} km²") + + st.markdown("---") + + if (ensemble.grid == "hex" and ensemble.level == 6) or ( + ensemble.grid == "healpix" and ensemble.level == 10 + ): + st.warning( + "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." + ) + else: + render_areas_map(grid_gdf, ensemble.grid) + # AlphaEarth tab - tab_idx = 1 + tab_idx = 2 if "AlphaEarth" in ensemble.members: with tabs[tab_idx]: st.markdown("### AlphaEarth Embeddings Analysis") diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py index c0ce468..d64a70f 100644 --- a/src/entropice/dashboard/utils/data.py +++ b/src/entropice/dashboard/utils/data.py @@ -29,10 +29,9 @@ class TrainingResult: def from_path(cls, result_path: Path) -> "TrainingResult": """Load a TrainingResult from a given result directory path.""" result_file = result_path / "search_results.parquet" - state_file = result_path / "best_estimator_state.nc" preds_file = result_path / "predicted_probabilities.parquet" settings_file = result_path / "search_settings.toml" - if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]): + if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]): raise FileNotFoundError(f"Missing required files in {result_path}") created_at = result_path.stat().st_ctime @@ -96,9 +95,9 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD """ return { - "binary": e.create_cat_training_dataset("binary"), - "count": e.create_cat_training_dataset("count"), - "density": e.create_cat_training_dataset("density"), + "binary": e.create_cat_training_dataset("binary", device="cpu"), + "count": e.create_cat_training_dataset("count", device="cpu"), + "density": e.create_cat_training_dataset("density", device="cpu"), } diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index 5003728..b80cd4a 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -14,12 +14,15 @@ Naming conventions: import hashlib import json +from collections.abc import Generator from dataclasses import asdict, dataclass, field from functools import cached_property, lru_cache -from typing import Literal +from typing import Literal, TypedDict +import cupy as cp import cyclopts import geopandas as gpd +import numpy as np import pandas as pd import seaborn as sns import torch @@ -110,8 +113,8 @@ def bin_values( @dataclass(frozen=True, eq=False) class DatasetLabels: binned: pd.Series - train: torch.Tensor - test: torch.Tensor + train: torch.Tensor | np.ndarray | cp.ndarray + test: torch.Tensor | np.ndarray | cp.ndarray raw_values: pd.Series @cached_property @@ -135,8 +138,8 @@ class DatasetLabels: @dataclass(frozen=True, eq=False) class DatasetInputs: data: pd.DataFrame - train: torch.Tensor - test: torch.Tensor + train: torch.Tensor | np.ndarray | cp.ndarray + test: torch.Tensor | np.ndarray | cp.ndarray @dataclass(frozen=True) @@ -151,6 +154,13 @@ class CategoricalTrainingDataset: return len(self.z) +class DatasetStats(TypedDict): + target: str + num_target_samples: int + members: dict[str, dict[str, object]] + total_features: int + + @cyclopts.Parameter("*") @dataclass(frozen=True) class DatasetEnsemble: @@ -283,15 +293,15 @@ class DatasetEnsemble: arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] return arcticdem_df - def get_stats(self) -> dict: + def get_stats(self) -> DatasetStats: """Get dataset statistics. Returns: - dict: Dictionary containing target stats, member stats, and total features count. + DatasetStats: Dictionary containing target stats, member stats, and total features count. """ targets = self._read_target() - stats = { + stats: DatasetStats = { "target": self.target, "num_target_samples": len(targets), "members": {}, @@ -365,11 +375,64 @@ class DatasetEnsemble: print(f"Saved dataset to cache at {cache_file}.") return dataset - def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset: + def create_batches( + self, + batch_size: int, + filter_target_col: str | None = None, + cache_mode: Literal["n", "o", "r"] = "r", + ) -> Generator[pd.DataFrame]: + targets = self._read_target() + if len(targets) == 0: + raise ValueError("No target samples found.") + elif len(targets) < batch_size: + yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode) + return + + if filter_target_col is not None: + targets = targets.loc[targets[filter_target_col]] + + for i in range(0, len(targets), batch_size): + # n: no cache, o: overwrite cache, r: read cache if exists + cache_file = entropice.paths.get_dataset_cache( + self.id(), subset=filter_target_col, batch=(i, i + batch_size) + ) + if cache_mode == "r" and cache_file.exists(): + dataset = gpd.read_parquet(cache_file) + print( + f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" + f" and {len(dataset.columns)} features." + ) + yield dataset + else: + targets_batch = targets.iloc[i : i + batch_size] + + member_dfs = [] + for member in self.members: + if member.startswith("ERA5"): + era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] + member_dfs.append(self._prep_era5(targets_batch, era5_agg)) + elif member == "AlphaEarth": + member_dfs.append(self._prep_embeddings(targets_batch)) + elif member == "ArcticDEM": + member_dfs.append(self._prep_arcticdem(targets_batch)) + else: + raise NotImplementedError(f"Member {member} not implemented.") + + dataset = targets_batch.set_index("cell_id").join(member_dfs) + print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") + if cache_mode in ["o", "r"]: + dataset.to_parquet(cache_file) + print(f"Saved dataset to cache at {cache_file}.") + yield dataset + + def create_cat_training_dataset( + self, task: Task, device: Literal["cpu", "cuda", "torch"] + ) -> CategoricalTrainingDataset: """Create a categorical dataset for training. Args: task (Task): Task type. + device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto. Returns: CategoricalTrainingDataset: The prepared categorical training dataset. @@ -414,10 +477,23 @@ class DatasetEnsemble: split.loc[test_idx] = "test" split = split.astype("category") - X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0) - X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0) - y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0) - y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0) + X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64") + X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64") + y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64") + y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64") + + if device == "cuda": + X_train = cp.asarray(X_train) + X_test = cp.asarray(X_test) + y_train = cp.asarray(y_train) + y_test = cp.asarray(y_test) + elif device == "torch": + X_train = torch.from_numpy(X_train) + X_test = torch.from_numpy(X_test) + y_train = torch.from_numpy(y_train) + y_test = torch.from_numpy(y_test) + else: + assert device == "cpu", "Invalid device specified." return CategoricalTrainingDataset( dataset=dataset.to_crs("EPSG:4326"), diff --git a/src/entropice/inference.py b/src/entropice/inference.py index 816dd66..b2d87dc 100644 --- a/src/entropice/inference.py +++ b/src/entropice/inference.py @@ -1,6 +1,7 @@ # ruff: noqa: N806 """Inference runs on trained models.""" +import cupy as cp import geopandas as gpd import pandas as pd import torch @@ -32,35 +33,32 @@ def predict_proba( list: A list of predicted probabilities for each cell. """ - data = e.create() - print(f"Predicting probabilities for {len(data)} cells...") - # Predict in batches to avoid memory issues - batch_size = 10_000 + batch_size = 10000 preds = [] - cols_to_drop = ["geometry"] - if e.target == "darts_mllabels": - cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")] - else: - cols_to_drop += [col for col in data.columns if col.startswith("darts_")] - for i in range(0, len(data), batch_size): - batch = data.iloc[i : i + batch_size] + for batch in e.create_batches(batch_size=batch_size): + cols_to_drop = ["geometry"] + if e.target == "darts_mllabels": + cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] X_batch = batch.drop(columns=cols_to_drop).dropna() cell_ids = X_batch.index.to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() X_batch = X_batch.to_numpy(dtype="float64") X_batch = torch.asarray(X_batch, device=0) - batch_preds = clf.predict(X_batch).cpu().numpy() + batch_preds = clf.predict(X_batch) + if isinstance(batch_preds, cp.ndarray): + batch_preds = batch_preds.get() + elif torch.is_tensor(batch_preds): + batch_preds = batch_preds.cpu().numpy() batch_preds = gpd.GeoDataFrame( { "cell_id": cell_ids, "predicted_class": [classes[i] for i in batch_preds], "geometry": cell_geoms, }, - crs="epsg:3413", - ) + ).set_crs(epsg=3413, inplace=False) preds.append(batch_preds) - preds = gpd.GeoDataFrame(pd.concat(preds)) - - return preds + return gpd.GeoDataFrame(pd.concat(preds, ignore_index=True)) diff --git a/src/entropice/paths.py b/src/entropice/paths.py index bb6a197..6ca18c7 100644 --- a/src/entropice/paths.py +++ b/src/entropice/paths.py @@ -9,7 +9,7 @@ from typing import Literal DATA_DIR = ( Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice" ) -DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster +# DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster GRIDS_DIR = DATA_DIR / "grids" FIGURES_DIR = Path("figures") @@ -110,7 +110,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: return dataset_file -def get_dataset_cache(eid: str, subset: str | None = None) -> Path: +def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int] | None = None) -> Path: + if batch is not None: + eid = f"{eid}_batch{batch[0]}-{batch[1]}" if subset is None: cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet" else: diff --git a/src/entropice/training.py b/src/entropice/training.py index 61b0c46..7b0cc87 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -4,6 +4,7 @@ import pickle from dataclasses import asdict, dataclass from typing import Literal +import cupy as cp import cyclopts import pandas as pd import toml @@ -14,7 +15,6 @@ from entropy import ESPAClassifier from rich import pretty, traceback from scipy.stats import loguniform, randint from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen -from sklearn import set_config from sklearn.model_selection import KFold, RandomizedSearchCV from stopuhr import stopwatch from xgboost.sklearn import XGBClassifier @@ -26,7 +26,9 @@ from entropice.paths import get_cv_results_dir traceback.install() pretty.install() -set_config(array_api_dispatch=True) +# Disabled array_api_dispatch to avoid namespace conflicts between NumPy and CuPy +# when using XGBoost with device="cuda" +# set_config(array_api_dispatch=True) cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type] @@ -85,8 +87,8 @@ def _create_clf( objective="multi:softprob" if settings.task != "binary" else "binary:logistic", eval_metric="mlogloss" if settings.task != "binary" else "logloss", random_state=42, - tree_method="gpu_hist", - device="cuda", + tree_method="hist", + device="gpu", # Using CPU to avoid CuPy/NumPy namespace conflicts in sklearn metrics ) fit_params = {} elif settings.model == "rf": @@ -98,11 +100,10 @@ def _create_clf( fit_params = {} elif settings.model == "knn": param_grid = { - "n_neighbors": randint(3, 15), + "n_neighbors": randint(10, 200), "weights": ["uniform", "distance"], - "algorithm": ["brute", "kd_tree", "ball_tree"], } - clf = KNeighborsClassifier(random_state=42) + clf = KNeighborsClassifier() fit_params = {} else: raise ValueError(f"Unknown model: {settings.model}") @@ -148,8 +149,9 @@ def random_cv( task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". """ + device = "torch" if settings.model in ["espa"] else "cuda" print("Creating training data...") - training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task) + training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task, device=device) clf, param_grid, fit_params = _create_clf(settings) print(f"Using model: {settings.model} with parameters: {param_grid}") @@ -159,7 +161,7 @@ def random_cv( clf, param_grid, n_iter=settings.n_iter, - n_jobs=8, + n_jobs=1, cv=cv, random_state=42, verbose=10, @@ -169,14 +171,25 @@ def random_cv( print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"): - search.fit(training_data.X.train, training_data.y.train, **fit_params) + y_train = ( + training_data.y.train.get() + if settings.model == "xgboost" and isinstance(training_data.y.train, cp.ndarray) + else training_data.y.train + ) + search.fit(training_data.X.train, y_train, **fit_params) print("Best parameters combination found:") - best_parameters = search.best_estimator_.get_params() + best_estimator = search.best_estimator_ + best_parameters = best_estimator.get_params() for param_name in sorted(param_grid.keys()): print(f"{param_name}: {best_parameters[param_name]}") - test_accuracy = search.score(training_data.X.test, training_data.y.test) + y_test = ( + training_data.y.test.get() + if settings.model == "xgboost" and isinstance(training_data.y.test, cp.ndarray) + else training_data.y.test + ) + test_accuracy = search.score(training_data.X.test, y_test) print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}") @@ -207,7 +220,7 @@ def random_cv( best_model_file = results_dir / "best_estimator_model.pkl" print(f"Storing best estimator model to {best_model_file}") with open(best_model_file, "wb") as f: - pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(best_estimator, f, protocol=pickle.HIGHEST_PROTOCOL) # Store the search results results = pd.DataFrame(search.cv_results_) @@ -220,10 +233,10 @@ def random_cv( results.to_parquet(results_file) # Get the inner state of the best estimator + features = training_data.X.data.columns.tolist() + if settings.model == "espa": - best_estimator = search.best_estimator_ # Annotate the state with xarray metadata - features = training_data.X.data.columns.tolist() boxes = list(range(best_estimator.K_)) box_centers = xr.DataArray( best_estimator.S_.cpu().numpy(), @@ -260,9 +273,154 @@ def random_cv( print(f"Storing best estimator state to {state_file}") state.to_netcdf(state_file, engine="h5netcdf") + elif settings.model == "xgboost": + # Extract XGBoost-specific information + # Get the underlying booster + booster = best_estimator.get_booster() + + # Feature importance with different importance types + importance_weight = booster.get_score(importance_type="weight") + importance_gain = booster.get_score(importance_type="gain") + importance_cover = booster.get_score(importance_type="cover") + importance_total_gain = booster.get_score(importance_type="total_gain") + importance_total_cover = booster.get_score(importance_type="total_cover") + + # Create aligned arrays for all features (including zero-importance) + def align_importance(importance_dict, features): + """Align importance dict to feature list, filling missing with 0.""" + return [importance_dict.get(f, 0.0) for f in features] + + feature_importance_weight = xr.DataArray( + align_importance(importance_weight, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_weight", + attrs={"description": "Number of times a feature is used to split the data across all trees."}, + ) + feature_importance_gain = xr.DataArray( + align_importance(importance_gain, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_gain", + attrs={"description": "Average gain across all splits the feature is used in."}, + ) + feature_importance_cover = xr.DataArray( + align_importance(importance_cover, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_cover", + attrs={"description": "Average coverage across all splits the feature is used in."}, + ) + feature_importance_total_gain = xr.DataArray( + align_importance(importance_total_gain, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_total_gain", + attrs={"description": "Total gain across all splits the feature is used in."}, + ) + feature_importance_total_cover = xr.DataArray( + align_importance(importance_total_cover, features), + dims=["feature"], + coords={"feature": features}, + name="feature_importance_total_cover", + attrs={"description": "Total coverage across all splits the feature is used in."}, + ) + + # Store tree information + n_trees = booster.num_boosted_rounds() + + state = xr.Dataset( + { + "feature_importance_weight": feature_importance_weight, + "feature_importance_gain": feature_importance_gain, + "feature_importance_cover": feature_importance_cover, + "feature_importance_total_gain": feature_importance_total_gain, + "feature_importance_total_cover": feature_importance_total_cover, + }, + attrs={ + "description": "Inner state of the best XGBClassifier from RandomizedSearchCV.", + "n_trees": n_trees, + "objective": str(best_estimator.objective), + }, + ) + state_file = results_dir / "best_estimator_state.nc" + print(f"Storing best estimator state to {state_file}") + state.to_netcdf(state_file, engine="h5netcdf") + + elif settings.model == "rf": + # Extract Random Forest-specific information + # Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_) + # like sklearn does, so we can only extract feature importances and model parameters + + # Feature importances (Gini importance) + feature_importances = best_estimator.feature_importances_ + + feature_importance = xr.DataArray( + feature_importances, + dims=["feature"], + coords={"feature": features}, + name="feature_importance", + attrs={"description": "Gini importance (impurity-based feature importance)."}, + ) + + # cuML RF doesn't expose individual trees, so we store model parameters instead + n_estimators = best_estimator.n_estimators + max_depth = best_estimator.max_depth + + # OOB score if available + oob_score = None + if hasattr(best_estimator, "oob_score_") and best_estimator.oob_score: + oob_score = float(best_estimator.oob_score_) + + # cuML RandomForest doesn't provide per-tree statistics like sklearn + # Store what we have: feature importances and model configuration + attrs = { + "description": "Inner state of the best RandomForestClassifier from RandomizedSearchCV (cuML).", + "n_estimators": int(n_estimators), + "note": "cuML RandomForest does not expose individual tree statistics like sklearn", + } + + # Only add optional attributes if they have values + if max_depth is not None: + attrs["max_depth"] = int(max_depth) + if oob_score is not None: + attrs["oob_score"] = oob_score + + state = xr.Dataset( + { + "feature_importance": feature_importance, + }, + attrs=attrs, + ) + state_file = results_dir / "best_estimator_state.nc" + print(f"Storing best estimator state to {state_file}") + state.to_netcdf(state_file, engine="h5netcdf") + + elif settings.model == "knn": + # KNN doesn't have traditional feature importance + # Store information about the training data and neighbors + n_neighbors = best_estimator.n_neighbors + + # We can't extract meaningful "state" from KNN in the same way + # but we can store metadata about the model + state = xr.Dataset( + attrs={ + "description": "Metadata of the best KNeighborsClassifier from RandomizedSearchCV.", + "n_neighbors": n_neighbors, + "weights": str(best_estimator.weights), + "algorithm": str(best_estimator.algorithm), + "metric": str(best_estimator.metric), + "n_samples_fit": best_estimator.n_samples_fit_, + }, + ) + state_file = results_dir / "best_estimator_state.nc" + print(f"Storing best estimator metadata to {state_file}") + state.to_netcdf(state_file, engine="h5netcdf") + # Predict probabilities for all cells print("Predicting probabilities for all cells...") preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels) + print(f"Predicted probabilities DataFrame with {len(preds)} entries.") preds_file = results_dir / "predicted_probabilities.parquet" print(f"Storing predicted probabilities to {preds_file}") preds.to_parquet(preds_file)