diff --git a/pixi.lock b/pixi.lock
index 7635c77..c1369b7 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -9,7 +9,6 @@ environments:
- https://pypi.org/simple
options:
channel-priority: disabled
- pypi-prerelease-mode: if-necessary-or-explicit
packages:
linux-64:
- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda
@@ -2869,7 +2868,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
- sha256: 6488240242ab4091686c27ee2e721c81143e0a158bdff086518da6ccdc2971c8
+ sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@@ -2933,6 +2932,7 @@ packages:
- ruff>=0.14.9,<0.15
- pandas-stubs>=2.3.3.251201,<3
requires_python: '>=3.13,<3.14'
+ editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
version: 0.1.0
diff --git a/pyproject.toml b/pyproject.toml
index 5ec1bff..87d1773 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,8 @@ dependencies = [
"pydeck>=0.9.1,<0.10",
"pypalettes>=0.2.1,<0.3",
"ty>=0.0.2,<0.0.3",
- "ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3",
+ "ruff>=0.14.9,<0.15",
+ "pandas-stubs>=2.3.3.251201,<3",
]
[project.scripts]
@@ -118,6 +119,7 @@ platforms = ["linux-64"]
[tool.pixi.activation.env]
SCIPY_ARRAY_API = "1"
+FAST_DATA_DIR = "./data"
[tool.pixi.system-requirements]
cuda = "12"
diff --git a/src/entropice/dashboard/model_state_page.py b/src/entropice/dashboard/model_state_page.py
index 88558fa..3e424dd 100644
--- a/src/entropice/dashboard/model_state_page.py
+++ b/src/entropice/dashboard/model_state_page.py
@@ -1,6 +1,7 @@
"""Model State page for the Entropice dashboard."""
import streamlit as st
+import xarray as xr
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import (
@@ -43,6 +44,9 @@ def render_model_state_page():
)
selected_result = result_options[selected_name]
+ # Get the model type from settings
+ model_type = selected_result.settings.get("model", "espa")
+
# Load model state
with st.spinner("Loading model state..."):
model_state = load_model_state(selected_result)
@@ -50,34 +54,40 @@ def render_model_state_page():
st.error("Could not load model state for this result.")
return
- # Scale feature weights by number of features
- n_features = model_state.sizes["feature"]
- model_state["feature_weights"] *= n_features
-
- # Extract different feature types
- embedding_feature_array = extract_embedding_features(model_state)
- era5_feature_array = extract_era5_features(model_state)
- common_feature_array = extract_common_features(model_state)
-
- # Generate unified colormaps
- _, _, altair_colors = generate_unified_colormap(selected_result.settings)
-
# Display basic model state info
with st.expander("Model State Information", expanded=False):
+ st.write(f"**Model Type:** {model_type.upper()}")
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
+ st.write(f"**Attributes:** {dict(model_state.attrs)}")
- # Show statistics
- st.write("**Feature Weight Statistics:**")
- feature_weights = model_state["feature_weights"].to_pandas()
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
- with col2:
- st.metric("Max Weight", f"{feature_weights.max():.4f}")
- with col3:
- st.metric("Total Features", len(feature_weights))
+ # Render model-specific visualizations
+ if model_type == "espa":
+ render_espa_model_state(model_state, selected_result)
+ elif model_type == "xgboost":
+ render_xgboost_model_state(model_state, selected_result)
+ elif model_type == "rf":
+ render_rf_model_state(model_state, selected_result)
+ elif model_type == "knn":
+ render_knn_model_state(model_state, selected_result)
+ else:
+ st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
+
+
+def render_espa_model_state(model_state: xr.Dataset, selected_result):
+ """Render visualizations for ESPA model."""
+ # Scale feature weights by number of features
+ n_features = model_state.sizes["feature"]
+ model_state["feature_weights"] *= n_features
+
+ # Extract different feature types
+ embedding_feature_array = extract_embedding_features(model_state)
+ era5_feature_array = extract_era5_features(model_state)
+ common_feature_array = extract_common_features(model_state)
+
+ # Generate unified colormaps
+ _, _, altair_colors = generate_unified_colormap(selected_result.settings)
# Feature importance section
st.header("Feature Importance")
@@ -167,156 +177,362 @@ def render_model_state_page():
# Embedding features analysis (if present)
if embedding_feature_array is not None:
- with st.container(border=True):
- st.header("🛰️ Embedding Feature Analysis")
- st.markdown(
- """
- Analysis of embedding features showing which aggregations, bands, and years
- are most important for the model predictions.
- """
- )
-
- # Summary bar charts
- st.markdown("### Importance by Dimension")
- with st.spinner("Generating dimension summaries..."):
- chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
- col1, col2, col3 = st.columns(3)
- with col1:
- st.altair_chart(chart_agg, use_container_width=True)
- with col2:
- st.altair_chart(chart_band, use_container_width=True)
- with col3:
- st.altair_chart(chart_year, use_container_width=True)
-
- # Detailed heatmap
- st.markdown("### Detailed Heatmap by Aggregation")
- st.markdown("Shows the weight of each band-year combination for each aggregation type.")
- with st.spinner("Generating heatmap..."):
- heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
- st.altair_chart(heatmap_chart, use_container_width=True)
-
- # Statistics
- with st.expander("Embedding Feature Statistics"):
- st.write("**Overall Statistics:**")
- n_emb_features = embedding_feature_array.size
- mean_weight = float(embedding_feature_array.mean().values)
- max_weight = float(embedding_feature_array.max().values)
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric("Total Embedding Features", n_emb_features)
- with col2:
- st.metric("Mean Weight", f"{mean_weight:.4f}")
- with col3:
- st.metric("Max Weight", f"{max_weight:.4f}")
-
- # Show top embedding features
- st.write("**Top 10 Embedding Features:**")
- emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
- top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
- st.dataframe(top_emb, width="stretch")
- else:
- st.info("No embedding features found in this model.")
+ render_embedding_features(embedding_feature_array)
# ERA5 features analysis (if present)
if era5_feature_array is not None:
- with st.container(border=True):
- st.header("⛅ ERA5 Feature Analysis")
- st.markdown(
- """
- Analysis of ERA5 climate features showing which variables and time periods
- are most important for the model predictions.
- """
- )
-
- # Summary bar charts
- st.markdown("### Importance by Dimension")
- with st.spinner("Generating ERA5 dimension summaries..."):
- chart_variable, chart_time = plot_era5_summary(era5_feature_array)
- col1, col2 = st.columns(2)
- with col1:
- st.altair_chart(chart_variable, use_container_width=True)
- with col2:
- st.altair_chart(chart_time, use_container_width=True)
-
- # Detailed heatmap
- st.markdown("### Detailed Heatmap")
- st.markdown("Shows the weight of each variable-time combination.")
- with st.spinner("Generating ERA5 heatmap..."):
- era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
- st.altair_chart(era5_heatmap_chart, use_container_width=True)
-
- # Statistics
- with st.expander("ERA5 Feature Statistics"):
- st.write("**Overall Statistics:**")
- n_era5_features = era5_feature_array.size
- mean_weight = float(era5_feature_array.mean().values)
- max_weight = float(era5_feature_array.max().values)
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric("Total ERA5 Features", n_era5_features)
- with col2:
- st.metric("Mean Weight", f"{mean_weight:.4f}")
- with col3:
- st.metric("Max Weight", f"{max_weight:.4f}")
-
- # Show top ERA5 features
- st.write("**Top 10 ERA5 Features:**")
- era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
- top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
- st.dataframe(top_era5, width="stretch")
- else:
- st.info("No ERA5 features found in this model.")
+ render_era5_features(era5_feature_array)
# Common features analysis (if present)
if common_feature_array is not None:
- with st.container(border=True):
- st.header("🗺️ Common Feature Analysis")
- st.markdown(
- """
- Analysis of common features including cell area, water area, land area, land ratio,
- longitude, and latitude. These features provide spatial and geographic context.
- """
- )
+ render_common_features(common_feature_array)
- # Bar chart showing all common feature weights
- with st.spinner("Generating common features chart..."):
- common_chart = plot_common_features(common_feature_array)
- st.altair_chart(common_chart, use_container_width=True)
- # Statistics
- with st.expander("Common Feature Statistics"):
- st.write("**Overall Statistics:**")
- n_common_features = common_feature_array.size
- mean_weight = float(common_feature_array.mean().values)
- max_weight = float(common_feature_array.max().values)
- min_weight = float(common_feature_array.min().values)
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("Total Common Features", n_common_features)
- with col2:
- st.metric("Mean Weight", f"{mean_weight:.4f}")
- with col3:
- st.metric("Max Weight", f"{max_weight:.4f}")
- with col4:
- st.metric("Min Weight", f"{min_weight:.4f}")
+def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
+ """Render visualizations for XGBoost model."""
+ from entropice.dashboard.plots.model_state import (
+ plot_xgboost_feature_importance,
+ plot_xgboost_importance_comparison,
+ )
- # Show all common features sorted by importance
- st.write("**All Common Features (by absolute weight):**")
- common_df = common_feature_array.to_dataframe(name="weight").reset_index()
- common_df["abs_weight"] = common_df["weight"].abs()
- common_df = common_df.sort_values("abs_weight", ascending=False)
- st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
+ st.header("🌲 XGBoost Model Analysis")
+ st.markdown(
+ f"""
+ XGBoost gradient boosted tree model with **{model_state.attrs.get("n_trees", "N/A")} trees**.
- st.markdown(
- """
- **Interpretation:**
- - **cell_area, water_area, land_area**: Spatial extent features that may indicate
- size-related patterns
- - **land_ratio**: Proportion of land vs water in each cell
- - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- - Positive weights indicate features that increase the probability of the positive class
- - Negative weights indicate features that decrease the probability of the positive class
- """
- )
- else:
- st.info("No common features found in this model.")
+ **Objective:** {model_state.attrs.get("objective", "N/A")}
+ """
+ )
+
+ # Feature importance with different types
+ st.subheader("Feature Importance Analysis")
+ st.markdown(
+ """
+ XGBoost provides multiple ways to measure feature importance:
+ - **Weight**: Number of times a feature is used to split the data
+ - **Gain**: Average gain across all splits using the feature
+ - **Cover**: Average coverage across all splits using the feature
+ - **Total Gain**: Total gain across all splits
+ - **Total Cover**: Total coverage across all splits
+ """
+ )
+
+ # Importance type selector
+ importance_type = st.selectbox(
+ "Select Importance Type",
+ options=["gain", "weight", "cover", "total_gain", "total_cover"],
+ index=0,
+ help="Choose which importance metric to visualize",
+ )
+
+ # Top N slider
+ top_n = st.slider(
+ "Number of top features to display",
+ min_value=5,
+ max_value=50,
+ value=20,
+ step=5,
+ help="Select how many of the most important features to visualize",
+ )
+
+ with st.spinner("Generating feature importance plot..."):
+ importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
+ st.altair_chart(importance_chart, use_container_width=True)
+
+ # Comparison of importance types
+ st.subheader("Importance Type Comparison")
+ st.markdown("Compare the top features across different importance metrics.")
+
+ with st.spinner("Generating importance comparison..."):
+ comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
+ st.altair_chart(comparison_chart, use_container_width=True)
+
+ # Statistics
+ with st.expander("Model Statistics"):
+ st.write("**Overall Statistics:**")
+ col1, col2 = st.columns(2)
+ with col1:
+ st.metric("Number of Trees", model_state.attrs.get("n_trees", "N/A"))
+ with col2:
+ st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
+
+
+def render_rf_model_state(model_state: xr.Dataset, selected_result):
+ """Render visualizations for Random Forest model."""
+ from entropice.dashboard.plots.model_state import plot_rf_feature_importance
+
+ st.header("🌳 Random Forest Model Analysis")
+
+ # Check if using cuML (which doesn't provide tree statistics)
+ is_cuml = "cuML" in model_state.attrs.get("description", "")
+
+ st.markdown(
+ f"""
+ Random Forest ensemble with **{model_state.attrs.get("n_estimators", "N/A")} trees**
+ (max depth: {model_state.attrs.get("max_depth", "N/A")}).
+ """
+ )
+
+ if is_cuml:
+ st.info("ℹ️ Using cuML GPU-accelerated Random Forest. Individual tree statistics are not available.")
+
+ # Display OOB score if available
+ oob_score = model_state.attrs.get("oob_score")
+ if oob_score is not None:
+ st.info(f"**Out-of-Bag Score:** {oob_score:.4f}")
+
+ # Feature importance
+ st.subheader("Feature Importance (Gini Importance)")
+ st.markdown(
+ """
+ Random Forest uses Gini impurity to measure feature importance. Features with higher
+ importance values contribute more to the model's predictions.
+ """
+ )
+
+ # Top N slider
+ top_n = st.slider(
+ "Number of top features to display",
+ min_value=5,
+ max_value=50,
+ value=20,
+ step=5,
+ help="Select how many of the most important features to visualize",
+ )
+
+ with st.spinner("Generating feature importance plot..."):
+ importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
+ st.altair_chart(importance_chart, use_container_width=True)
+
+ # Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
+ if not is_cuml and "tree_depths" in model_state:
+ from entropice.dashboard.plots.model_state import plot_rf_tree_statistics
+
+ st.subheader("Tree Structure Statistics")
+ st.markdown("Distribution of tree properties across the forest.")
+
+ with st.spinner("Generating tree statistics..."):
+ chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.altair_chart(chart_depths, use_container_width=True)
+ with col2:
+ st.altair_chart(chart_leaves, use_container_width=True)
+ with col3:
+ st.altair_chart(chart_nodes, use_container_width=True)
+
+ # Statistics
+ with st.expander("Forest Statistics"):
+ st.write("**Overall Statistics:**")
+ depths = model_state["tree_depths"].to_pandas()
+ leaves = model_state["tree_n_leaves"].to_pandas()
+ nodes = model_state["tree_n_nodes"].to_pandas()
+
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.write("**Tree Depths:**")
+ st.metric("Mean Depth", f"{depths.mean():.2f}")
+ st.metric("Max Depth", f"{depths.max()}")
+ st.metric("Min Depth", f"{depths.min()}")
+ with col2:
+ st.write("**Leaf Counts:**")
+ st.metric("Mean Leaves", f"{leaves.mean():.2f}")
+ st.metric("Max Leaves", f"{leaves.max()}")
+ st.metric("Min Leaves", f"{leaves.min()}")
+ with col3:
+ st.write("**Node Counts:**")
+ st.metric("Mean Nodes", f"{nodes.mean():.2f}")
+ st.metric("Max Nodes", f"{nodes.max()}")
+ st.metric("Min Nodes", f"{nodes.min()}")
+
+
+def render_knn_model_state(model_state: xr.Dataset, selected_result):
+ """Render visualizations for KNN model."""
+ st.header("🔍 K-Nearest Neighbors Model Analysis")
+ st.markdown(
+ """
+ K-Nearest Neighbors is a non-parametric, instance-based learning algorithm.
+ Unlike tree-based or parametric models, KNN doesn't learn feature weights or build
+ a model structure. Instead, it memorizes the training data and makes predictions
+ based on the k nearest neighbors.
+ """
+ )
+
+ # Display model metadata
+ st.subheader("Model Configuration")
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Number of Neighbors (k)", model_state.attrs.get("n_neighbors", "N/A"))
+ st.metric("Training Samples", model_state.attrs.get("n_samples_fit", "N/A"))
+ with col2:
+ st.metric("Weights", model_state.attrs.get("weights", "N/A"))
+ st.metric("Algorithm", model_state.attrs.get("algorithm", "N/A"))
+ with col3:
+ st.metric("Metric", model_state.attrs.get("metric", "N/A"))
+
+ st.info(
+ """
+ **Note:** KNN doesn't have traditional feature importance or model parameters to visualize.
+ The model's behavior depends entirely on:
+ - The number of neighbors (k)
+ - The distance metric used
+ - The weighting scheme for neighbors
+
+ To understand the model better, consider visualizing the decision boundaries on a
+ reduced-dimensional representation of your data (e.g., using PCA or t-SNE).
+ """
+ )
+
+
+# Helper functions for embedding/era5/common features
+def render_embedding_features(embedding_feature_array):
+ """Render embedding feature visualizations."""
+ with st.container(border=True):
+ st.header("🛰️ Embedding Feature Analysis")
+ st.markdown(
+ """
+ Analysis of embedding features showing which aggregations, bands, and years
+ are most important for the model predictions.
+ """
+ )
+
+ # Summary bar charts
+ st.markdown("### Importance by Dimension")
+ with st.spinner("Generating dimension summaries..."):
+ chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.altair_chart(chart_agg, use_container_width=True)
+ with col2:
+ st.altair_chart(chart_band, use_container_width=True)
+ with col3:
+ st.altair_chart(chart_year, use_container_width=True)
+
+ # Detailed heatmap
+ st.markdown("### Detailed Heatmap by Aggregation")
+ st.markdown("Shows the weight of each band-year combination for each aggregation type.")
+ with st.spinner("Generating heatmap..."):
+ heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
+ st.altair_chart(heatmap_chart, use_container_width=True)
+
+ # Statistics
+ with st.expander("Embedding Feature Statistics"):
+ st.write("**Overall Statistics:**")
+ n_emb_features = embedding_feature_array.size
+ mean_weight = float(embedding_feature_array.mean().values)
+ max_weight = float(embedding_feature_array.max().values)
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Total Embedding Features", n_emb_features)
+ with col2:
+ st.metric("Mean Weight", f"{mean_weight:.4f}")
+ with col3:
+ st.metric("Max Weight", f"{max_weight:.4f}")
+
+ # Show top embedding features
+ st.write("**Top 10 Embedding Features:**")
+ emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
+ top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
+ st.dataframe(top_emb, width="stretch")
+
+
+def render_era5_features(era5_feature_array):
+ """Render ERA5 feature visualizations."""
+ with st.container(border=True):
+ st.header("⛅ ERA5 Feature Analysis")
+ st.markdown(
+ """
+ Analysis of ERA5 climate features showing which variables and time periods
+ are most important for the model predictions.
+ """
+ )
+
+ # Summary bar charts
+ st.markdown("### Importance by Dimension")
+ with st.spinner("Generating ERA5 dimension summaries..."):
+ chart_variable, chart_time = plot_era5_summary(era5_feature_array)
+ col1, col2 = st.columns(2)
+ with col1:
+ st.altair_chart(chart_variable, use_container_width=True)
+ with col2:
+ st.altair_chart(chart_time, use_container_width=True)
+
+ # Detailed heatmap
+ st.markdown("### Detailed Heatmap")
+ st.markdown("Shows the weight of each variable-time combination.")
+ with st.spinner("Generating ERA5 heatmap..."):
+ era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
+ st.altair_chart(era5_heatmap_chart, use_container_width=True)
+
+ # Statistics
+ with st.expander("ERA5 Feature Statistics"):
+ st.write("**Overall Statistics:**")
+ n_era5_features = era5_feature_array.size
+ mean_weight = float(era5_feature_array.mean().values)
+ max_weight = float(era5_feature_array.max().values)
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Total ERA5 Features", n_era5_features)
+ with col2:
+ st.metric("Mean Weight", f"{mean_weight:.4f}")
+ with col3:
+ st.metric("Max Weight", f"{max_weight:.4f}")
+
+ # Show top ERA5 features
+ st.write("**Top 10 ERA5 Features:**")
+ era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
+ top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
+ st.dataframe(top_era5, width="stretch")
+
+
+def render_common_features(common_feature_array):
+ """Render common feature visualizations."""
+ with st.container(border=True):
+ st.header("🗺️ Common Feature Analysis")
+ st.markdown(
+ """
+ Analysis of common features including cell area, water area, land area, land ratio,
+ longitude, and latitude. These features provide spatial and geographic context.
+ """
+ )
+
+ # Bar chart showing all common feature weights
+ with st.spinner("Generating common features chart..."):
+ common_chart = plot_common_features(common_feature_array)
+ st.altair_chart(common_chart, use_container_width=True)
+
+ # Statistics
+ with st.expander("Common Feature Statistics"):
+ st.write("**Overall Statistics:**")
+ n_common_features = common_feature_array.size
+ mean_weight = float(common_feature_array.mean().values)
+ max_weight = float(common_feature_array.max().values)
+ min_weight = float(common_feature_array.min().values)
+ col1, col2, col3, col4 = st.columns(4)
+ with col1:
+ st.metric("Total Common Features", n_common_features)
+ with col2:
+ st.metric("Mean Weight", f"{mean_weight:.4f}")
+ with col3:
+ st.metric("Max Weight", f"{max_weight:.4f}")
+ with col4:
+ st.metric("Min Weight", f"{min_weight:.4f}")
+
+ # Show all common features sorted by importance
+ st.write("**All Common Features (by absolute weight):**")
+ common_df = common_feature_array.to_dataframe(name="weight").reset_index()
+ common_df["abs_weight"] = common_df["weight"].abs()
+ common_df = common_df.sort_values("abs_weight", ascending=False)
+ st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
+
+ st.markdown(
+ """
+ **Interpretation:**
+ - **cell_area, water_area, land_area**: Spatial extent features that may indicate
+ size-related patterns
+ - **land_ratio**: Proportion of land vs water in each cell
+ - **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
+ - Positive weights indicate features that increase the probability of the positive class
+ - Negative weights indicate features that decrease the probability of the positive class
+ """
+ )
diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py
index eee8ca6..dd05665 100644
--- a/src/entropice/dashboard/plots/hyperparameter_analysis.py
+++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py
@@ -1,11 +1,12 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt
+import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import streamlit as st
-from entropice.dashboard.plots.colors import get_cmap
+from entropice.dashboard.plots.colors import get_cmap, get_palette
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
@@ -135,8 +136,6 @@ def render_parameter_distributions(results: pd.DataFrame):
# Get colormap from colors module
cmap = get_cmap("parameter_distribution")
- import matplotlib.colors as mcolors
-
bar_color = mcolors.rgb2hex(cmap(0.5))
# Create histograms for each parameter
@@ -159,37 +158,77 @@ def render_parameter_distributions(results: pd.DataFrame):
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
+ # Check if we have enough unique values for a histogram
+ n_unique = param_values.nunique()
+
+ if n_unique == 1:
+ # Only one unique value - show as text
+ value = param_values.iloc[0]
+ formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}"
+ st.metric(param_name, formatted)
+ st.caption(f"All {len(param_values)} samples have the same value")
+ continue
+
# Numeric parameter - use histogram
- df_plot = pd.DataFrame({param_name: param_values})
+ df_plot = pd.DataFrame({param_name: param_values.to_numpy()})
- # Use log scale if the range spans multiple orders of magnitude
- value_range = param_values.max() / (param_values.min() + 1e-10)
- use_log = value_range > 100
+ # Use log scale if the range spans multiple orders of magnitude OR values are very small
+ max_val = param_values.max()
- if use_log:
- chart = (
- alt.Chart(df_plot)
- .mark_bar(color=bar_color)
- .encode(
- alt.X(
- param_name,
- bin=alt.Bin(maxbins=20),
- scale=alt.Scale(type="log"),
- title=param_name,
- ),
- alt.Y("count()", title="Count"),
- tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
+ # Determine number of bins based on unique values
+ n_bins = min(20, max(5, n_unique))
+
+ # For very small values or large ranges, use a simple bar chart instead of histogram
+ if n_unique <= 10:
+ # Few unique values - use bar chart
+ value_counts = param_values.value_counts().reset_index()
+ value_counts.columns = [param_name, "count"]
+ value_counts = value_counts.sort_values(param_name)
+
+ # Format x-axis values
+ if max_val < 0.01:
+ formatted_col = f"{param_name}_formatted"
+ value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}")
+ chart = (
+ alt.Chart(value_counts)
+ .mark_bar(color=bar_color)
+ .encode(
+ alt.X(f"{formatted_col}:N", title=param_name, sort=None),
+ alt.Y("count:Q", title="Count"),
+ tooltip=[
+ alt.Tooltip(param_name, format=".2e"),
+ alt.Tooltip("count", title="Count"),
+ ],
+ )
+ .properties(height=250, title=param_name)
+ )
+ else:
+ chart = (
+ alt.Chart(value_counts)
+ .mark_bar(color=bar_color)
+ .encode(
+ alt.X(f"{param_name}:Q", title=param_name),
+ alt.Y("count:Q", title="Count"),
+ tooltip=[
+ alt.Tooltip(param_name, format=".3f"),
+ alt.Tooltip("count", title="Count"),
+ ],
+ )
+ .properties(height=250, title=param_name)
)
- .properties(height=250, title=param_name)
- )
else:
+ # Many unique values - use binned histogram
+ # Avoid log scale for binning as it can cause issues
chart = (
alt.Chart(df_plot)
.mark_bar(color=bar_color)
.encode(
- alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
+ alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name),
alt.Y("count()", title="Count"),
- tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
+ tooltip=[
+ alt.Tooltip(f"{param_name}:Q", format=".2e" if max_val < 0.01 else ".3f", bin=True),
+ "count()",
+ ],
)
.properties(height=250, title=param_name)
)
@@ -276,7 +315,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
- scale=alt.Scale(scheme="viridis"),
+ scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
@@ -292,7 +331,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
- scale=alt.Scale(scheme="viridis"),
+ scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
@@ -350,13 +389,8 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
- # Get colormap from colors module
- cmap = get_cmap("correlation")
- # Sample colors for red-blue diverging scheme
- colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
- import matplotlib.colors as mcolors
-
- hex_colors = [mcolors.rgb2hex(c) for c in colors]
+ # Get colormap from colors module (use diverging colormap for correlation)
+ hex_colors = get_palette("correlation", n_colors=256)
# Create bar chart
chart = (
@@ -367,7 +401,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
alt.Y("Parameter", sort="-x", title="Parameter"),
alt.Color(
"Correlation",
- scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
+ scale=alt.Scale(range=hex_colors, domain=[-1, 1]),
legend=None,
),
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
@@ -402,7 +436,7 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
return
- # Prepare binned data
+ # Prepare binned data and gather parameter info
results_binned = results.copy()
bin_info = {}
@@ -414,146 +448,160 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
continue
# Determine if we should use log bins
- value_range = param_values.max() / (param_values.min() + 1e-10)
- use_log = value_range > 100 or param_values.max() < 1.0
+ min_val = param_values.min()
+ max_val = param_values.max()
- if use_log:
+ if min_val > 0:
+ value_range = max_val / min_val
+ use_log = value_range > 100 or max_val < 0.01
+ else:
+ use_log = False
+
+ if use_log and min_val > 0:
# Logarithmic binning for parameters spanning many orders of magnitude
- log_min = np.log10(param_values.min())
- log_max = np.log10(param_values.max())
- n_bins = min(10, int(log_max - log_min) + 1)
+ log_min = np.log10(min_val)
+ log_max = np.log10(max_val)
+ n_bins = min(10, max(3, int(log_max - log_min) + 1))
bins = np.logspace(log_min, log_max, num=n_bins)
else:
# Linear binning
- n_bins = min(10, int(np.sqrt(len(param_values))))
- bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
+ n_bins = min(10, max(3, int(np.sqrt(len(param_values)))))
+ bins = np.linspace(min_val, max_val, num=n_bins)
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
- bin_info[param_name] = {"use_log": use_log, "bins": bins}
+ bin_info[param_name] = {
+ "use_log": use_log,
+ "bins": bins,
+ "min": min_val,
+ "max": max_val,
+ "n_unique": param_values.nunique(),
+ }
# Get colormap from colors module
- cmap = get_cmap(metric)
+ hex_colors = get_palette(metric, n_colors=256)
- # Create binned plots for pairs of parameters
- # Show the most meaningful combinations
+ # Get parameter names sorted by scale type (log parameters first, then linear)
param_names = [col.replace("param_", "") for col in numeric_params]
+ param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p))
- # If we have 3+ parameters, create multiple plots
- if len(param_names) >= 3:
- st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
+ st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
- # Plot first parameter vs others, binned by third
- for i in range(1, min(3, len(param_names))):
- x_param = param_names[0]
- y_param = param_names[i]
+ # Create all pairwise combinations systematically
+ if len(param_names) == 2:
+ # Simple case: just one pair
+ x_param, y_param = param_names_sorted
+ _render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500)
+ else:
+ # Multiple parameters: create structured plots
+ st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
- # Find a third parameter for faceting
- facet_param = None
- for p in param_names:
- if p != x_param and p != y_param:
- facet_param = p
- break
+ # Strategy: Show all combinations (including duplicates with swapped axes)
+ # Group by first parameter to create organized sections
- if facet_param and f"{facet_param}_binned" in results_binned.columns:
- # Create faceted plot
- plot_data = results_binned[
- [f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
- ].dropna()
+ for i, x_param in enumerate(param_names_sorted):
+ # Get all other parameters to pair with this one
+ other_params = [p for p in param_names_sorted if p != x_param]
- # Convert binned column to string with sorted categories
- plot_data = plot_data.sort_values(f"{facet_param}_binned")
- plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
- bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
+ if not other_params:
+ continue
- x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
- y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
+ # Create expander for each primary parameter
+ with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)):
+ # Create plots in rows of 2
+ n_others = len(other_params)
+ for row_idx in range(0, n_others, 2):
+ cols = st.columns(2)
- # Get colormap colors
- cmap_colors = get_cmap(metric)
- import matplotlib.colors as mcolors
+ for col_idx in range(2):
+ other_idx = row_idx + col_idx
+ if other_idx >= n_others:
+ break
- # Sample colors from colormap
- hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
+ y_param = other_params[other_idx]
- chart = (
- alt.Chart(plot_data)
- .mark_circle(size=60, opacity=0.7)
- .encode(
- x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
- y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
- color=alt.Color(
- f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
- ),
- tooltip=[
- alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
- alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
- alt.Tooltip(f"{score_col}:Q", format=".4f"),
- f"{facet_param}_binned:N",
- ],
- )
- .properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
- .facet(
- facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
- columns=4,
- )
- )
- st.altair_chart(chart, use_container_width=True)
+ with cols[col_idx]:
+ _render_2d_param_plot(
+ results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350
+ )
- # Also create a simple 2D plot without faceting
- plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
- x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
- y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
+def _render_2d_param_plot(
+ results_binned: pd.DataFrame,
+ x_param: str,
+ y_param: str,
+ score_col: str,
+ bin_info: dict,
+ hex_colors: list,
+ metric: str,
+ height: int = 400,
+):
+ """Render a 2D parameter space plot.
- chart = (
- alt.Chart(plot_data)
- .mark_circle(size=80, opacity=0.6)
- .encode(
- x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
- y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
- color=alt.Color(
- f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
- ),
- tooltip=[
- alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
- alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
- alt.Tooltip(f"{score_col}:Q", format=".4f"),
- ],
- )
- .properties(height=400, title=f"{x_param} vs {y_param}")
- )
- st.altair_chart(chart, use_container_width=True)
+ Args:
+ results_binned: DataFrame with binned results.
+ x_param: Name of x-axis parameter.
+ y_param: Name of y-axis parameter.
+ score_col: Column name for the score.
+ bin_info: Dictionary with binning information for each parameter.
+ hex_colors: List of hex colors for the colormap.
+ metric: Name of the metric being visualized.
+ height: Height of the plot in pixels.
- elif len(param_names) == 2:
- # Simple 2-parameter plot
- st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
+ """
+ plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
- x_param = param_names[0]
- y_param = param_names[1]
+ if len(plot_data) == 0:
+ st.warning(f"No data available for {x_param} vs {y_param}")
+ return
- plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
+ x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
+ y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
- x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
- y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
+ # Determine marker size based on number of points
+ n_points = len(plot_data)
+ marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40
- chart = (
- alt.Chart(plot_data)
- .mark_circle(size=100, opacity=0.6)
- .encode(
- x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
- y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
- color=alt.Color(
- f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
+ chart = (
+ alt.Chart(plot_data)
+ .mark_circle(size=marker_size, opacity=0.7)
+ .encode(
+ x=alt.X(
+ f"param_{x_param}:Q",
+ title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}",
+ scale=x_scale,
+ ),
+ y=alt.Y(
+ f"param_{y_param}:Q",
+ title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}",
+ scale=y_scale,
+ ),
+ color=alt.Color(
+ f"{score_col}:Q",
+ scale=alt.Scale(range=hex_colors),
+ title=metric.replace("_", " ").title(),
+ ),
+ tooltip=[
+ alt.Tooltip(
+ f"param_{x_param}:Q",
+ title=x_param,
+ format=".2e" if bin_info[x_param]["use_log"] else ".3f",
),
- tooltip=[
- alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
- alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
- alt.Tooltip(f"{score_col}:Q", format=".4f"),
- ],
- )
- .properties(height=500, title=f"{x_param} vs {y_param}")
+ alt.Tooltip(
+ f"param_{y_param}:Q",
+ title=y_param,
+ format=".2e" if bin_info[y_param]["use_log"] else ".3f",
+ ),
+ alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"),
+ ],
)
- st.altair_chart(chart, use_container_width=True)
+ .properties(
+ height=height,
+ title=f"{x_param} vs {y_param}",
+ )
+ .interactive()
+ )
+
+ st.altair_chart(chart, use_container_width=True)
def render_score_evolution(results: pd.DataFrame, metric: str):
@@ -580,6 +628,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
# Reshape for Altair
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
+ # Get colormap for score evolution
+ evolution_cmap = get_cmap("score_evolution")
+ evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))]
+
# Create line chart
chart = (
alt.Chart(df_long)
@@ -587,7 +639,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
.encode(
alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()),
- alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
+ alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)),
strokeDash=alt.StrokeDash(
"Type",
legend=None,
@@ -614,6 +666,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
st.metric("Best at Iteration", best_iter)
+@st.fragment
def render_multi_metric_comparison(results: pd.DataFrame):
"""Render comparison of multiple metrics.
@@ -628,8 +681,16 @@ def render_multi_metric_comparison(results: pd.DataFrame):
st.warning("Need at least 2 metrics for comparison.")
return
- # Let user select two metrics to compare
- col1, col2 = st.columns(2)
+ # Get parameter columns for color options
+ param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
+ param_names = [col.replace("param_", "") for col in param_cols]
+
+ if not param_names:
+ st.warning("No parameters found for coloring.")
+ return
+
+ # Let user select two metrics and color parameter
+ col1, col2, col3 = st.columns(3)
with col1:
metric1 = st.selectbox(
"Select First Metric",
@@ -646,21 +707,57 @@ def render_multi_metric_comparison(results: pd.DataFrame):
key="metric2_select",
)
+ with col3:
+ color_by = st.selectbox(
+ "Color By",
+ options=param_names,
+ index=0,
+ key="color_select",
+ )
+
if metric1 == metric2:
st.warning("Please select different metrics.")
return
- # Create scatter plot
+ # Create scatter plot data
df_plot = pd.DataFrame(
{
metric1: results[f"mean_test_{metric1}"],
metric2: results[f"mean_test_{metric2}"],
- "Iteration": range(len(results)),
}
)
- # Get colormap from colors module
- cmap = get_cmap("iteration")
+ # Add color parameter to dataframe
+ param_col = f"param_{color_by}"
+ df_plot[color_by] = results[param_col]
+ color_col = color_by
+
+ # Check if parameter is numeric and should use log scale
+ param_values = results[param_col].dropna()
+ if pd.api.types.is_numeric_dtype(param_values):
+ value_range = param_values.max() / (param_values.min() + 1e-10)
+ use_log = value_range > 100
+
+ if use_log:
+ color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256))
+ color_format = ".2e"
+ else:
+ color_scale = alt.Scale(range=get_palette(color_by, n_colors=256))
+ color_format = ".3f"
+ else:
+ # Categorical parameter
+ color_scale = alt.Scale(scheme="category20")
+ color_format = None
+
+ # Build tooltip list
+ tooltip_list = [
+ alt.Tooltip(metric1, format=".4f"),
+ alt.Tooltip(metric2, format=".4f"),
+ ]
+ if color_format:
+ tooltip_list.append(alt.Tooltip(color_col, format=color_format))
+ else:
+ tooltip_list.append(color_col)
chart = (
alt.Chart(df_plot)
@@ -668,12 +765,8 @@ def render_multi_metric_comparison(results: pd.DataFrame):
.encode(
alt.X(metric1, title=metric1.replace("_", " ").title()),
alt.Y(metric2, title=metric2.replace("_", " ").title()),
- alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
- tooltip=[
- alt.Tooltip(metric1, format=".4f"),
- alt.Tooltip(metric2, format=".4f"),
- "Iteration",
- ],
+ alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
+ tooltip=tooltip_list,
)
.properties(height=500)
)
diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py
index 94f017d..113bfac 100644
--- a/src/entropice/dashboard/plots/model_state.py
+++ b/src/entropice/dashboard/plots/model_state.py
@@ -447,3 +447,212 @@ def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
)
return chart
+
+
+# XGBoost-specific visualizations
+def plot_xgboost_feature_importance(
+ model_state: xr.Dataset, importance_type: str = "gain", top_n: int = 20
+) -> alt.Chart:
+ """Plot XGBoost feature importance for a specific importance type.
+
+ Args:
+ model_state: The xarray Dataset containing the XGBoost model state.
+ importance_type: Type of importance ('weight', 'gain', 'cover', 'total_gain', 'total_cover').
+ top_n: Number of top features to display.
+
+ Returns:
+ Altair chart showing the top features by importance.
+
+ """
+ # Get the importance array
+ importance_key = f"feature_importance_{importance_type}"
+ if importance_key not in model_state:
+ raise ValueError(f"Importance type '{importance_type}' not found in model state")
+
+ importance = model_state[importance_key].to_pandas()
+
+ # Sort and take top N
+ top_features = importance.nlargest(top_n).sort_values(ascending=True)
+
+ # Create DataFrame for plotting
+ plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
+
+ # Create bar chart
+ chart = (
+ alt.Chart(plot_data)
+ .mark_bar()
+ .encode(
+ y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
+ x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"),
+ color=alt.value("steelblue"),
+ tooltip=[
+ alt.Tooltip("feature:N", title="Feature"),
+ alt.Tooltip("importance:Q", format=".4f", title="Importance"),
+ ],
+ )
+ .properties(
+ width=600,
+ height=400,
+ title=f"Top {top_n} Features by {importance_type.replace('_', ' ').title()} Importance (XGBoost)",
+ )
+ )
+
+ return chart
+
+
+def plot_xgboost_importance_comparison(model_state: xr.Dataset, top_n: int = 15) -> alt.Chart:
+ """Compare different importance types for XGBoost side-by-side.
+
+ Args:
+ model_state: The xarray Dataset containing the XGBoost model state.
+ top_n: Number of top features to display.
+
+ Returns:
+ Altair chart comparing importance types.
+
+ """
+ # Collect all importance types
+ importance_types = ["weight", "gain", "cover"]
+ all_data = []
+
+ for imp_type in importance_types:
+ importance_key = f"feature_importance_{imp_type}"
+ if importance_key in model_state:
+ importance = model_state[importance_key].to_pandas()
+ # Get top features by this importance type
+ top_features = importance.nlargest(top_n)
+ for feature, value in top_features.items():
+ all_data.append(
+ {
+ "feature": feature,
+ "importance": value,
+ "type": imp_type.replace("_", " ").title(),
+ }
+ )
+
+ df = pd.DataFrame(all_data)
+
+ # Create faceted bar chart
+ chart = (
+ alt.Chart(df)
+ .mark_bar()
+ .encode(
+ y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=200)),
+ x=alt.X("importance:Q", title="Importance Value"),
+ color=alt.Color("type:N", title="Importance Type", scale=alt.Scale(scheme="category10")),
+ tooltip=[
+ alt.Tooltip("feature:N", title="Feature"),
+ alt.Tooltip("type:N", title="Type"),
+ alt.Tooltip("importance:Q", format=".4f", title="Importance"),
+ ],
+ )
+ .properties(width=250, height=300)
+ .facet(facet=alt.Facet("type:N", title="Importance Type"), columns=3)
+ )
+
+ return chart
+
+
+# Random Forest-specific visualizations
+def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.Chart:
+ """Plot Random Forest feature importance (Gini importance).
+
+ Args:
+ model_state: The xarray Dataset containing the Random Forest model state.
+ top_n: Number of top features to display.
+
+ Returns:
+ Altair chart showing the top features by importance.
+
+ """
+ importance = model_state["feature_importance"].to_pandas()
+
+ # Sort and take top N
+ top_features = importance.nlargest(top_n).sort_values(ascending=True)
+
+ # Create DataFrame for plotting
+ plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
+
+ # Create bar chart
+ chart = (
+ alt.Chart(plot_data)
+ .mark_bar()
+ .encode(
+ y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
+ x=alt.X("importance:Q", title="Gini Importance"),
+ color=alt.value("forestgreen"),
+ tooltip=[
+ alt.Tooltip("feature:N", title="Feature"),
+ alt.Tooltip("importance:Q", format=".4f", title="Importance"),
+ ],
+ )
+ .properties(
+ width=600,
+ height=400,
+ title=f"Top {top_n} Features by Gini Importance (Random Forest)",
+ )
+ )
+
+ return chart
+
+
+def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
+ """Plot Random Forest tree statistics.
+
+ Args:
+ model_state: The xarray Dataset containing the Random Forest model state.
+
+ Returns:
+ Tuple of three Altair charts (depth histogram, leaves histogram, nodes histogram).
+
+ """
+ # Extract tree statistics
+ depths = model_state["tree_depths"].to_pandas()
+ n_leaves = model_state["tree_n_leaves"].to_pandas()
+ n_nodes = model_state["tree_n_nodes"].to_pandas()
+
+ # Create DataFrames
+ df_depths = pd.DataFrame({"value": depths.to_numpy()})
+ df_leaves = pd.DataFrame({"value": n_leaves.to_numpy()})
+ df_nodes = pd.DataFrame({"value": n_nodes.to_numpy()})
+
+ # Histogram for tree depths
+ chart_depths = (
+ alt.Chart(df_depths)
+ .mark_bar()
+ .encode(
+ x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
+ y=alt.Y("count()", title="Number of Trees"),
+ color=alt.value("steelblue"),
+ tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")],
+ )
+ .properties(width=300, height=200, title="Distribution of Tree Depths")
+ )
+
+ # Histogram for number of leaves
+ chart_leaves = (
+ alt.Chart(df_leaves)
+ .mark_bar()
+ .encode(
+ x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
+ y=alt.Y("count()", title="Number of Trees"),
+ color=alt.value("forestgreen"),
+ tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")],
+ )
+ .properties(width=300, height=200, title="Distribution of Leaf Counts")
+ )
+
+ # Histogram for number of nodes
+ chart_nodes = (
+ alt.Chart(df_nodes)
+ .mark_bar()
+ .encode(
+ x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
+ y=alt.Y("count()", title="Number of Trees"),
+ color=alt.value("darkorange"),
+ tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")],
+ )
+ .properties(width=300, height=200, title="Distribution of Node Counts")
+ )
+
+ return chart_depths, chart_leaves, chart_nodes
diff --git a/src/entropice/dashboard/plots/source_data.py b/src/entropice/dashboard/plots/source_data.py
index c73245b..294fc65 100644
--- a/src/entropice/dashboard/plots/source_data.py
+++ b/src/entropice/dashboard/plots/source_data.py
@@ -761,6 +761,117 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
st.warning("No valid data available for selected parameters")
+@st.fragment
+def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
+ """Render interactive pydeck map for grid cell areas.
+
+ Args:
+ grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
+ grid: Grid type ('hex' or 'healpix').
+
+ """
+ st.subheader("🗺️ Grid Cell Areas Distribution")
+
+ # Controls
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ area_metric = st.selectbox(
+ "Area Metric",
+ options=["cell_area", "land_area", "water_area", "land_ratio"],
+ format_func=lambda x: x.replace("_", " ").title(),
+ key="areas_metric",
+ )
+
+ with col2:
+ opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity")
+
+ # Create GeoDataFrame
+ gdf = grid_gdf.copy()
+
+ # Convert to WGS84 first
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix geometries after CRS conversion
+ if grid == "hex":
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
+
+ # Get values for the selected metric
+ values = gdf_wgs84[area_metric].to_numpy()
+
+ # Normalize values for color mapping
+ vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
+ normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
+
+ # Apply colormap based on metric type
+ if area_metric == "land_ratio":
+ cmap = get_cmap("terrain") # Different colormap for ratio
+ else:
+ cmap = get_cmap("terrain")
+
+ colors = [cmap(val) for val in normalized]
+ gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
+
+ # Set elevation based on normalized values for 3D visualization
+ gdf_wgs84["elevation"] = normalized
+
+ # Create GeoJSON
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": {
+ "cell_area": f"{float(row['cell_area']):.2f}",
+ "land_area": f"{float(row['land_area']):.2f}",
+ "water_area": f"{float(row['water_area']):.2f}",
+ "land_ratio": f"{float(row['land_ratio']):.2%}",
+ "fill_color": row["fill_color"],
+ "elevation": float(row["elevation"]),
+ },
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer with 3D elevation
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=opacity,
+ stroked=True,
+ filled=True,
+ extruded=True,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ get_elevation="properties.elevation",
+ elevation_scale=500000,
+ line_width_min_pixels=0.5,
+ pickable=True,
+ )
+
+ view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
+
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={
+ "html": "Cell Area: {cell_area} km²
"
+ "Land Area: {land_area} km²
"
+ "Water Area: {water_area} km²
"
+ "Land Ratio: {land_ratio}",
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ st.pydeck_chart(deck)
+
+ # Show statistics
+ st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
+
+ # Show additional info
+ st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
+
+
@st.fragment
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
"""Render interactive pydeck map for ERA5 climate data.
diff --git a/src/entropice/dashboard/training_data_page.py b/src/entropice/dashboard/training_data_page.py
index 88dc5f5..1e8f805 100644
--- a/src/entropice/dashboard/training_data_page.py
+++ b/src/entropice/dashboard/training_data_page.py
@@ -2,6 +2,7 @@
import streamlit as st
+from entropice import grids
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
render_alphaearth_overview,
@@ -9,6 +10,7 @@ from entropice.dashboard.plots.source_data import (
render_arcticdem_map,
render_arcticdem_overview,
render_arcticdem_plots,
+ render_areas_map,
render_era5_map,
render_era5_overview,
render_era5_plots,
@@ -184,7 +186,7 @@ def render_training_data_page():
st.markdown("---")
# Create tabs for different data views
- tab_names = ["📊 Labels"]
+ tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member
for member in ensemble.members:
@@ -228,8 +230,50 @@ def render_training_data_page():
render_spatial_map(train_data_dict)
+ # Areas tab
+ with tabs[1]:
+ st.markdown("### Grid Cell Areas and Land/Water Distribution")
+
+ st.markdown(
+ "This visualization shows the spatial distribution of cell areas, land areas, "
+ "water areas, and land ratio across the grid. The grid has been filtered to "
+ "include only cells in the permafrost region (>50° latitude, <85° latitude) "
+ "with >10% land coverage."
+ )
+
+ # Load grid data
+ grid_gdf = grids.open(ensemble.grid, ensemble.level)
+
+ st.success(
+ f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
+ f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
+ )
+
+ # Show summary statistics
+ col1, col2, col3, col4 = st.columns(4)
+ with col1:
+ st.metric("Total Cells", f"{len(grid_gdf):,}")
+ with col2:
+ st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
+ with col3:
+ st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
+ with col4:
+ total_land = grid_gdf["land_area"].sum()
+ st.metric("Total Land Area", f"{total_land:,.0f} km²")
+
+ st.markdown("---")
+
+ if (ensemble.grid == "hex" and ensemble.level == 6) or (
+ ensemble.grid == "healpix" and ensemble.level == 10
+ ):
+ st.warning(
+ "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
+ )
+ else:
+ render_areas_map(grid_gdf, ensemble.grid)
+
# AlphaEarth tab
- tab_idx = 1
+ tab_idx = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### AlphaEarth Embeddings Analysis")
diff --git a/src/entropice/dashboard/utils/data.py b/src/entropice/dashboard/utils/data.py
index c0ce468..d64a70f 100644
--- a/src/entropice/dashboard/utils/data.py
+++ b/src/entropice/dashboard/utils/data.py
@@ -29,10 +29,9 @@ class TrainingResult:
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
- state_file = result_path / "best_estimator_state.nc"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
- if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]):
+ if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
created_at = result_path.stat().st_ctime
@@ -96,9 +95,9 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
"""
return {
- "binary": e.create_cat_training_dataset("binary"),
- "count": e.create_cat_training_dataset("count"),
- "density": e.create_cat_training_dataset("density"),
+ "binary": e.create_cat_training_dataset("binary", device="cpu"),
+ "count": e.create_cat_training_dataset("count", device="cpu"),
+ "density": e.create_cat_training_dataset("density", device="cpu"),
}
diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py
index 5003728..b80cd4a 100644
--- a/src/entropice/dataset.py
+++ b/src/entropice/dataset.py
@@ -14,12 +14,15 @@ Naming conventions:
import hashlib
import json
+from collections.abc import Generator
from dataclasses import asdict, dataclass, field
from functools import cached_property, lru_cache
-from typing import Literal
+from typing import Literal, TypedDict
+import cupy as cp
import cyclopts
import geopandas as gpd
+import numpy as np
import pandas as pd
import seaborn as sns
import torch
@@ -110,8 +113,8 @@ def bin_values(
@dataclass(frozen=True, eq=False)
class DatasetLabels:
binned: pd.Series
- train: torch.Tensor
- test: torch.Tensor
+ train: torch.Tensor | np.ndarray | cp.ndarray
+ test: torch.Tensor | np.ndarray | cp.ndarray
raw_values: pd.Series
@cached_property
@@ -135,8 +138,8 @@ class DatasetLabels:
@dataclass(frozen=True, eq=False)
class DatasetInputs:
data: pd.DataFrame
- train: torch.Tensor
- test: torch.Tensor
+ train: torch.Tensor | np.ndarray | cp.ndarray
+ test: torch.Tensor | np.ndarray | cp.ndarray
@dataclass(frozen=True)
@@ -151,6 +154,13 @@ class CategoricalTrainingDataset:
return len(self.z)
+class DatasetStats(TypedDict):
+ target: str
+ num_target_samples: int
+ members: dict[str, dict[str, object]]
+ total_features: int
+
+
@cyclopts.Parameter("*")
@dataclass(frozen=True)
class DatasetEnsemble:
@@ -283,15 +293,15 @@ class DatasetEnsemble:
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
return arcticdem_df
- def get_stats(self) -> dict:
+ def get_stats(self) -> DatasetStats:
"""Get dataset statistics.
Returns:
- dict: Dictionary containing target stats, member stats, and total features count.
+ DatasetStats: Dictionary containing target stats, member stats, and total features count.
"""
targets = self._read_target()
- stats = {
+ stats: DatasetStats = {
"target": self.target,
"num_target_samples": len(targets),
"members": {},
@@ -365,11 +375,64 @@ class DatasetEnsemble:
print(f"Saved dataset to cache at {cache_file}.")
return dataset
- def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset:
+ def create_batches(
+ self,
+ batch_size: int,
+ filter_target_col: str | None = None,
+ cache_mode: Literal["n", "o", "r"] = "r",
+ ) -> Generator[pd.DataFrame]:
+ targets = self._read_target()
+ if len(targets) == 0:
+ raise ValueError("No target samples found.")
+ elif len(targets) < batch_size:
+ yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
+ return
+
+ if filter_target_col is not None:
+ targets = targets.loc[targets[filter_target_col]]
+
+ for i in range(0, len(targets), batch_size):
+ # n: no cache, o: overwrite cache, r: read cache if exists
+ cache_file = entropice.paths.get_dataset_cache(
+ self.id(), subset=filter_target_col, batch=(i, i + batch_size)
+ )
+ if cache_mode == "r" and cache_file.exists():
+ dataset = gpd.read_parquet(cache_file)
+ print(
+ f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
+ f" and {len(dataset.columns)} features."
+ )
+ yield dataset
+ else:
+ targets_batch = targets.iloc[i : i + batch_size]
+
+ member_dfs = []
+ for member in self.members:
+ if member.startswith("ERA5"):
+ era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
+ member_dfs.append(self._prep_era5(targets_batch, era5_agg))
+ elif member == "AlphaEarth":
+ member_dfs.append(self._prep_embeddings(targets_batch))
+ elif member == "ArcticDEM":
+ member_dfs.append(self._prep_arcticdem(targets_batch))
+ else:
+ raise NotImplementedError(f"Member {member} not implemented.")
+
+ dataset = targets_batch.set_index("cell_id").join(member_dfs)
+ print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
+ if cache_mode in ["o", "r"]:
+ dataset.to_parquet(cache_file)
+ print(f"Saved dataset to cache at {cache_file}.")
+ yield dataset
+
+ def create_cat_training_dataset(
+ self, task: Task, device: Literal["cpu", "cuda", "torch"]
+ ) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
+ device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
@@ -414,10 +477,23 @@ class DatasetEnsemble:
split.loc[test_idx] = "test"
split = split.astype("category")
- X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0)
- X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0)
- y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0)
- y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0)
+ X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64")
+ X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64")
+ y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64")
+ y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64")
+
+ if device == "cuda":
+ X_train = cp.asarray(X_train)
+ X_test = cp.asarray(X_test)
+ y_train = cp.asarray(y_train)
+ y_test = cp.asarray(y_test)
+ elif device == "torch":
+ X_train = torch.from_numpy(X_train)
+ X_test = torch.from_numpy(X_test)
+ y_train = torch.from_numpy(y_train)
+ y_test = torch.from_numpy(y_test)
+ else:
+ assert device == "cpu", "Invalid device specified."
return CategoricalTrainingDataset(
dataset=dataset.to_crs("EPSG:4326"),
diff --git a/src/entropice/inference.py b/src/entropice/inference.py
index 816dd66..b2d87dc 100644
--- a/src/entropice/inference.py
+++ b/src/entropice/inference.py
@@ -1,6 +1,7 @@
# ruff: noqa: N806
"""Inference runs on trained models."""
+import cupy as cp
import geopandas as gpd
import pandas as pd
import torch
@@ -32,35 +33,32 @@ def predict_proba(
list: A list of predicted probabilities for each cell.
"""
- data = e.create()
- print(f"Predicting probabilities for {len(data)} cells...")
-
# Predict in batches to avoid memory issues
- batch_size = 10_000
+ batch_size = 10000
preds = []
- cols_to_drop = ["geometry"]
- if e.target == "darts_mllabels":
- cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
- else:
- cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
- for i in range(0, len(data), batch_size):
- batch = data.iloc[i : i + batch_size]
+ for batch in e.create_batches(batch_size=batch_size):
+ cols_to_drop = ["geometry"]
+ if e.target == "darts_mllabels":
+ cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
+ else:
+ cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
X_batch = batch.drop(columns=cols_to_drop).dropna()
cell_ids = X_batch.index.to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float64")
X_batch = torch.asarray(X_batch, device=0)
- batch_preds = clf.predict(X_batch).cpu().numpy()
+ batch_preds = clf.predict(X_batch)
+ if isinstance(batch_preds, cp.ndarray):
+ batch_preds = batch_preds.get()
+ elif torch.is_tensor(batch_preds):
+ batch_preds = batch_preds.cpu().numpy()
batch_preds = gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"predicted_class": [classes[i] for i in batch_preds],
"geometry": cell_geoms,
},
- crs="epsg:3413",
- )
+ ).set_crs(epsg=3413, inplace=False)
preds.append(batch_preds)
- preds = gpd.GeoDataFrame(pd.concat(preds))
-
- return preds
+ return gpd.GeoDataFrame(pd.concat(preds, ignore_index=True))
diff --git a/src/entropice/paths.py b/src/entropice/paths.py
index bb6a197..6ca18c7 100644
--- a/src/entropice/paths.py
+++ b/src/entropice/paths.py
@@ -9,7 +9,7 @@ from typing import Literal
DATA_DIR = (
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
)
-DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
+# DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures")
@@ -110,7 +110,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file
-def get_dataset_cache(eid: str, subset: str | None = None) -> Path:
+def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int] | None = None) -> Path:
+ if batch is not None:
+ eid = f"{eid}_batch{batch[0]}-{batch[1]}"
if subset is None:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
else:
diff --git a/src/entropice/training.py b/src/entropice/training.py
index 61b0c46..7b0cc87 100644
--- a/src/entropice/training.py
+++ b/src/entropice/training.py
@@ -4,6 +4,7 @@ import pickle
from dataclasses import asdict, dataclass
from typing import Literal
+import cupy as cp
import cyclopts
import pandas as pd
import toml
@@ -14,7 +15,6 @@ from entropy import ESPAClassifier
from rich import pretty, traceback
from scipy.stats import loguniform, randint
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
-from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV
from stopuhr import stopwatch
from xgboost.sklearn import XGBClassifier
@@ -26,7 +26,9 @@ from entropice.paths import get_cv_results_dir
traceback.install()
pretty.install()
-set_config(array_api_dispatch=True)
+# Disabled array_api_dispatch to avoid namespace conflicts between NumPy and CuPy
+# when using XGBoost with device="cuda"
+# set_config(array_api_dispatch=True)
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
@@ -85,8 +87,8 @@ def _create_clf(
objective="multi:softprob" if settings.task != "binary" else "binary:logistic",
eval_metric="mlogloss" if settings.task != "binary" else "logloss",
random_state=42,
- tree_method="gpu_hist",
- device="cuda",
+ tree_method="hist",
+ device="gpu", # Using CPU to avoid CuPy/NumPy namespace conflicts in sklearn metrics
)
fit_params = {}
elif settings.model == "rf":
@@ -98,11 +100,10 @@ def _create_clf(
fit_params = {}
elif settings.model == "knn":
param_grid = {
- "n_neighbors": randint(3, 15),
+ "n_neighbors": randint(10, 200),
"weights": ["uniform", "distance"],
- "algorithm": ["brute", "kd_tree", "ball_tree"],
}
- clf = KNeighborsClassifier(random_state=42)
+ clf = KNeighborsClassifier()
fit_params = {}
else:
raise ValueError(f"Unknown model: {settings.model}")
@@ -148,8 +149,9 @@ def random_cv(
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
"""
+ device = "torch" if settings.model in ["espa"] else "cuda"
print("Creating training data...")
- training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task)
+ training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task, device=device)
clf, param_grid, fit_params = _create_clf(settings)
print(f"Using model: {settings.model} with parameters: {param_grid}")
@@ -159,7 +161,7 @@ def random_cv(
clf,
param_grid,
n_iter=settings.n_iter,
- n_jobs=8,
+ n_jobs=1,
cv=cv,
random_state=42,
verbose=10,
@@ -169,14 +171,25 @@ def random_cv(
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
- search.fit(training_data.X.train, training_data.y.train, **fit_params)
+ y_train = (
+ training_data.y.train.get()
+ if settings.model == "xgboost" and isinstance(training_data.y.train, cp.ndarray)
+ else training_data.y.train
+ )
+ search.fit(training_data.X.train, y_train, **fit_params)
print("Best parameters combination found:")
- best_parameters = search.best_estimator_.get_params()
+ best_estimator = search.best_estimator_
+ best_parameters = best_estimator.get_params()
for param_name in sorted(param_grid.keys()):
print(f"{param_name}: {best_parameters[param_name]}")
- test_accuracy = search.score(training_data.X.test, training_data.y.test)
+ y_test = (
+ training_data.y.test.get()
+ if settings.model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
+ else training_data.y.test
+ )
+ test_accuracy = search.score(training_data.X.test, y_test)
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
@@ -207,7 +220,7 @@ def random_cv(
best_model_file = results_dir / "best_estimator_model.pkl"
print(f"Storing best estimator model to {best_model_file}")
with open(best_model_file, "wb") as f:
- pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
+ pickle.dump(best_estimator, f, protocol=pickle.HIGHEST_PROTOCOL)
# Store the search results
results = pd.DataFrame(search.cv_results_)
@@ -220,10 +233,10 @@ def random_cv(
results.to_parquet(results_file)
# Get the inner state of the best estimator
+ features = training_data.X.data.columns.tolist()
+
if settings.model == "espa":
- best_estimator = search.best_estimator_
# Annotate the state with xarray metadata
- features = training_data.X.data.columns.tolist()
boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(),
@@ -260,9 +273,154 @@ def random_cv(
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
+ elif settings.model == "xgboost":
+ # Extract XGBoost-specific information
+ # Get the underlying booster
+ booster = best_estimator.get_booster()
+
+ # Feature importance with different importance types
+ importance_weight = booster.get_score(importance_type="weight")
+ importance_gain = booster.get_score(importance_type="gain")
+ importance_cover = booster.get_score(importance_type="cover")
+ importance_total_gain = booster.get_score(importance_type="total_gain")
+ importance_total_cover = booster.get_score(importance_type="total_cover")
+
+ # Create aligned arrays for all features (including zero-importance)
+ def align_importance(importance_dict, features):
+ """Align importance dict to feature list, filling missing with 0."""
+ return [importance_dict.get(f, 0.0) for f in features]
+
+ feature_importance_weight = xr.DataArray(
+ align_importance(importance_weight, features),
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance_weight",
+ attrs={"description": "Number of times a feature is used to split the data across all trees."},
+ )
+ feature_importance_gain = xr.DataArray(
+ align_importance(importance_gain, features),
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance_gain",
+ attrs={"description": "Average gain across all splits the feature is used in."},
+ )
+ feature_importance_cover = xr.DataArray(
+ align_importance(importance_cover, features),
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance_cover",
+ attrs={"description": "Average coverage across all splits the feature is used in."},
+ )
+ feature_importance_total_gain = xr.DataArray(
+ align_importance(importance_total_gain, features),
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance_total_gain",
+ attrs={"description": "Total gain across all splits the feature is used in."},
+ )
+ feature_importance_total_cover = xr.DataArray(
+ align_importance(importance_total_cover, features),
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance_total_cover",
+ attrs={"description": "Total coverage across all splits the feature is used in."},
+ )
+
+ # Store tree information
+ n_trees = booster.num_boosted_rounds()
+
+ state = xr.Dataset(
+ {
+ "feature_importance_weight": feature_importance_weight,
+ "feature_importance_gain": feature_importance_gain,
+ "feature_importance_cover": feature_importance_cover,
+ "feature_importance_total_gain": feature_importance_total_gain,
+ "feature_importance_total_cover": feature_importance_total_cover,
+ },
+ attrs={
+ "description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
+ "n_trees": n_trees,
+ "objective": str(best_estimator.objective),
+ },
+ )
+ state_file = results_dir / "best_estimator_state.nc"
+ print(f"Storing best estimator state to {state_file}")
+ state.to_netcdf(state_file, engine="h5netcdf")
+
+ elif settings.model == "rf":
+ # Extract Random Forest-specific information
+ # Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_)
+ # like sklearn does, so we can only extract feature importances and model parameters
+
+ # Feature importances (Gini importance)
+ feature_importances = best_estimator.feature_importances_
+
+ feature_importance = xr.DataArray(
+ feature_importances,
+ dims=["feature"],
+ coords={"feature": features},
+ name="feature_importance",
+ attrs={"description": "Gini importance (impurity-based feature importance)."},
+ )
+
+ # cuML RF doesn't expose individual trees, so we store model parameters instead
+ n_estimators = best_estimator.n_estimators
+ max_depth = best_estimator.max_depth
+
+ # OOB score if available
+ oob_score = None
+ if hasattr(best_estimator, "oob_score_") and best_estimator.oob_score:
+ oob_score = float(best_estimator.oob_score_)
+
+ # cuML RandomForest doesn't provide per-tree statistics like sklearn
+ # Store what we have: feature importances and model configuration
+ attrs = {
+ "description": "Inner state of the best RandomForestClassifier from RandomizedSearchCV (cuML).",
+ "n_estimators": int(n_estimators),
+ "note": "cuML RandomForest does not expose individual tree statistics like sklearn",
+ }
+
+ # Only add optional attributes if they have values
+ if max_depth is not None:
+ attrs["max_depth"] = int(max_depth)
+ if oob_score is not None:
+ attrs["oob_score"] = oob_score
+
+ state = xr.Dataset(
+ {
+ "feature_importance": feature_importance,
+ },
+ attrs=attrs,
+ )
+ state_file = results_dir / "best_estimator_state.nc"
+ print(f"Storing best estimator state to {state_file}")
+ state.to_netcdf(state_file, engine="h5netcdf")
+
+ elif settings.model == "knn":
+ # KNN doesn't have traditional feature importance
+ # Store information about the training data and neighbors
+ n_neighbors = best_estimator.n_neighbors
+
+ # We can't extract meaningful "state" from KNN in the same way
+ # but we can store metadata about the model
+ state = xr.Dataset(
+ attrs={
+ "description": "Metadata of the best KNeighborsClassifier from RandomizedSearchCV.",
+ "n_neighbors": n_neighbors,
+ "weights": str(best_estimator.weights),
+ "algorithm": str(best_estimator.algorithm),
+ "metric": str(best_estimator.metric),
+ "n_samples_fit": best_estimator.n_samples_fit_,
+ },
+ )
+ state_file = results_dir / "best_estimator_state.nc"
+ print(f"Storing best estimator metadata to {state_file}")
+ state.to_netcdf(state_file, engine="h5netcdf")
+
# Predict probabilities for all cells
print("Predicting probabilities for all cells...")
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels)
+ print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file)