Update Model State Page
This commit is contained in:
parent
6ed5a9c224
commit
a64e1ac41f
12 changed files with 1288 additions and 380 deletions
4
pixi.lock
generated
4
pixi.lock
generated
|
|
@ -9,7 +9,6 @@ environments:
|
|||
- https://pypi.org/simple
|
||||
options:
|
||||
channel-priority: disabled
|
||||
pypi-prerelease-mode: if-necessary-or-explicit
|
||||
packages:
|
||||
linux-64:
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda
|
||||
|
|
@ -2869,7 +2868,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: 6488240242ab4091686c27ee2e721c81143e0a158bdff086518da6ccdc2971c8
|
||||
sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2933,6 +2932,7 @@ packages:
|
|||
- ruff>=0.14.9,<0.15
|
||||
- pandas-stubs>=2.3.3.251201,<3
|
||||
requires_python: '>=3.13,<3.14'
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
name: entropy
|
||||
version: 0.1.0
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ dependencies = [
|
|||
"pydeck>=0.9.1,<0.10",
|
||||
"pypalettes>=0.2.1,<0.3",
|
||||
"ty>=0.0.2,<0.0.3",
|
||||
"ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3",
|
||||
"ruff>=0.14.9,<0.15",
|
||||
"pandas-stubs>=2.3.3.251201,<3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
@ -118,6 +119,7 @@ platforms = ["linux-64"]
|
|||
|
||||
[tool.pixi.activation.env]
|
||||
SCIPY_ARRAY_API = "1"
|
||||
FAST_DATA_DIR = "./data"
|
||||
|
||||
[tool.pixi.system-requirements]
|
||||
cuda = "12"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Model State page for the Entropice dashboard."""
|
||||
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.plots.colors import generate_unified_colormap
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
|
|
@ -43,6 +44,9 @@ def render_model_state_page():
|
|||
)
|
||||
selected_result = result_options[selected_name]
|
||||
|
||||
# Get the model type from settings
|
||||
model_type = selected_result.settings.get("model", "espa")
|
||||
|
||||
# Load model state
|
||||
with st.spinner("Loading model state..."):
|
||||
model_state = load_model_state(selected_result)
|
||||
|
|
@ -50,34 +54,40 @@ def render_model_state_page():
|
|||
st.error("Could not load model state for this result.")
|
||||
return
|
||||
|
||||
# Scale feature weights by number of features
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
||||
# Extract different feature types
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
era5_feature_array = extract_era5_features(model_state)
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps
|
||||
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
|
||||
|
||||
# Display basic model state info
|
||||
with st.expander("Model State Information", expanded=False):
|
||||
st.write(f"**Model Type:** {model_type.upper()}")
|
||||
st.write(f"**Variables:** {list(model_state.data_vars)}")
|
||||
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
|
||||
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||
st.write(f"**Attributes:** {dict(model_state.attrs)}")
|
||||
|
||||
# Show statistics
|
||||
st.write("**Feature Weight Statistics:**")
|
||||
feature_weights = model_state["feature_weights"].to_pandas()
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
|
||||
with col2:
|
||||
st.metric("Max Weight", f"{feature_weights.max():.4f}")
|
||||
with col3:
|
||||
st.metric("Total Features", len(feature_weights))
|
||||
# Render model-specific visualizations
|
||||
if model_type == "espa":
|
||||
render_espa_model_state(model_state, selected_result)
|
||||
elif model_type == "xgboost":
|
||||
render_xgboost_model_state(model_state, selected_result)
|
||||
elif model_type == "rf":
|
||||
render_rf_model_state(model_state, selected_result)
|
||||
elif model_type == "knn":
|
||||
render_knn_model_state(model_state, selected_result)
|
||||
else:
|
||||
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
|
||||
|
||||
|
||||
def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for ESPA model."""
|
||||
# Scale feature weights by number of features
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
||||
# Extract different feature types
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
era5_feature_array = extract_era5_features(model_state)
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps
|
||||
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
|
||||
|
||||
# Feature importance section
|
||||
st.header("Feature Importance")
|
||||
|
|
@ -167,156 +177,362 @@ def render_model_state_page():
|
|||
|
||||
# Embedding features analysis (if present)
|
||||
if embedding_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("🛰️ Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
else:
|
||||
st.info("No embedding features found in this model.")
|
||||
render_embedding_features(embedding_feature_array)
|
||||
|
||||
# ERA5 features analysis (if present)
|
||||
if era5_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("⛅ ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
else:
|
||||
st.info("No ERA5 features found in this model.")
|
||||
render_era5_features(era5_feature_array)
|
||||
|
||||
# Common features analysis (if present)
|
||||
if common_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("🗺️ Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
render_common_features(common_feature_array)
|
||||
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for XGBoost model."""
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_xgboost_feature_importance,
|
||||
plot_xgboost_importance_comparison,
|
||||
)
|
||||
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
st.header("🌲 XGBoost Model Analysis")
|
||||
st.markdown(
|
||||
f"""
|
||||
XGBoost gradient boosted tree model with **{model_state.attrs.get("n_trees", "N/A")} trees**.
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||
size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
else:
|
||||
st.info("No common features found in this model.")
|
||||
**Objective:** {model_state.attrs.get("objective", "N/A")}
|
||||
"""
|
||||
)
|
||||
|
||||
# Feature importance with different types
|
||||
st.subheader("Feature Importance Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
XGBoost provides multiple ways to measure feature importance:
|
||||
- **Weight**: Number of times a feature is used to split the data
|
||||
- **Gain**: Average gain across all splits using the feature
|
||||
- **Cover**: Average coverage across all splits using the feature
|
||||
- **Total Gain**: Total gain across all splits
|
||||
- **Total Cover**: Total coverage across all splits
|
||||
"""
|
||||
)
|
||||
|
||||
# Importance type selector
|
||||
importance_type = st.selectbox(
|
||||
"Select Importance Type",
|
||||
options=["gain", "weight", "cover", "total_gain", "total_cover"],
|
||||
index=0,
|
||||
help="Choose which importance metric to visualize",
|
||||
)
|
||||
|
||||
# Top N slider
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=20,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
|
||||
st.altair_chart(importance_chart, use_container_width=True)
|
||||
|
||||
# Comparison of importance types
|
||||
st.subheader("Importance Type Comparison")
|
||||
st.markdown("Compare the top features across different importance metrics.")
|
||||
|
||||
with st.spinner("Generating importance comparison..."):
|
||||
comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
|
||||
st.altair_chart(comparison_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Model Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric("Number of Trees", model_state.attrs.get("n_trees", "N/A"))
|
||||
with col2:
|
||||
st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
|
||||
|
||||
|
||||
def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for Random Forest model."""
|
||||
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
|
||||
|
||||
st.header("🌳 Random Forest Model Analysis")
|
||||
|
||||
# Check if using cuML (which doesn't provide tree statistics)
|
||||
is_cuml = "cuML" in model_state.attrs.get("description", "")
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
Random Forest ensemble with **{model_state.attrs.get("n_estimators", "N/A")} trees**
|
||||
(max depth: {model_state.attrs.get("max_depth", "N/A")}).
|
||||
"""
|
||||
)
|
||||
|
||||
if is_cuml:
|
||||
st.info("ℹ️ Using cuML GPU-accelerated Random Forest. Individual tree statistics are not available.")
|
||||
|
||||
# Display OOB score if available
|
||||
oob_score = model_state.attrs.get("oob_score")
|
||||
if oob_score is not None:
|
||||
st.info(f"**Out-of-Bag Score:** {oob_score:.4f}")
|
||||
|
||||
# Feature importance
|
||||
st.subheader("Feature Importance (Gini Importance)")
|
||||
st.markdown(
|
||||
"""
|
||||
Random Forest uses Gini impurity to measure feature importance. Features with higher
|
||||
importance values contribute more to the model's predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Top N slider
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=20,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
|
||||
st.altair_chart(importance_chart, use_container_width=True)
|
||||
|
||||
# Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
|
||||
if not is_cuml and "tree_depths" in model_state:
|
||||
from entropice.dashboard.plots.model_state import plot_rf_tree_statistics
|
||||
|
||||
st.subheader("Tree Structure Statistics")
|
||||
st.markdown("Distribution of tree properties across the forest.")
|
||||
|
||||
with st.spinner("Generating tree statistics..."):
|
||||
chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_depths, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_leaves, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_nodes, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Forest Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
depths = model_state["tree_depths"].to_pandas()
|
||||
leaves = model_state["tree_n_leaves"].to_pandas()
|
||||
nodes = model_state["tree_n_nodes"].to_pandas()
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.write("**Tree Depths:**")
|
||||
st.metric("Mean Depth", f"{depths.mean():.2f}")
|
||||
st.metric("Max Depth", f"{depths.max()}")
|
||||
st.metric("Min Depth", f"{depths.min()}")
|
||||
with col2:
|
||||
st.write("**Leaf Counts:**")
|
||||
st.metric("Mean Leaves", f"{leaves.mean():.2f}")
|
||||
st.metric("Max Leaves", f"{leaves.max()}")
|
||||
st.metric("Min Leaves", f"{leaves.min()}")
|
||||
with col3:
|
||||
st.write("**Node Counts:**")
|
||||
st.metric("Mean Nodes", f"{nodes.mean():.2f}")
|
||||
st.metric("Max Nodes", f"{nodes.max()}")
|
||||
st.metric("Min Nodes", f"{nodes.min()}")
|
||||
|
||||
|
||||
def render_knn_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for KNN model."""
|
||||
st.header("🔍 K-Nearest Neighbors Model Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
K-Nearest Neighbors is a non-parametric, instance-based learning algorithm.
|
||||
Unlike tree-based or parametric models, KNN doesn't learn feature weights or build
|
||||
a model structure. Instead, it memorizes the training data and makes predictions
|
||||
based on the k nearest neighbors.
|
||||
"""
|
||||
)
|
||||
|
||||
# Display model metadata
|
||||
st.subheader("Model Configuration")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Number of Neighbors (k)", model_state.attrs.get("n_neighbors", "N/A"))
|
||||
st.metric("Training Samples", model_state.attrs.get("n_samples_fit", "N/A"))
|
||||
with col2:
|
||||
st.metric("Weights", model_state.attrs.get("weights", "N/A"))
|
||||
st.metric("Algorithm", model_state.attrs.get("algorithm", "N/A"))
|
||||
with col3:
|
||||
st.metric("Metric", model_state.attrs.get("metric", "N/A"))
|
||||
|
||||
st.info(
|
||||
"""
|
||||
**Note:** KNN doesn't have traditional feature importance or model parameters to visualize.
|
||||
The model's behavior depends entirely on:
|
||||
- The number of neighbors (k)
|
||||
- The distance metric used
|
||||
- The weighting scheme for neighbors
|
||||
|
||||
To understand the model better, consider visualizing the decision boundaries on a
|
||||
reduced-dimensional representation of your data (e.g., using PCA or t-SNE).
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for embedding/era5/common features
|
||||
def render_embedding_features(embedding_feature_array):
|
||||
"""Render embedding feature visualizations."""
|
||||
with st.container(border=True):
|
||||
st.header("🛰️ Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
|
||||
|
||||
def render_era5_features(era5_feature_array):
|
||||
"""Render ERA5 feature visualizations."""
|
||||
with st.container(border=True):
|
||||
st.header("⛅ ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
|
||||
|
||||
def render_common_features(common_feature_array):
|
||||
"""Render common feature visualizations."""
|
||||
with st.container(border=True):
|
||||
st.header("🗺️ Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||
size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
|
||||
|
||||
import altair as alt
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.colors import get_cmap
|
||||
from entropice.dashboard.plots.colors import get_cmap, get_palette
|
||||
|
||||
|
||||
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||
|
|
@ -135,8 +136,6 @@ def render_parameter_distributions(results: pd.DataFrame):
|
|||
|
||||
# Get colormap from colors module
|
||||
cmap = get_cmap("parameter_distribution")
|
||||
import matplotlib.colors as mcolors
|
||||
|
||||
bar_color = mcolors.rgb2hex(cmap(0.5))
|
||||
|
||||
# Create histograms for each parameter
|
||||
|
|
@ -159,37 +158,77 @@ def render_parameter_distributions(results: pd.DataFrame):
|
|||
param_values = results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
# Check if we have enough unique values for a histogram
|
||||
n_unique = param_values.nunique()
|
||||
|
||||
if n_unique == 1:
|
||||
# Only one unique value - show as text
|
||||
value = param_values.iloc[0]
|
||||
formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}"
|
||||
st.metric(param_name, formatted)
|
||||
st.caption(f"All {len(param_values)} samples have the same value")
|
||||
continue
|
||||
|
||||
# Numeric parameter - use histogram
|
||||
df_plot = pd.DataFrame({param_name: param_values})
|
||||
df_plot = pd.DataFrame({param_name: param_values.to_numpy()})
|
||||
|
||||
# Use log scale if the range spans multiple orders of magnitude
|
||||
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||
use_log = value_range > 100
|
||||
# Use log scale if the range spans multiple orders of magnitude OR values are very small
|
||||
max_val = param_values.max()
|
||||
|
||||
if use_log:
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(
|
||||
param_name,
|
||||
bin=alt.Bin(maxbins=20),
|
||||
scale=alt.Scale(type="log"),
|
||||
title=param_name,
|
||||
),
|
||||
alt.Y("count()", title="Count"),
|
||||
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
|
||||
# Determine number of bins based on unique values
|
||||
n_bins = min(20, max(5, n_unique))
|
||||
|
||||
# For very small values or large ranges, use a simple bar chart instead of histogram
|
||||
if n_unique <= 10:
|
||||
# Few unique values - use bar chart
|
||||
value_counts = param_values.value_counts().reset_index()
|
||||
value_counts.columns = [param_name, "count"]
|
||||
value_counts = value_counts.sort_values(param_name)
|
||||
|
||||
# Format x-axis values
|
||||
if max_val < 0.01:
|
||||
formatted_col = f"{param_name}_formatted"
|
||||
value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}")
|
||||
chart = (
|
||||
alt.Chart(value_counts)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(f"{formatted_col}:N", title=param_name, sort=None),
|
||||
alt.Y("count:Q", title="Count"),
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".2e"),
|
||||
alt.Tooltip("count", title="Count"),
|
||||
],
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
else:
|
||||
chart = (
|
||||
alt.Chart(value_counts)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(f"{param_name}:Q", title=param_name),
|
||||
alt.Y("count:Q", title="Count"),
|
||||
tooltip=[
|
||||
alt.Tooltip(param_name, format=".3f"),
|
||||
alt.Tooltip("count", title="Count"),
|
||||
],
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
else:
|
||||
# Many unique values - use binned histogram
|
||||
# Avoid log scale for binning as it can cause issues
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_bar(color=bar_color)
|
||||
.encode(
|
||||
alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
|
||||
alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name),
|
||||
alt.Y("count()", title="Count"),
|
||||
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
|
||||
tooltip=[
|
||||
alt.Tooltip(f"{param_name}:Q", format=".2e" if max_val < 0.01 else ".3f", bin=True),
|
||||
"count()",
|
||||
],
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
|
|
@ -276,7 +315,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
alt.Y(metric, title=metric.replace("_", " ").title()),
|
||||
alt.Color(
|
||||
metric,
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
|
||||
|
|
@ -292,7 +331,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
|||
alt.Y(metric, title=metric.replace("_", " ").title()),
|
||||
alt.Color(
|
||||
metric,
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
|
||||
|
|
@ -350,13 +389,8 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
|||
|
||||
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
|
||||
|
||||
# Get colormap from colors module
|
||||
cmap = get_cmap("correlation")
|
||||
# Sample colors for red-blue diverging scheme
|
||||
colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
|
||||
import matplotlib.colors as mcolors
|
||||
|
||||
hex_colors = [mcolors.rgb2hex(c) for c in colors]
|
||||
# Get colormap from colors module (use diverging colormap for correlation)
|
||||
hex_colors = get_palette("correlation", n_colors=256)
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
|
|
@ -367,7 +401,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
|||
alt.Y("Parameter", sort="-x", title="Parameter"),
|
||||
alt.Color(
|
||||
"Correlation",
|
||||
scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
|
||||
scale=alt.Scale(range=hex_colors, domain=[-1, 1]),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
|
||||
|
|
@ -402,7 +436,7 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
|||
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
|
||||
return
|
||||
|
||||
# Prepare binned data
|
||||
# Prepare binned data and gather parameter info
|
||||
results_binned = results.copy()
|
||||
bin_info = {}
|
||||
|
||||
|
|
@ -414,146 +448,160 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
|||
continue
|
||||
|
||||
# Determine if we should use log bins
|
||||
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||
use_log = value_range > 100 or param_values.max() < 1.0
|
||||
min_val = param_values.min()
|
||||
max_val = param_values.max()
|
||||
|
||||
if use_log:
|
||||
if min_val > 0:
|
||||
value_range = max_val / min_val
|
||||
use_log = value_range > 100 or max_val < 0.01
|
||||
else:
|
||||
use_log = False
|
||||
|
||||
if use_log and min_val > 0:
|
||||
# Logarithmic binning for parameters spanning many orders of magnitude
|
||||
log_min = np.log10(param_values.min())
|
||||
log_max = np.log10(param_values.max())
|
||||
n_bins = min(10, int(log_max - log_min) + 1)
|
||||
log_min = np.log10(min_val)
|
||||
log_max = np.log10(max_val)
|
||||
n_bins = min(10, max(3, int(log_max - log_min) + 1))
|
||||
bins = np.logspace(log_min, log_max, num=n_bins)
|
||||
else:
|
||||
# Linear binning
|
||||
n_bins = min(10, int(np.sqrt(len(param_values))))
|
||||
bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
|
||||
n_bins = min(10, max(3, int(np.sqrt(len(param_values)))))
|
||||
bins = np.linspace(min_val, max_val, num=n_bins)
|
||||
|
||||
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
|
||||
bin_info[param_name] = {"use_log": use_log, "bins": bins}
|
||||
bin_info[param_name] = {
|
||||
"use_log": use_log,
|
||||
"bins": bins,
|
||||
"min": min_val,
|
||||
"max": max_val,
|
||||
"n_unique": param_values.nunique(),
|
||||
}
|
||||
|
||||
# Get colormap from colors module
|
||||
cmap = get_cmap(metric)
|
||||
hex_colors = get_palette(metric, n_colors=256)
|
||||
|
||||
# Create binned plots for pairs of parameters
|
||||
# Show the most meaningful combinations
|
||||
# Get parameter names sorted by scale type (log parameters first, then linear)
|
||||
param_names = [col.replace("param_", "") for col in numeric_params]
|
||||
param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p))
|
||||
|
||||
# If we have 3+ parameters, create multiple plots
|
||||
if len(param_names) >= 3:
|
||||
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
|
||||
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
|
||||
|
||||
# Plot first parameter vs others, binned by third
|
||||
for i in range(1, min(3, len(param_names))):
|
||||
x_param = param_names[0]
|
||||
y_param = param_names[i]
|
||||
# Create all pairwise combinations systematically
|
||||
if len(param_names) == 2:
|
||||
# Simple case: just one pair
|
||||
x_param, y_param = param_names_sorted
|
||||
_render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500)
|
||||
else:
|
||||
# Multiple parameters: create structured plots
|
||||
st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
|
||||
|
||||
# Find a third parameter for faceting
|
||||
facet_param = None
|
||||
for p in param_names:
|
||||
if p != x_param and p != y_param:
|
||||
facet_param = p
|
||||
break
|
||||
# Strategy: Show all combinations (including duplicates with swapped axes)
|
||||
# Group by first parameter to create organized sections
|
||||
|
||||
if facet_param and f"{facet_param}_binned" in results_binned.columns:
|
||||
# Create faceted plot
|
||||
plot_data = results_binned[
|
||||
[f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
|
||||
].dropna()
|
||||
for i, x_param in enumerate(param_names_sorted):
|
||||
# Get all other parameters to pair with this one
|
||||
other_params = [p for p in param_names_sorted if p != x_param]
|
||||
|
||||
# Convert binned column to string with sorted categories
|
||||
plot_data = plot_data.sort_values(f"{facet_param}_binned")
|
||||
plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
|
||||
bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
|
||||
if not other_params:
|
||||
continue
|
||||
|
||||
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||
# Create expander for each primary parameter
|
||||
with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)):
|
||||
# Create plots in rows of 2
|
||||
n_others = len(other_params)
|
||||
for row_idx in range(0, n_others, 2):
|
||||
cols = st.columns(2)
|
||||
|
||||
# Get colormap colors
|
||||
cmap_colors = get_cmap(metric)
|
||||
import matplotlib.colors as mcolors
|
||||
for col_idx in range(2):
|
||||
other_idx = row_idx + col_idx
|
||||
if other_idx >= n_others:
|
||||
break
|
||||
|
||||
# Sample colors from colormap
|
||||
hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
|
||||
y_param = other_params[other_idx]
|
||||
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_circle(size=60, opacity=0.7)
|
||||
.encode(
|
||||
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||
color=alt.Color(
|
||||
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||
f"{facet_param}_binned:N",
|
||||
],
|
||||
)
|
||||
.properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
|
||||
.facet(
|
||||
facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
|
||||
columns=4,
|
||||
)
|
||||
)
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
with cols[col_idx]:
|
||||
_render_2d_param_plot(
|
||||
results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350
|
||||
)
|
||||
|
||||
# Also create a simple 2D plot without faceting
|
||||
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
|
||||
|
||||
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||
def _render_2d_param_plot(
|
||||
results_binned: pd.DataFrame,
|
||||
x_param: str,
|
||||
y_param: str,
|
||||
score_col: str,
|
||||
bin_info: dict,
|
||||
hex_colors: list,
|
||||
metric: str,
|
||||
height: int = 400,
|
||||
):
|
||||
"""Render a 2D parameter space plot.
|
||||
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_circle(size=80, opacity=0.6)
|
||||
.encode(
|
||||
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||
color=alt.Color(
|
||||
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||
],
|
||||
)
|
||||
.properties(height=400, title=f"{x_param} vs {y_param}")
|
||||
)
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
Args:
|
||||
results_binned: DataFrame with binned results.
|
||||
x_param: Name of x-axis parameter.
|
||||
y_param: Name of y-axis parameter.
|
||||
score_col: Column name for the score.
|
||||
bin_info: Dictionary with binning information for each parameter.
|
||||
hex_colors: List of hex colors for the colormap.
|
||||
metric: Name of the metric being visualized.
|
||||
height: Height of the plot in pixels.
|
||||
|
||||
elif len(param_names) == 2:
|
||||
# Simple 2-parameter plot
|
||||
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
|
||||
"""
|
||||
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
|
||||
|
||||
x_param = param_names[0]
|
||||
y_param = param_names[1]
|
||||
if len(plot_data) == 0:
|
||||
st.warning(f"No data available for {x_param} vs {y_param}")
|
||||
return
|
||||
|
||||
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
|
||||
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||
|
||||
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||
# Determine marker size based on number of points
|
||||
n_points = len(plot_data)
|
||||
marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40
|
||||
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_circle(size=100, opacity=0.6)
|
||||
.encode(
|
||||
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||
color=alt.Color(
|
||||
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_circle(size=marker_size, opacity=0.7)
|
||||
.encode(
|
||||
x=alt.X(
|
||||
f"param_{x_param}:Q",
|
||||
title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}",
|
||||
scale=x_scale,
|
||||
),
|
||||
y=alt.Y(
|
||||
f"param_{y_param}:Q",
|
||||
title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}",
|
||||
scale=y_scale,
|
||||
),
|
||||
color=alt.Color(
|
||||
f"{score_col}:Q",
|
||||
scale=alt.Scale(range=hex_colors),
|
||||
title=metric.replace("_", " ").title(),
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip(
|
||||
f"param_{x_param}:Q",
|
||||
title=x_param,
|
||||
format=".2e" if bin_info[x_param]["use_log"] else ".3f",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||
],
|
||||
)
|
||||
.properties(height=500, title=f"{x_param} vs {y_param}")
|
||||
alt.Tooltip(
|
||||
f"param_{y_param}:Q",
|
||||
title=y_param,
|
||||
format=".2e" if bin_info[y_param]["use_log"] else ".3f",
|
||||
),
|
||||
alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"),
|
||||
],
|
||||
)
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
.properties(
|
||||
height=height,
|
||||
title=f"{x_param} vs {y_param}",
|
||||
)
|
||||
.interactive()
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
|
||||
def render_score_evolution(results: pd.DataFrame, metric: str):
|
||||
|
|
@ -580,6 +628,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
# Reshape for Altair
|
||||
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
|
||||
|
||||
# Get colormap for score evolution
|
||||
evolution_cmap = get_cmap("score_evolution")
|
||||
evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))]
|
||||
|
||||
# Create line chart
|
||||
chart = (
|
||||
alt.Chart(df_long)
|
||||
|
|
@ -587,7 +639,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
.encode(
|
||||
alt.X("Iteration", title="Iteration"),
|
||||
alt.Y("value", title=metric.replace("_", " ").title()),
|
||||
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
|
||||
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)),
|
||||
strokeDash=alt.StrokeDash(
|
||||
"Type",
|
||||
legend=None,
|
||||
|
|
@ -614,6 +666,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
|
|||
st.metric("Best at Iteration", best_iter)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_multi_metric_comparison(results: pd.DataFrame):
|
||||
"""Render comparison of multiple metrics.
|
||||
|
||||
|
|
@ -628,8 +681,16 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
|||
st.warning("Need at least 2 metrics for comparison.")
|
||||
return
|
||||
|
||||
# Let user select two metrics to compare
|
||||
col1, col2 = st.columns(2)
|
||||
# Get parameter columns for color options
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
param_names = [col.replace("param_", "") for col in param_cols]
|
||||
|
||||
if not param_names:
|
||||
st.warning("No parameters found for coloring.")
|
||||
return
|
||||
|
||||
# Let user select two metrics and color parameter
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
metric1 = st.selectbox(
|
||||
"Select First Metric",
|
||||
|
|
@ -646,21 +707,57 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
|||
key="metric2_select",
|
||||
)
|
||||
|
||||
with col3:
|
||||
color_by = st.selectbox(
|
||||
"Color By",
|
||||
options=param_names,
|
||||
index=0,
|
||||
key="color_select",
|
||||
)
|
||||
|
||||
if metric1 == metric2:
|
||||
st.warning("Please select different metrics.")
|
||||
return
|
||||
|
||||
# Create scatter plot
|
||||
# Create scatter plot data
|
||||
df_plot = pd.DataFrame(
|
||||
{
|
||||
metric1: results[f"mean_test_{metric1}"],
|
||||
metric2: results[f"mean_test_{metric2}"],
|
||||
"Iteration": range(len(results)),
|
||||
}
|
||||
)
|
||||
|
||||
# Get colormap from colors module
|
||||
cmap = get_cmap("iteration")
|
||||
# Add color parameter to dataframe
|
||||
param_col = f"param_{color_by}"
|
||||
df_plot[color_by] = results[param_col]
|
||||
color_col = color_by
|
||||
|
||||
# Check if parameter is numeric and should use log scale
|
||||
param_values = results[param_col].dropna()
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||
use_log = value_range > 100
|
||||
|
||||
if use_log:
|
||||
color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256))
|
||||
color_format = ".2e"
|
||||
else:
|
||||
color_scale = alt.Scale(range=get_palette(color_by, n_colors=256))
|
||||
color_format = ".3f"
|
||||
else:
|
||||
# Categorical parameter
|
||||
color_scale = alt.Scale(scheme="category20")
|
||||
color_format = None
|
||||
|
||||
# Build tooltip list
|
||||
tooltip_list = [
|
||||
alt.Tooltip(metric1, format=".4f"),
|
||||
alt.Tooltip(metric2, format=".4f"),
|
||||
]
|
||||
if color_format:
|
||||
tooltip_list.append(alt.Tooltip(color_col, format=color_format))
|
||||
else:
|
||||
tooltip_list.append(color_col)
|
||||
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
|
|
@ -668,12 +765,8 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
|||
.encode(
|
||||
alt.X(metric1, title=metric1.replace("_", " ").title()),
|
||||
alt.Y(metric2, title=metric2.replace("_", " ").title()),
|
||||
alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
|
||||
tooltip=[
|
||||
alt.Tooltip(metric1, format=".4f"),
|
||||
alt.Tooltip(metric2, format=".4f"),
|
||||
"Iteration",
|
||||
],
|
||||
alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
|
||||
tooltip=tooltip_list,
|
||||
)
|
||||
.properties(height=500)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -447,3 +447,212 @@ def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
|
|||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
# XGBoost-specific visualizations
|
||||
def plot_xgboost_feature_importance(
|
||||
model_state: xr.Dataset, importance_type: str = "gain", top_n: int = 20
|
||||
) -> alt.Chart:
|
||||
"""Plot XGBoost feature importance for a specific importance type.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the XGBoost model state.
|
||||
importance_type: Type of importance ('weight', 'gain', 'cover', 'total_gain', 'total_cover').
|
||||
top_n: Number of top features to display.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the top features by importance.
|
||||
|
||||
"""
|
||||
# Get the importance array
|
||||
importance_key = f"feature_importance_{importance_type}"
|
||||
if importance_key not in model_state:
|
||||
raise ValueError(f"Importance type '{importance_type}' not found in model state")
|
||||
|
||||
importance = model_state[importance_key].to_pandas()
|
||||
|
||||
# Sort and take top N
|
||||
top_features = importance.nlargest(top_n).sort_values(ascending=True)
|
||||
|
||||
# Create DataFrame for plotting
|
||||
plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"),
|
||||
color=alt.value("steelblue"),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=400,
|
||||
title=f"Top {top_n} Features by {importance_type.replace('_', ' ').title()} Importance (XGBoost)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_xgboost_importance_comparison(model_state: xr.Dataset, top_n: int = 15) -> alt.Chart:
|
||||
"""Compare different importance types for XGBoost side-by-side.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the XGBoost model state.
|
||||
top_n: Number of top features to display.
|
||||
|
||||
Returns:
|
||||
Altair chart comparing importance types.
|
||||
|
||||
"""
|
||||
# Collect all importance types
|
||||
importance_types = ["weight", "gain", "cover"]
|
||||
all_data = []
|
||||
|
||||
for imp_type in importance_types:
|
||||
importance_key = f"feature_importance_{imp_type}"
|
||||
if importance_key in model_state:
|
||||
importance = model_state[importance_key].to_pandas()
|
||||
# Get top features by this importance type
|
||||
top_features = importance.nlargest(top_n)
|
||||
for feature, value in top_features.items():
|
||||
all_data.append(
|
||||
{
|
||||
"feature": feature,
|
||||
"importance": value,
|
||||
"type": imp_type.replace("_", " ").title(),
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(all_data)
|
||||
|
||||
# Create faceted bar chart
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
x=alt.X("importance:Q", title="Importance Value"),
|
||||
color=alt.Color("type:N", title="Importance Type", scale=alt.Scale(scheme="category10")),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
alt.Tooltip("type:N", title="Type"),
|
||||
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
|
||||
],
|
||||
)
|
||||
.properties(width=250, height=300)
|
||||
.facet(facet=alt.Facet("type:N", title="Importance Type"), columns=3)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
# Random Forest-specific visualizations
|
||||
def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.Chart:
|
||||
"""Plot Random Forest feature importance (Gini importance).
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the Random Forest model state.
|
||||
top_n: Number of top features to display.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the top features by importance.
|
||||
|
||||
"""
|
||||
importance = model_state["feature_importance"].to_pandas()
|
||||
|
||||
# Sort and take top N
|
||||
top_features = importance.nlargest(top_n).sort_values(ascending=True)
|
||||
|
||||
# Create DataFrame for plotting
|
||||
plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("importance:Q", title="Gini Importance"),
|
||||
color=alt.value("forestgreen"),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=400,
|
||||
title=f"Top {top_n} Features by Gini Importance (Random Forest)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
"""Plot Random Forest tree statistics.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the Random Forest model state.
|
||||
|
||||
Returns:
|
||||
Tuple of three Altair charts (depth histogram, leaves histogram, nodes histogram).
|
||||
|
||||
"""
|
||||
# Extract tree statistics
|
||||
depths = model_state["tree_depths"].to_pandas()
|
||||
n_leaves = model_state["tree_n_leaves"].to_pandas()
|
||||
n_nodes = model_state["tree_n_nodes"].to_pandas()
|
||||
|
||||
# Create DataFrames
|
||||
df_depths = pd.DataFrame({"value": depths.to_numpy()})
|
||||
df_leaves = pd.DataFrame({"value": n_leaves.to_numpy()})
|
||||
df_nodes = pd.DataFrame({"value": n_nodes.to_numpy()})
|
||||
|
||||
# Histogram for tree depths
|
||||
chart_depths = (
|
||||
alt.Chart(df_depths)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("steelblue"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Tree Depths")
|
||||
)
|
||||
|
||||
# Histogram for number of leaves
|
||||
chart_leaves = (
|
||||
alt.Chart(df_leaves)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("forestgreen"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Leaf Counts")
|
||||
)
|
||||
|
||||
# Histogram for number of nodes
|
||||
chart_nodes = (
|
||||
alt.Chart(df_nodes)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
|
||||
y=alt.Y("count()", title="Number of Trees"),
|
||||
color=alt.value("darkorange"),
|
||||
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")],
|
||||
)
|
||||
.properties(width=300, height=200, title="Distribution of Node Counts")
|
||||
)
|
||||
|
||||
return chart_depths, chart_leaves, chart_nodes
|
||||
|
|
|
|||
|
|
@ -761,6 +761,117 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
|||
st.warning("No valid data available for selected parameters")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
|
||||
"""Render interactive pydeck map for grid cell areas.
|
||||
|
||||
Args:
|
||||
grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
|
||||
grid: Grid type ('hex' or 'healpix').
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Grid Cell Areas Distribution")
|
||||
|
||||
# Controls
|
||||
col1, col2 = st.columns([3, 1])
|
||||
|
||||
with col1:
|
||||
area_metric = st.selectbox(
|
||||
"Area Metric",
|
||||
options=["cell_area", "land_area", "water_area", "land_ratio"],
|
||||
format_func=lambda x: x.replace("_", " ").title(),
|
||||
key="areas_metric",
|
||||
)
|
||||
|
||||
with col2:
|
||||
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity")
|
||||
|
||||
# Create GeoDataFrame
|
||||
gdf = grid_gdf.copy()
|
||||
|
||||
# Convert to WGS84 first
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix geometries after CRS conversion
|
||||
if grid == "hex":
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||
|
||||
# Get values for the selected metric
|
||||
values = gdf_wgs84[area_metric].to_numpy()
|
||||
|
||||
# Normalize values for color mapping
|
||||
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
|
||||
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||
|
||||
# Apply colormap based on metric type
|
||||
if area_metric == "land_ratio":
|
||||
cmap = get_cmap("terrain") # Different colormap for ratio
|
||||
else:
|
||||
cmap = get_cmap("terrain")
|
||||
|
||||
colors = [cmap(val) for val in normalized]
|
||||
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||
|
||||
# Set elevation based on normalized values for 3D visualization
|
||||
gdf_wgs84["elevation"] = normalized
|
||||
|
||||
# Create GeoJSON
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"cell_area": f"{float(row['cell_area']):.2f}",
|
||||
"land_area": f"{float(row['land_area']):.2f}",
|
||||
"water_area": f"{float(row['water_area']):.2f}",
|
||||
"land_ratio": f"{float(row['land_ratio']):.2%}",
|
||||
"fill_color": row["fill_color"],
|
||||
"elevation": float(row["elevation"]),
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer with 3D elevation
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=opacity,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=True,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
get_elevation="properties.elevation",
|
||||
elevation_scale=500000,
|
||||
line_width_min_pixels=0.5,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
|
||||
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>Cell Area:</b> {cell_area} km²<br/>"
|
||||
"<b>Land Area:</b> {land_area} km²<br/>"
|
||||
"<b>Water Area:</b> {water_area} km²<br/>"
|
||||
"<b>Land Ratio:</b> {land_ratio}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
|
||||
|
||||
# Show additional info
|
||||
st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
|
||||
"""Render interactive pydeck map for ERA5 climate data.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from entropice import grids
|
||||
from entropice.dashboard.plots.source_data import (
|
||||
render_alphaearth_map,
|
||||
render_alphaearth_overview,
|
||||
|
|
@ -9,6 +10,7 @@ from entropice.dashboard.plots.source_data import (
|
|||
render_arcticdem_map,
|
||||
render_arcticdem_overview,
|
||||
render_arcticdem_plots,
|
||||
render_areas_map,
|
||||
render_era5_map,
|
||||
render_era5_overview,
|
||||
render_era5_plots,
|
||||
|
|
@ -184,7 +186,7 @@ def render_training_data_page():
|
|||
st.markdown("---")
|
||||
|
||||
# Create tabs for different data views
|
||||
tab_names = ["📊 Labels"]
|
||||
tab_names = ["📊 Labels", "📐 Areas"]
|
||||
|
||||
# Add tabs for each member
|
||||
for member in ensemble.members:
|
||||
|
|
@ -228,8 +230,50 @@ def render_training_data_page():
|
|||
|
||||
render_spatial_map(train_data_dict)
|
||||
|
||||
# Areas tab
|
||||
with tabs[1]:
|
||||
st.markdown("### Grid Cell Areas and Land/Water Distribution")
|
||||
|
||||
st.markdown(
|
||||
"This visualization shows the spatial distribution of cell areas, land areas, "
|
||||
"water areas, and land ratio across the grid. The grid has been filtered to "
|
||||
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
|
||||
"with >10% land coverage."
|
||||
)
|
||||
|
||||
# Load grid data
|
||||
grid_gdf = grids.open(ensemble.grid, ensemble.level)
|
||||
|
||||
st.success(
|
||||
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
|
||||
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
|
||||
)
|
||||
|
||||
# Show summary statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Cells", f"{len(grid_gdf):,}")
|
||||
with col2:
|
||||
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
|
||||
with col3:
|
||||
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
|
||||
with col4:
|
||||
total_land = grid_gdf["land_area"].sum()
|
||||
st.metric("Total Land Area", f"{total_land:,.0f} km²")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (
|
||||
ensemble.grid == "healpix" and ensemble.level == 10
|
||||
):
|
||||
st.warning(
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
|
||||
)
|
||||
else:
|
||||
render_areas_map(grid_gdf, ensemble.grid)
|
||||
|
||||
# AlphaEarth tab
|
||||
tab_idx = 1
|
||||
tab_idx = 2
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
st.markdown("### AlphaEarth Embeddings Analysis")
|
||||
|
|
|
|||
|
|
@ -29,10 +29,9 @@ class TrainingResult:
|
|||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
result_file = result_path / "search_results.parquet"
|
||||
state_file = result_path / "best_estimator_state.nc"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
|
|
@ -96,9 +95,9 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
|
|||
|
||||
"""
|
||||
return {
|
||||
"binary": e.create_cat_training_dataset("binary"),
|
||||
"count": e.create_cat_training_dataset("count"),
|
||||
"density": e.create_cat_training_dataset("density"),
|
||||
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
||||
"count": e.create_cat_training_dataset("count", device="cpu"),
|
||||
"density": e.create_cat_training_dataset("density", device="cpu"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,15 @@ Naming conventions:
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Literal
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import cupy as cp
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
|
|
@ -110,8 +113,8 @@ def bin_values(
|
|||
@dataclass(frozen=True, eq=False)
|
||||
class DatasetLabels:
|
||||
binned: pd.Series
|
||||
train: torch.Tensor
|
||||
test: torch.Tensor
|
||||
train: torch.Tensor | np.ndarray | cp.ndarray
|
||||
test: torch.Tensor | np.ndarray | cp.ndarray
|
||||
raw_values: pd.Series
|
||||
|
||||
@cached_property
|
||||
|
|
@ -135,8 +138,8 @@ class DatasetLabels:
|
|||
@dataclass(frozen=True, eq=False)
|
||||
class DatasetInputs:
|
||||
data: pd.DataFrame
|
||||
train: torch.Tensor
|
||||
test: torch.Tensor
|
||||
train: torch.Tensor | np.ndarray | cp.ndarray
|
||||
test: torch.Tensor | np.ndarray | cp.ndarray
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -151,6 +154,13 @@ class CategoricalTrainingDataset:
|
|||
return len(self.z)
|
||||
|
||||
|
||||
class DatasetStats(TypedDict):
|
||||
target: str
|
||||
num_target_samples: int
|
||||
members: dict[str, dict[str, object]]
|
||||
total_features: int
|
||||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True)
|
||||
class DatasetEnsemble:
|
||||
|
|
@ -283,15 +293,15 @@ class DatasetEnsemble:
|
|||
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||
return arcticdem_df
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
def get_stats(self) -> DatasetStats:
|
||||
"""Get dataset statistics.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing target stats, member stats, and total features count.
|
||||
DatasetStats: Dictionary containing target stats, member stats, and total features count.
|
||||
|
||||
"""
|
||||
targets = self._read_target()
|
||||
stats = {
|
||||
stats: DatasetStats = {
|
||||
"target": self.target,
|
||||
"num_target_samples": len(targets),
|
||||
"members": {},
|
||||
|
|
@ -365,11 +375,64 @@ class DatasetEnsemble:
|
|||
print(f"Saved dataset to cache at {cache_file}.")
|
||||
return dataset
|
||||
|
||||
def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset:
|
||||
def create_batches(
|
||||
self,
|
||||
batch_size: int,
|
||||
filter_target_col: str | None = None,
|
||||
cache_mode: Literal["n", "o", "r"] = "r",
|
||||
) -> Generator[pd.DataFrame]:
|
||||
targets = self._read_target()
|
||||
if len(targets) == 0:
|
||||
raise ValueError("No target samples found.")
|
||||
elif len(targets) < batch_size:
|
||||
yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
|
||||
return
|
||||
|
||||
if filter_target_col is not None:
|
||||
targets = targets.loc[targets[filter_target_col]]
|
||||
|
||||
for i in range(0, len(targets), batch_size):
|
||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.paths.get_dataset_cache(
|
||||
self.id(), subset=filter_target_col, batch=(i, i + batch_size)
|
||||
)
|
||||
if cache_mode == "r" and cache_file.exists():
|
||||
dataset = gpd.read_parquet(cache_file)
|
||||
print(
|
||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
||||
f" and {len(dataset.columns)} features."
|
||||
)
|
||||
yield dataset
|
||||
else:
|
||||
targets_batch = targets.iloc[i : i + batch_size]
|
||||
|
||||
member_dfs = []
|
||||
for member in self.members:
|
||||
if member.startswith("ERA5"):
|
||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||
member_dfs.append(self._prep_era5(targets_batch, era5_agg))
|
||||
elif member == "AlphaEarth":
|
||||
member_dfs.append(self._prep_embeddings(targets_batch))
|
||||
elif member == "ArcticDEM":
|
||||
member_dfs.append(self._prep_arcticdem(targets_batch))
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
dataset = targets_batch.set_index("cell_id").join(member_dfs)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
if cache_mode in ["o", "r"]:
|
||||
dataset.to_parquet(cache_file)
|
||||
print(f"Saved dataset to cache at {cache_file}.")
|
||||
yield dataset
|
||||
|
||||
def create_cat_training_dataset(
|
||||
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||
) -> CategoricalTrainingDataset:
|
||||
"""Create a categorical dataset for training.
|
||||
|
||||
Args:
|
||||
task (Task): Task type.
|
||||
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
||||
|
||||
Returns:
|
||||
CategoricalTrainingDataset: The prepared categorical training dataset.
|
||||
|
|
@ -414,10 +477,23 @@ class DatasetEnsemble:
|
|||
split.loc[test_idx] = "test"
|
||||
split = split.astype("category")
|
||||
|
||||
X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0)
|
||||
X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0)
|
||||
y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0)
|
||||
y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0)
|
||||
X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64")
|
||||
X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64")
|
||||
y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64")
|
||||
y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64")
|
||||
|
||||
if device == "cuda":
|
||||
X_train = cp.asarray(X_train)
|
||||
X_test = cp.asarray(X_test)
|
||||
y_train = cp.asarray(y_train)
|
||||
y_test = cp.asarray(y_test)
|
||||
elif device == "torch":
|
||||
X_train = torch.from_numpy(X_train)
|
||||
X_test = torch.from_numpy(X_test)
|
||||
y_train = torch.from_numpy(y_train)
|
||||
y_test = torch.from_numpy(y_test)
|
||||
else:
|
||||
assert device == "cpu", "Invalid device specified."
|
||||
|
||||
return CategoricalTrainingDataset(
|
||||
dataset=dataset.to_crs("EPSG:4326"),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# ruff: noqa: N806
|
||||
"""Inference runs on trained models."""
|
||||
|
||||
import cupy as cp
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
|
@ -32,35 +33,32 @@ def predict_proba(
|
|||
list: A list of predicted probabilities for each cell.
|
||||
|
||||
"""
|
||||
data = e.create()
|
||||
print(f"Predicting probabilities for {len(data)} cells...")
|
||||
|
||||
# Predict in batches to avoid memory issues
|
||||
batch_size = 10_000
|
||||
batch_size = 10000
|
||||
preds = []
|
||||
|
||||
cols_to_drop = ["geometry"]
|
||||
if e.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data.iloc[i : i + batch_size]
|
||||
for batch in e.create_batches(batch_size=batch_size):
|
||||
cols_to_drop = ["geometry"]
|
||||
if e.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||
cell_ids = X_batch.index.to_numpy()
|
||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float64")
|
||||
X_batch = torch.asarray(X_batch, device=0)
|
||||
batch_preds = clf.predict(X_batch).cpu().numpy()
|
||||
batch_preds = clf.predict(X_batch)
|
||||
if isinstance(batch_preds, cp.ndarray):
|
||||
batch_preds = batch_preds.get()
|
||||
elif torch.is_tensor(batch_preds):
|
||||
batch_preds = batch_preds.cpu().numpy()
|
||||
batch_preds = gpd.GeoDataFrame(
|
||||
{
|
||||
"cell_id": cell_ids,
|
||||
"predicted_class": [classes[i] for i in batch_preds],
|
||||
"geometry": cell_geoms,
|
||||
},
|
||||
crs="epsg:3413",
|
||||
)
|
||||
).set_crs(epsg=3413, inplace=False)
|
||||
preds.append(batch_preds)
|
||||
preds = gpd.GeoDataFrame(pd.concat(preds))
|
||||
|
||||
return preds
|
||||
return gpd.GeoDataFrame(pd.concat(preds, ignore_index=True))
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Literal
|
|||
DATA_DIR = (
|
||||
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||
)
|
||||
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
|
||||
# DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
|
||||
|
||||
GRIDS_DIR = DATA_DIR / "grids"
|
||||
FIGURES_DIR = Path("figures")
|
||||
|
|
@ -110,7 +110,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
|||
return dataset_file
|
||||
|
||||
|
||||
def get_dataset_cache(eid: str, subset: str | None = None) -> Path:
|
||||
def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int] | None = None) -> Path:
|
||||
if batch is not None:
|
||||
eid = f"{eid}_batch{batch[0]}-{batch[1]}"
|
||||
if subset is None:
|
||||
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pickle
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
import cupy as cp
|
||||
import cyclopts
|
||||
import pandas as pd
|
||||
import toml
|
||||
|
|
@ -14,7 +15,6 @@ from entropy import ESPAClassifier
|
|||
from rich import pretty, traceback
|
||||
from scipy.stats import loguniform, randint
|
||||
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
|
||||
from sklearn import set_config
|
||||
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||
from stopuhr import stopwatch
|
||||
from xgboost.sklearn import XGBClassifier
|
||||
|
|
@ -26,7 +26,9 @@ from entropice.paths import get_cv_results_dir
|
|||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
set_config(array_api_dispatch=True)
|
||||
# Disabled array_api_dispatch to avoid namespace conflicts between NumPy and CuPy
|
||||
# when using XGBoost with device="cuda"
|
||||
# set_config(array_api_dispatch=True)
|
||||
|
||||
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
|
||||
|
||||
|
|
@ -85,8 +87,8 @@ def _create_clf(
|
|||
objective="multi:softprob" if settings.task != "binary" else "binary:logistic",
|
||||
eval_metric="mlogloss" if settings.task != "binary" else "logloss",
|
||||
random_state=42,
|
||||
tree_method="gpu_hist",
|
||||
device="cuda",
|
||||
tree_method="hist",
|
||||
device="gpu", # Using CPU to avoid CuPy/NumPy namespace conflicts in sklearn metrics
|
||||
)
|
||||
fit_params = {}
|
||||
elif settings.model == "rf":
|
||||
|
|
@ -98,11 +100,10 @@ def _create_clf(
|
|||
fit_params = {}
|
||||
elif settings.model == "knn":
|
||||
param_grid = {
|
||||
"n_neighbors": randint(3, 15),
|
||||
"n_neighbors": randint(10, 200),
|
||||
"weights": ["uniform", "distance"],
|
||||
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
||||
}
|
||||
clf = KNeighborsClassifier(random_state=42)
|
||||
clf = KNeighborsClassifier()
|
||||
fit_params = {}
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {settings.model}")
|
||||
|
|
@ -148,8 +149,9 @@ def random_cv(
|
|||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
"""
|
||||
device = "torch" if settings.model in ["espa"] else "cuda"
|
||||
print("Creating training data...")
|
||||
training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task)
|
||||
training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task, device=device)
|
||||
|
||||
clf, param_grid, fit_params = _create_clf(settings)
|
||||
print(f"Using model: {settings.model} with parameters: {param_grid}")
|
||||
|
|
@ -159,7 +161,7 @@ def random_cv(
|
|||
clf,
|
||||
param_grid,
|
||||
n_iter=settings.n_iter,
|
||||
n_jobs=8,
|
||||
n_jobs=1,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=10,
|
||||
|
|
@ -169,14 +171,25 @@ def random_cv(
|
|||
|
||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||
search.fit(training_data.X.train, training_data.y.train, **fit_params)
|
||||
y_train = (
|
||||
training_data.y.train.get()
|
||||
if settings.model == "xgboost" and isinstance(training_data.y.train, cp.ndarray)
|
||||
else training_data.y.train
|
||||
)
|
||||
search.fit(training_data.X.train, y_train, **fit_params)
|
||||
|
||||
print("Best parameters combination found:")
|
||||
best_parameters = search.best_estimator_.get_params()
|
||||
best_estimator = search.best_estimator_
|
||||
best_parameters = best_estimator.get_params()
|
||||
for param_name in sorted(param_grid.keys()):
|
||||
print(f"{param_name}: {best_parameters[param_name]}")
|
||||
|
||||
test_accuracy = search.score(training_data.X.test, training_data.y.test)
|
||||
y_test = (
|
||||
training_data.y.test.get()
|
||||
if settings.model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
|
||||
else training_data.y.test
|
||||
)
|
||||
test_accuracy = search.score(training_data.X.test, y_test)
|
||||
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||
|
||||
|
|
@ -207,7 +220,7 @@ def random_cv(
|
|||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
print(f"Storing best estimator model to {best_model_file}")
|
||||
with open(best_model_file, "wb") as f:
|
||||
pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickle.dump(best_estimator, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Store the search results
|
||||
results = pd.DataFrame(search.cv_results_)
|
||||
|
|
@ -220,10 +233,10 @@ def random_cv(
|
|||
results.to_parquet(results_file)
|
||||
|
||||
# Get the inner state of the best estimator
|
||||
features = training_data.X.data.columns.tolist()
|
||||
|
||||
if settings.model == "espa":
|
||||
best_estimator = search.best_estimator_
|
||||
# Annotate the state with xarray metadata
|
||||
features = training_data.X.data.columns.tolist()
|
||||
boxes = list(range(best_estimator.K_))
|
||||
box_centers = xr.DataArray(
|
||||
best_estimator.S_.cpu().numpy(),
|
||||
|
|
@ -260,9 +273,154 @@ def random_cv(
|
|||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
elif settings.model == "xgboost":
|
||||
# Extract XGBoost-specific information
|
||||
# Get the underlying booster
|
||||
booster = best_estimator.get_booster()
|
||||
|
||||
# Feature importance with different importance types
|
||||
importance_weight = booster.get_score(importance_type="weight")
|
||||
importance_gain = booster.get_score(importance_type="gain")
|
||||
importance_cover = booster.get_score(importance_type="cover")
|
||||
importance_total_gain = booster.get_score(importance_type="total_gain")
|
||||
importance_total_cover = booster.get_score(importance_type="total_cover")
|
||||
|
||||
# Create aligned arrays for all features (including zero-importance)
|
||||
def align_importance(importance_dict, features):
|
||||
"""Align importance dict to feature list, filling missing with 0."""
|
||||
return [importance_dict.get(f, 0.0) for f in features]
|
||||
|
||||
feature_importance_weight = xr.DataArray(
|
||||
align_importance(importance_weight, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_weight",
|
||||
attrs={"description": "Number of times a feature is used to split the data across all trees."},
|
||||
)
|
||||
feature_importance_gain = xr.DataArray(
|
||||
align_importance(importance_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_gain",
|
||||
attrs={"description": "Average gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_cover = xr.DataArray(
|
||||
align_importance(importance_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_cover",
|
||||
attrs={"description": "Average coverage across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_gain = xr.DataArray(
|
||||
align_importance(importance_total_gain, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_gain",
|
||||
attrs={"description": "Total gain across all splits the feature is used in."},
|
||||
)
|
||||
feature_importance_total_cover = xr.DataArray(
|
||||
align_importance(importance_total_cover, features),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance_total_cover",
|
||||
attrs={"description": "Total coverage across all splits the feature is used in."},
|
||||
)
|
||||
|
||||
# Store tree information
|
||||
n_trees = booster.num_boosted_rounds()
|
||||
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"feature_importance_weight": feature_importance_weight,
|
||||
"feature_importance_gain": feature_importance_gain,
|
||||
"feature_importance_cover": feature_importance_cover,
|
||||
"feature_importance_total_gain": feature_importance_total_gain,
|
||||
"feature_importance_total_cover": feature_importance_total_cover,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
|
||||
"n_trees": n_trees,
|
||||
"objective": str(best_estimator.objective),
|
||||
},
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
elif settings.model == "rf":
|
||||
# Extract Random Forest-specific information
|
||||
# Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_)
|
||||
# like sklearn does, so we can only extract feature importances and model parameters
|
||||
|
||||
# Feature importances (Gini importance)
|
||||
feature_importances = best_estimator.feature_importances_
|
||||
|
||||
feature_importance = xr.DataArray(
|
||||
feature_importances,
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_importance",
|
||||
attrs={"description": "Gini importance (impurity-based feature importance)."},
|
||||
)
|
||||
|
||||
# cuML RF doesn't expose individual trees, so we store model parameters instead
|
||||
n_estimators = best_estimator.n_estimators
|
||||
max_depth = best_estimator.max_depth
|
||||
|
||||
# OOB score if available
|
||||
oob_score = None
|
||||
if hasattr(best_estimator, "oob_score_") and best_estimator.oob_score:
|
||||
oob_score = float(best_estimator.oob_score_)
|
||||
|
||||
# cuML RandomForest doesn't provide per-tree statistics like sklearn
|
||||
# Store what we have: feature importances and model configuration
|
||||
attrs = {
|
||||
"description": "Inner state of the best RandomForestClassifier from RandomizedSearchCV (cuML).",
|
||||
"n_estimators": int(n_estimators),
|
||||
"note": "cuML RandomForest does not expose individual tree statistics like sklearn",
|
||||
}
|
||||
|
||||
# Only add optional attributes if they have values
|
||||
if max_depth is not None:
|
||||
attrs["max_depth"] = int(max_depth)
|
||||
if oob_score is not None:
|
||||
attrs["oob_score"] = oob_score
|
||||
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"feature_importance": feature_importance,
|
||||
},
|
||||
attrs=attrs,
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
elif settings.model == "knn":
|
||||
# KNN doesn't have traditional feature importance
|
||||
# Store information about the training data and neighbors
|
||||
n_neighbors = best_estimator.n_neighbors
|
||||
|
||||
# We can't extract meaningful "state" from KNN in the same way
|
||||
# but we can store metadata about the model
|
||||
state = xr.Dataset(
|
||||
attrs={
|
||||
"description": "Metadata of the best KNeighborsClassifier from RandomizedSearchCV.",
|
||||
"n_neighbors": n_neighbors,
|
||||
"weights": str(best_estimator.weights),
|
||||
"algorithm": str(best_estimator.algorithm),
|
||||
"metric": str(best_estimator.metric),
|
||||
"n_samples_fit": best_estimator.n_samples_fit_,
|
||||
},
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator metadata to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels)
|
||||
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
|
||||
preds_file = results_dir / "predicted_probabilities.parquet"
|
||||
print(f"Storing predicted probabilities to {preds_file}")
|
||||
preds.to_parquet(preds_file)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue