Update Model State Page

This commit is contained in:
Tobias Hölzer 2025-12-23 14:29:47 +01:00
parent 6ed5a9c224
commit a64e1ac41f
12 changed files with 1288 additions and 380 deletions

4
pixi.lock generated
View file

@ -9,7 +9,6 @@ environments:
- https://pypi.org/simple
options:
channel-priority: disabled
pypi-prerelease-mode: if-necessary-or-explicit
packages:
linux-64:
- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-7_kmp_llvm.conda
@ -2869,7 +2868,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: 6488240242ab4091686c27ee2e721c81143e0a158bdff086518da6ccdc2971c8
sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -2933,6 +2932,7 @@ packages:
- ruff>=0.14.9,<0.15
- pandas-stubs>=2.3.3.251201,<3
requires_python: '>=3.13,<3.14'
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
version: 0.1.0

View file

@ -65,7 +65,8 @@ dependencies = [
"pydeck>=0.9.1,<0.10",
"pypalettes>=0.2.1,<0.3",
"ty>=0.0.2,<0.0.3",
"ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3",
"ruff>=0.14.9,<0.15",
"pandas-stubs>=2.3.3.251201,<3",
]
[project.scripts]
@ -118,6 +119,7 @@ platforms = ["linux-64"]
[tool.pixi.activation.env]
SCIPY_ARRAY_API = "1"
FAST_DATA_DIR = "./data"
[tool.pixi.system-requirements]
cuda = "12"

View file

@ -1,6 +1,7 @@
"""Model State page for the Entropice dashboard."""
import streamlit as st
import xarray as xr
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import (
@ -43,6 +44,9 @@ def render_model_state_page():
)
selected_result = result_options[selected_name]
# Get the model type from settings
model_type = selected_result.settings.get("model", "espa")
# Load model state
with st.spinner("Loading model state..."):
model_state = load_model_state(selected_result)
@ -50,6 +54,29 @@ def render_model_state_page():
st.error("Could not load model state for this result.")
return
# Display basic model state info
with st.expander("Model State Information", expanded=False):
st.write(f"**Model Type:** {model_type.upper()}")
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
st.write(f"**Attributes:** {dict(model_state.attrs)}")
# Render model-specific visualizations
if model_type == "espa":
render_espa_model_state(model_state, selected_result)
elif model_type == "xgboost":
render_xgboost_model_state(model_state, selected_result)
elif model_type == "rf":
render_rf_model_state(model_state, selected_result)
elif model_type == "knn":
render_knn_model_state(model_state, selected_result)
else:
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
def render_espa_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for ESPA model."""
# Scale feature weights by number of features
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
@ -62,23 +89,6 @@ def render_model_state_page():
# Generate unified colormaps
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
# Display basic model state info
with st.expander("Model State Information", expanded=False):
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
# Show statistics
st.write("**Feature Weight Statistics:**")
feature_weights = model_state["feature_weights"].to_pandas()
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
with col2:
st.metric("Max Weight", f"{feature_weights.max():.4f}")
with col3:
st.metric("Total Features", len(feature_weights))
# Feature importance section
st.header("Feature Importance")
st.markdown("The most important features based on learned feature weights from the best estimator.")
@ -167,6 +177,216 @@ def render_model_state_page():
# Embedding features analysis (if present)
if embedding_feature_array is not None:
render_embedding_features(embedding_feature_array)
# ERA5 features analysis (if present)
if era5_feature_array is not None:
render_era5_features(era5_feature_array)
# Common features analysis (if present)
if common_feature_array is not None:
render_common_features(common_feature_array)
def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for XGBoost model."""
from entropice.dashboard.plots.model_state import (
plot_xgboost_feature_importance,
plot_xgboost_importance_comparison,
)
st.header("🌲 XGBoost Model Analysis")
st.markdown(
f"""
XGBoost gradient boosted tree model with **{model_state.attrs.get("n_trees", "N/A")} trees**.
**Objective:** {model_state.attrs.get("objective", "N/A")}
"""
)
# Feature importance with different types
st.subheader("Feature Importance Analysis")
st.markdown(
"""
XGBoost provides multiple ways to measure feature importance:
- **Weight**: Number of times a feature is used to split the data
- **Gain**: Average gain across all splits using the feature
- **Cover**: Average coverage across all splits using the feature
- **Total Gain**: Total gain across all splits
- **Total Cover**: Total coverage across all splits
"""
)
# Importance type selector
importance_type = st.selectbox(
"Select Importance Type",
options=["gain", "weight", "cover", "total_gain", "total_cover"],
index=0,
help="Choose which importance metric to visualize",
)
# Top N slider
top_n = st.slider(
"Number of top features to display",
min_value=5,
max_value=50,
value=20,
step=5,
help="Select how many of the most important features to visualize",
)
with st.spinner("Generating feature importance plot..."):
importance_chart = plot_xgboost_feature_importance(model_state, importance_type=importance_type, top_n=top_n)
st.altair_chart(importance_chart, use_container_width=True)
# Comparison of importance types
st.subheader("Importance Type Comparison")
st.markdown("Compare the top features across different importance metrics.")
with st.spinner("Generating importance comparison..."):
comparison_chart = plot_xgboost_importance_comparison(model_state, top_n=15)
st.altair_chart(comparison_chart, use_container_width=True)
# Statistics
with st.expander("Model Statistics"):
st.write("**Overall Statistics:**")
col1, col2 = st.columns(2)
with col1:
st.metric("Number of Trees", model_state.attrs.get("n_trees", "N/A"))
with col2:
st.metric("Total Features", model_state.sizes.get("feature", "N/A"))
def render_rf_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for Random Forest model."""
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
st.header("🌳 Random Forest Model Analysis")
# Check if using cuML (which doesn't provide tree statistics)
is_cuml = "cuML" in model_state.attrs.get("description", "")
st.markdown(
f"""
Random Forest ensemble with **{model_state.attrs.get("n_estimators", "N/A")} trees**
(max depth: {model_state.attrs.get("max_depth", "N/A")}).
"""
)
if is_cuml:
st.info(" Using cuML GPU-accelerated Random Forest. Individual tree statistics are not available.")
# Display OOB score if available
oob_score = model_state.attrs.get("oob_score")
if oob_score is not None:
st.info(f"**Out-of-Bag Score:** {oob_score:.4f}")
# Feature importance
st.subheader("Feature Importance (Gini Importance)")
st.markdown(
"""
Random Forest uses Gini impurity to measure feature importance. Features with higher
importance values contribute more to the model's predictions.
"""
)
# Top N slider
top_n = st.slider(
"Number of top features to display",
min_value=5,
max_value=50,
value=20,
step=5,
help="Select how many of the most important features to visualize",
)
with st.spinner("Generating feature importance plot..."):
importance_chart = plot_rf_feature_importance(model_state, top_n=top_n)
st.altair_chart(importance_chart, use_container_width=True)
# Tree statistics (only if available - sklearn RF has them, cuML RF doesn't)
if not is_cuml and "tree_depths" in model_state:
from entropice.dashboard.plots.model_state import plot_rf_tree_statistics
st.subheader("Tree Structure Statistics")
st.markdown("Distribution of tree properties across the forest.")
with st.spinner("Generating tree statistics..."):
chart_depths, chart_leaves, chart_nodes = plot_rf_tree_statistics(model_state)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_depths, use_container_width=True)
with col2:
st.altair_chart(chart_leaves, use_container_width=True)
with col3:
st.altair_chart(chart_nodes, use_container_width=True)
# Statistics
with st.expander("Forest Statistics"):
st.write("**Overall Statistics:**")
depths = model_state["tree_depths"].to_pandas()
leaves = model_state["tree_n_leaves"].to_pandas()
nodes = model_state["tree_n_nodes"].to_pandas()
col1, col2, col3 = st.columns(3)
with col1:
st.write("**Tree Depths:**")
st.metric("Mean Depth", f"{depths.mean():.2f}")
st.metric("Max Depth", f"{depths.max()}")
st.metric("Min Depth", f"{depths.min()}")
with col2:
st.write("**Leaf Counts:**")
st.metric("Mean Leaves", f"{leaves.mean():.2f}")
st.metric("Max Leaves", f"{leaves.max()}")
st.metric("Min Leaves", f"{leaves.min()}")
with col3:
st.write("**Node Counts:**")
st.metric("Mean Nodes", f"{nodes.mean():.2f}")
st.metric("Max Nodes", f"{nodes.max()}")
st.metric("Min Nodes", f"{nodes.min()}")
def render_knn_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for KNN model."""
st.header("🔍 K-Nearest Neighbors Model Analysis")
st.markdown(
"""
K-Nearest Neighbors is a non-parametric, instance-based learning algorithm.
Unlike tree-based or parametric models, KNN doesn't learn feature weights or build
a model structure. Instead, it memorizes the training data and makes predictions
based on the k nearest neighbors.
"""
)
# Display model metadata
st.subheader("Model Configuration")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Number of Neighbors (k)", model_state.attrs.get("n_neighbors", "N/A"))
st.metric("Training Samples", model_state.attrs.get("n_samples_fit", "N/A"))
with col2:
st.metric("Weights", model_state.attrs.get("weights", "N/A"))
st.metric("Algorithm", model_state.attrs.get("algorithm", "N/A"))
with col3:
st.metric("Metric", model_state.attrs.get("metric", "N/A"))
st.info(
"""
**Note:** KNN doesn't have traditional feature importance or model parameters to visualize.
The model's behavior depends entirely on:
- The number of neighbors (k)
- The distance metric used
- The weighting scheme for neighbors
To understand the model better, consider visualizing the decision boundaries on a
reduced-dimensional representation of your data (e.g., using PCA or t-SNE).
"""
)
# Helper functions for embedding/era5/common features
def render_embedding_features(embedding_feature_array):
"""Render embedding feature visualizations."""
with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis")
st.markdown(
@ -214,11 +434,10 @@ def render_model_state_page():
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
st.dataframe(top_emb, width="stretch")
else:
st.info("No embedding features found in this model.")
# ERA5 features analysis (if present)
if era5_feature_array is not None:
def render_era5_features(era5_feature_array):
"""Render ERA5 feature visualizations."""
with st.container(border=True):
st.header("⛅ ERA5 Feature Analysis")
st.markdown(
@ -264,11 +483,10 @@ def render_model_state_page():
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
st.dataframe(top_era5, width="stretch")
else:
st.info("No ERA5 features found in this model.")
# Common features analysis (if present)
if common_feature_array is not None:
def render_common_features(common_feature_array):
"""Render common feature visualizations."""
with st.container(border=True):
st.header("🗺️ Common Feature Analysis")
st.markdown(
@ -318,5 +536,3 @@ def render_model_state_page():
- Negative weights indicate features that decrease the probability of the positive class
"""
)
else:
st.info("No common features found in this model.")

View file

@ -1,11 +1,12 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import streamlit as st
from entropice.dashboard.plots.colors import get_cmap
from entropice.dashboard.plots.colors import get_cmap, get_palette
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
@ -135,8 +136,6 @@ def render_parameter_distributions(results: pd.DataFrame):
# Get colormap from colors module
cmap = get_cmap("parameter_distribution")
import matplotlib.colors as mcolors
bar_color = mcolors.rgb2hex(cmap(0.5))
# Create histograms for each parameter
@ -159,37 +158,77 @@ def render_parameter_distributions(results: pd.DataFrame):
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
# Check if we have enough unique values for a histogram
n_unique = param_values.nunique()
if n_unique == 1:
# Only one unique value - show as text
value = param_values.iloc[0]
formatted = f"{value:.2e}" if value < 0.01 else f"{value:.4f}"
st.metric(param_name, formatted)
st.caption(f"All {len(param_values)} samples have the same value")
continue
# Numeric parameter - use histogram
df_plot = pd.DataFrame({param_name: param_values})
df_plot = pd.DataFrame({param_name: param_values.to_numpy()})
# Use log scale if the range spans multiple orders of magnitude
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
# Use log scale if the range spans multiple orders of magnitude OR values are very small
max_val = param_values.max()
if use_log:
# Determine number of bins based on unique values
n_bins = min(20, max(5, n_unique))
# For very small values or large ranges, use a simple bar chart instead of histogram
if n_unique <= 10:
# Few unique values - use bar chart
value_counts = param_values.value_counts().reset_index()
value_counts.columns = [param_name, "count"]
value_counts = value_counts.sort_values(param_name)
# Format x-axis values
if max_val < 0.01:
formatted_col = f"{param_name}_formatted"
value_counts[formatted_col] = value_counts[param_name].apply(lambda x: f"{x:.2e}")
chart = (
alt.Chart(df_plot)
alt.Chart(value_counts)
.mark_bar(color=bar_color)
.encode(
alt.X(
param_name,
bin=alt.Bin(maxbins=20),
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
alt.X(f"{formatted_col}:N", title=param_name, sort=None),
alt.Y("count:Q", title="Count"),
tooltip=[
alt.Tooltip(param_name, format=".2e"),
alt.Tooltip("count", title="Count"),
],
)
.properties(height=250, title=param_name)
)
else:
chart = (
alt.Chart(value_counts)
.mark_bar(color=bar_color)
.encode(
alt.X(f"{param_name}:Q", title=param_name),
alt.Y("count:Q", title="Count"),
tooltip=[
alt.Tooltip(param_name, format=".3f"),
alt.Tooltip("count", title="Count"),
],
)
.properties(height=250, title=param_name)
)
else:
# Many unique values - use binned histogram
# Avoid log scale for binning as it can cause issues
chart = (
alt.Chart(df_plot)
.mark_bar(color=bar_color)
.encode(
alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
alt.X(f"{param_name}:Q", bin=alt.Bin(maxbins=n_bins), title=param_name),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
tooltip=[
alt.Tooltip(f"{param_name}:Q", format=".2e" if max_val < 0.01 else ".3f", bin=True),
"count()",
],
)
.properties(height=250, title=param_name)
)
@ -276,7 +315,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
@ -292,7 +331,7 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
scale=alt.Scale(range=get_palette(metric, n_colors=256)),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
@ -350,13 +389,8 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
# Get colormap from colors module
cmap = get_cmap("correlation")
# Sample colors for red-blue diverging scheme
colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
import matplotlib.colors as mcolors
hex_colors = [mcolors.rgb2hex(c) for c in colors]
# Get colormap from colors module (use diverging colormap for correlation)
hex_colors = get_palette("correlation", n_colors=256)
# Create bar chart
chart = (
@ -367,7 +401,7 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
alt.Y("Parameter", sort="-x", title="Parameter"),
alt.Color(
"Correlation",
scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
scale=alt.Scale(range=hex_colors, domain=[-1, 1]),
legend=None,
),
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
@ -402,7 +436,7 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
return
# Prepare binned data
# Prepare binned data and gather parameter info
results_binned = results.copy()
bin_info = {}
@ -414,145 +448,159 @@ def render_binned_parameter_space(results: pd.DataFrame, metric: str):
continue
# Determine if we should use log bins
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100 or param_values.max() < 1.0
min_val = param_values.min()
max_val = param_values.max()
if use_log:
if min_val > 0:
value_range = max_val / min_val
use_log = value_range > 100 or max_val < 0.01
else:
use_log = False
if use_log and min_val > 0:
# Logarithmic binning for parameters spanning many orders of magnitude
log_min = np.log10(param_values.min())
log_max = np.log10(param_values.max())
n_bins = min(10, int(log_max - log_min) + 1)
log_min = np.log10(min_val)
log_max = np.log10(max_val)
n_bins = min(10, max(3, int(log_max - log_min) + 1))
bins = np.logspace(log_min, log_max, num=n_bins)
else:
# Linear binning
n_bins = min(10, int(np.sqrt(len(param_values))))
bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
n_bins = min(10, max(3, int(np.sqrt(len(param_values)))))
bins = np.linspace(min_val, max_val, num=n_bins)
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
bin_info[param_name] = {"use_log": use_log, "bins": bins}
bin_info[param_name] = {
"use_log": use_log,
"bins": bins,
"min": min_val,
"max": max_val,
"n_unique": param_values.nunique(),
}
# Get colormap from colors module
cmap = get_cmap(metric)
hex_colors = get_palette(metric, n_colors=256)
# Create binned plots for pairs of parameters
# Show the most meaningful combinations
# Get parameter names sorted by scale type (log parameters first, then linear)
param_names = [col.replace("param_", "") for col in numeric_params]
param_names_sorted = sorted(param_names, key=lambda p: (not bin_info[p]["use_log"], p))
# If we have 3+ parameters, create multiple plots
if len(param_names) >= 3:
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
# Plot first parameter vs others, binned by third
for i in range(1, min(3, len(param_names))):
x_param = param_names[0]
y_param = param_names[i]
# Create all pairwise combinations systematically
if len(param_names) == 2:
# Simple case: just one pair
x_param, y_param = param_names_sorted
_render_2d_param_plot(results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=500)
else:
# Multiple parameters: create structured plots
st.markdown(f"**Exploring {len(param_names)} parameters:** {', '.join(param_names_sorted)}")
# Find a third parameter for faceting
facet_param = None
for p in param_names:
if p != x_param and p != y_param:
facet_param = p
# Strategy: Show all combinations (including duplicates with swapped axes)
# Group by first parameter to create organized sections
for i, x_param in enumerate(param_names_sorted):
# Get all other parameters to pair with this one
other_params = [p for p in param_names_sorted if p != x_param]
if not other_params:
continue
# Create expander for each primary parameter
with st.expander(f"📊 {x_param} vs other parameters", expanded=(i == 0)):
# Create plots in rows of 2
n_others = len(other_params)
for row_idx in range(0, n_others, 2):
cols = st.columns(2)
for col_idx in range(2):
other_idx = row_idx + col_idx
if other_idx >= n_others:
break
if facet_param and f"{facet_param}_binned" in results_binned.columns:
# Create faceted plot
plot_data = results_binned[
[f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
].dropna()
y_param = other_params[other_idx]
# Convert binned column to string with sorted categories
plot_data = plot_data.sort_values(f"{facet_param}_binned")
plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
# Get colormap colors
cmap_colors = get_cmap(metric)
import matplotlib.colors as mcolors
# Sample colors from colormap
hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
f"{facet_param}_binned:N",
],
with cols[col_idx]:
_render_2d_param_plot(
results_binned, x_param, y_param, score_col, bin_info, hex_colors, metric, height=350
)
.properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
.facet(
facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
columns=4,
)
)
st.altair_chart(chart, use_container_width=True)
# Also create a simple 2D plot without faceting
def _render_2d_param_plot(
results_binned: pd.DataFrame,
x_param: str,
y_param: str,
score_col: str,
bin_info: dict,
hex_colors: list,
metric: str,
height: int = 400,
):
"""Render a 2D parameter space plot.
Args:
results_binned: DataFrame with binned results.
x_param: Name of x-axis parameter.
y_param: Name of y-axis parameter.
score_col: Column name for the score.
bin_info: Dictionary with binning information for each parameter.
hex_colors: List of hex colors for the colormap.
metric: Name of the metric being visualized.
height: Height of the plot in pixels.
"""
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
chart = (
alt.Chart(plot_data)
.mark_circle(size=80, opacity=0.6)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
],
)
.properties(height=400, title=f"{x_param} vs {y_param}")
)
st.altair_chart(chart, use_container_width=True)
elif len(param_names) == 2:
# Simple 2-parameter plot
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
x_param = param_names[0]
y_param = param_names[1]
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
if len(plot_data) == 0:
st.warning(f"No data available for {x_param} vs {y_param}")
return
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
# Determine marker size based on number of points
n_points = len(plot_data)
marker_size = 100 if n_points < 50 else 60 if n_points < 200 else 40
chart = (
alt.Chart(plot_data)
.mark_circle(size=100, opacity=0.6)
.mark_circle(size=marker_size, opacity=0.7)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
x=alt.X(
f"param_{x_param}:Q",
title=f"{x_param} {'(log scale)' if bin_info[x_param]['use_log'] else ''}",
scale=x_scale,
),
y=alt.Y(
f"param_{y_param}:Q",
title=f"{y_param} {'(log scale)' if bin_info[y_param]['use_log'] else ''}",
scale=y_scale,
),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
f"{score_col}:Q",
scale=alt.Scale(range=hex_colors),
title=metric.replace("_", " ").title(),
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
alt.Tooltip(
f"param_{x_param}:Q",
title=x_param,
format=".2e" if bin_info[x_param]["use_log"] else ".3f",
),
alt.Tooltip(
f"param_{y_param}:Q",
title=y_param,
format=".2e" if bin_info[y_param]["use_log"] else ".3f",
),
alt.Tooltip(f"{score_col}:Q", title=metric, format=".4f"),
],
)
.properties(height=500, title=f"{x_param} vs {y_param}")
.properties(
height=height,
title=f"{x_param} vs {y_param}",
)
.interactive()
)
st.altair_chart(chart, use_container_width=True)
@ -580,6 +628,10 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
# Reshape for Altair
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
# Get colormap for score evolution
evolution_cmap = get_cmap("score_evolution")
evolution_colors = [mcolors.rgb2hex(evolution_cmap(0.3)), mcolors.rgb2hex(evolution_cmap(0.7))]
# Create line chart
chart = (
alt.Chart(df_long)
@ -587,7 +639,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
.encode(
alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()),
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(range=evolution_colors)),
strokeDash=alt.StrokeDash(
"Type",
legend=None,
@ -614,6 +666,7 @@ def render_score_evolution(results: pd.DataFrame, metric: str):
st.metric("Best at Iteration", best_iter)
@st.fragment
def render_multi_metric_comparison(results: pd.DataFrame):
"""Render comparison of multiple metrics.
@ -628,8 +681,16 @@ def render_multi_metric_comparison(results: pd.DataFrame):
st.warning("Need at least 2 metrics for comparison.")
return
# Let user select two metrics to compare
col1, col2 = st.columns(2)
# Get parameter columns for color options
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
param_names = [col.replace("param_", "") for col in param_cols]
if not param_names:
st.warning("No parameters found for coloring.")
return
# Let user select two metrics and color parameter
col1, col2, col3 = st.columns(3)
with col1:
metric1 = st.selectbox(
"Select First Metric",
@ -646,21 +707,57 @@ def render_multi_metric_comparison(results: pd.DataFrame):
key="metric2_select",
)
with col3:
color_by = st.selectbox(
"Color By",
options=param_names,
index=0,
key="color_select",
)
if metric1 == metric2:
st.warning("Please select different metrics.")
return
# Create scatter plot
# Create scatter plot data
df_plot = pd.DataFrame(
{
metric1: results[f"mean_test_{metric1}"],
metric2: results[f"mean_test_{metric2}"],
"Iteration": range(len(results)),
}
)
# Get colormap from colors module
cmap = get_cmap("iteration")
# Add color parameter to dataframe
param_col = f"param_{color_by}"
df_plot[color_by] = results[param_col]
color_col = color_by
# Check if parameter is numeric and should use log scale
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
if use_log:
color_scale = alt.Scale(type="log", range=get_palette(color_by, n_colors=256))
color_format = ".2e"
else:
color_scale = alt.Scale(range=get_palette(color_by, n_colors=256))
color_format = ".3f"
else:
# Categorical parameter
color_scale = alt.Scale(scheme="category20")
color_format = None
# Build tooltip list
tooltip_list = [
alt.Tooltip(metric1, format=".4f"),
alt.Tooltip(metric2, format=".4f"),
]
if color_format:
tooltip_list.append(alt.Tooltip(color_col, format=color_format))
else:
tooltip_list.append(color_col)
chart = (
alt.Chart(df_plot)
@ -668,12 +765,8 @@ def render_multi_metric_comparison(results: pd.DataFrame):
.encode(
alt.X(metric1, title=metric1.replace("_", " ").title()),
alt.Y(metric2, title=metric2.replace("_", " ").title()),
alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
tooltip=[
alt.Tooltip(metric1, format=".4f"),
alt.Tooltip(metric2, format=".4f"),
"Iteration",
],
alt.Color(color_col, scale=color_scale, title=color_col.replace("_", " ").title()),
tooltip=tooltip_list,
)
.properties(height=500)
)

View file

@ -447,3 +447,212 @@ def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
)
return chart
# XGBoost-specific visualizations
def plot_xgboost_feature_importance(
model_state: xr.Dataset, importance_type: str = "gain", top_n: int = 20
) -> alt.Chart:
"""Plot XGBoost feature importance for a specific importance type.
Args:
model_state: The xarray Dataset containing the XGBoost model state.
importance_type: Type of importance ('weight', 'gain', 'cover', 'total_gain', 'total_cover').
top_n: Number of top features to display.
Returns:
Altair chart showing the top features by importance.
"""
# Get the importance array
importance_key = f"feature_importance_{importance_type}"
if importance_key not in model_state:
raise ValueError(f"Importance type '{importance_type}' not found in model state")
importance = model_state[importance_key].to_pandas()
# Sort and take top N
top_features = importance.nlargest(top_n).sort_values(ascending=True)
# Create DataFrame for plotting
plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
# Create bar chart
chart = (
alt.Chart(plot_data)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("importance:Q", title=f"{importance_type.replace('_', ' ').title()} Importance"),
color=alt.value("steelblue"),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
],
)
.properties(
width=600,
height=400,
title=f"Top {top_n} Features by {importance_type.replace('_', ' ').title()} Importance (XGBoost)",
)
)
return chart
def plot_xgboost_importance_comparison(model_state: xr.Dataset, top_n: int = 15) -> alt.Chart:
"""Compare different importance types for XGBoost side-by-side.
Args:
model_state: The xarray Dataset containing the XGBoost model state.
top_n: Number of top features to display.
Returns:
Altair chart comparing importance types.
"""
# Collect all importance types
importance_types = ["weight", "gain", "cover"]
all_data = []
for imp_type in importance_types:
importance_key = f"feature_importance_{imp_type}"
if importance_key in model_state:
importance = model_state[importance_key].to_pandas()
# Get top features by this importance type
top_features = importance.nlargest(top_n)
for feature, value in top_features.items():
all_data.append(
{
"feature": feature,
"importance": value,
"type": imp_type.replace("_", " ").title(),
}
)
df = pd.DataFrame(all_data)
# Create faceted bar chart
chart = (
alt.Chart(df)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("importance:Q", title="Importance Value"),
color=alt.Color("type:N", title="Importance Type", scale=alt.Scale(scheme="category10")),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("type:N", title="Type"),
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
],
)
.properties(width=250, height=300)
.facet(facet=alt.Facet("type:N", title="Importance Type"), columns=3)
)
return chart
# Random Forest-specific visualizations
def plot_rf_feature_importance(model_state: xr.Dataset, top_n: int = 20) -> alt.Chart:
"""Plot Random Forest feature importance (Gini importance).
Args:
model_state: The xarray Dataset containing the Random Forest model state.
top_n: Number of top features to display.
Returns:
Altair chart showing the top features by importance.
"""
importance = model_state["feature_importance"].to_pandas()
# Sort and take top N
top_features = importance.nlargest(top_n).sort_values(ascending=True)
# Create DataFrame for plotting
plot_data = pd.DataFrame({"feature": top_features.index, "importance": top_features.to_numpy()})
# Create bar chart
chart = (
alt.Chart(plot_data)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("importance:Q", title="Gini Importance"),
color=alt.value("forestgreen"),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("importance:Q", format=".4f", title="Importance"),
],
)
.properties(
width=600,
height=400,
title=f"Top {top_n} Features by Gini Importance (Random Forest)",
)
)
return chart
def plot_rf_tree_statistics(model_state: xr.Dataset) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Plot Random Forest tree statistics.
Args:
model_state: The xarray Dataset containing the Random Forest model state.
Returns:
Tuple of three Altair charts (depth histogram, leaves histogram, nodes histogram).
"""
# Extract tree statistics
depths = model_state["tree_depths"].to_pandas()
n_leaves = model_state["tree_n_leaves"].to_pandas()
n_nodes = model_state["tree_n_nodes"].to_pandas()
# Create DataFrames
df_depths = pd.DataFrame({"value": depths.to_numpy()})
df_leaves = pd.DataFrame({"value": n_leaves.to_numpy()})
df_nodes = pd.DataFrame({"value": n_nodes.to_numpy()})
# Histogram for tree depths
chart_depths = (
alt.Chart(df_depths)
.mark_bar()
.encode(
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Tree Depth"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("steelblue"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Depth Range")],
)
.properties(width=300, height=200, title="Distribution of Tree Depths")
)
# Histogram for number of leaves
chart_leaves = (
alt.Chart(df_leaves)
.mark_bar()
.encode(
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Leaves"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("forestgreen"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Leaves Range")],
)
.properties(width=300, height=200, title="Distribution of Leaf Counts")
)
# Histogram for number of nodes
chart_nodes = (
alt.Chart(df_nodes)
.mark_bar()
.encode(
x=alt.X("value:Q", bin=alt.Bin(maxbins=20), title="Number of Nodes"),
y=alt.Y("count()", title="Number of Trees"),
color=alt.value("darkorange"),
tooltip=[alt.Tooltip("count()", title="Count"), alt.Tooltip("value:Q", bin=True, title="Nodes Range")],
)
.properties(width=300, height=200, title="Distribution of Node Counts")
)
return chart_depths, chart_leaves, chart_nodes

View file

@ -761,6 +761,117 @@ def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
st.warning("No valid data available for selected parameters")
@st.fragment
def render_areas_map(grid_gdf: gpd.GeoDataFrame, grid: str):
"""Render interactive pydeck map for grid cell areas.
Args:
grid_gdf: GeoDataFrame with cell_id, geometry, cell_area, land_area, water_area, land_ratio.
grid: Grid type ('hex' or 'healpix').
"""
st.subheader("🗺️ Grid Cell Areas Distribution")
# Controls
col1, col2 = st.columns([3, 1])
with col1:
area_metric = st.selectbox(
"Area Metric",
options=["cell_area", "land_area", "water_area", "land_ratio"],
format_func=lambda x: x.replace("_", " ").title(),
key="areas_metric",
)
with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="areas_map_opacity")
# Create GeoDataFrame
gdf = grid_gdf.copy()
# Convert to WGS84 first
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix geometries after CRS conversion
if grid == "hex":
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
# Get values for the selected metric
values = gdf_wgs84[area_metric].to_numpy()
# Normalize values for color mapping
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
# Apply colormap based on metric type
if area_metric == "land_ratio":
cmap = get_cmap("terrain") # Different colormap for ratio
else:
cmap = get_cmap("terrain")
colors = [cmap(val) for val in normalized]
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
# Set elevation based on normalized values for 3D visualization
gdf_wgs84["elevation"] = normalized
# Create GeoJSON
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"cell_area": f"{float(row['cell_area']):.2f}",
"land_area": f"{float(row['land_area']):.2f}",
"water_area": f"{float(row['water_area']):.2f}",
"land_ratio": f"{float(row['land_ratio']):.2%}",
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]),
},
}
geojson_data.append(feature)
# Create pydeck layer with 3D elevation
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=True,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
get_elevation="properties.elevation",
elevation_scale=500000,
line_width_min_pixels=0.5,
pickable=True,
)
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=45, bearing=0)
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Cell Area:</b> {cell_area} km²<br/>"
"<b>Land Area:</b> {land_area} km²<br/>"
"<b>Water Area:</b> {water_area} km²<br/>"
"<b>Land Ratio:</b> {land_ratio}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
st.pydeck_chart(deck)
# Show statistics
st.caption(f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values):.2f} | Std: {np.nanstd(values):.2f}")
# Show additional info
st.info("💡 3D elevation represents normalized values. Rotate the map by holding Ctrl/Cmd and dragging.")
@st.fragment
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
"""Render interactive pydeck map for ERA5 climate data.

View file

@ -2,6 +2,7 @@
import streamlit as st
from entropice import grids
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
render_alphaearth_overview,
@ -9,6 +10,7 @@ from entropice.dashboard.plots.source_data import (
render_arcticdem_map,
render_arcticdem_overview,
render_arcticdem_plots,
render_areas_map,
render_era5_map,
render_era5_overview,
render_era5_plots,
@ -184,7 +186,7 @@ def render_training_data_page():
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels"]
tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member
for member in ensemble.members:
@ -228,8 +230,50 @@ def render_training_data_page():
render_spatial_map(train_data_dict)
# Areas tab
with tabs[1]:
st.markdown("### Grid Cell Areas and Land/Water Distribution")
st.markdown(
"This visualization shows the spatial distribution of cell areas, land areas, "
"water areas, and land ratio across the grid. The grid has been filtered to "
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
"with >10% land coverage."
)
# Load grid data
grid_gdf = grids.open(ensemble.grid, ensemble.level)
st.success(
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
# Show summary statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_areas_map(grid_gdf, ensemble.grid)
# AlphaEarth tab
tab_idx = 1
tab_idx = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### AlphaEarth Embeddings Analysis")

View file

@ -29,10 +29,9 @@ class TrainingResult:
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
state_file = result_path / "best_estimator_state.nc"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]):
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
created_at = result_path.stat().st_ctime
@ -96,9 +95,9 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
"""
return {
"binary": e.create_cat_training_dataset("binary"),
"count": e.create_cat_training_dataset("count"),
"density": e.create_cat_training_dataset("density"),
"binary": e.create_cat_training_dataset("binary", device="cpu"),
"count": e.create_cat_training_dataset("count", device="cpu"),
"density": e.create_cat_training_dataset("density", device="cpu"),
}

View file

@ -14,12 +14,15 @@ Naming conventions:
import hashlib
import json
from collections.abc import Generator
from dataclasses import asdict, dataclass, field
from functools import cached_property, lru_cache
from typing import Literal
from typing import Literal, TypedDict
import cupy as cp
import cyclopts
import geopandas as gpd
import numpy as np
import pandas as pd
import seaborn as sns
import torch
@ -110,8 +113,8 @@ def bin_values(
@dataclass(frozen=True, eq=False)
class DatasetLabels:
binned: pd.Series
train: torch.Tensor
test: torch.Tensor
train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray
raw_values: pd.Series
@cached_property
@ -135,8 +138,8 @@ class DatasetLabels:
@dataclass(frozen=True, eq=False)
class DatasetInputs:
data: pd.DataFrame
train: torch.Tensor
test: torch.Tensor
train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray
@dataclass(frozen=True)
@ -151,6 +154,13 @@ class CategoricalTrainingDataset:
return len(self.z)
class DatasetStats(TypedDict):
target: str
num_target_samples: int
members: dict[str, dict[str, object]]
total_features: int
@cyclopts.Parameter("*")
@dataclass(frozen=True)
class DatasetEnsemble:
@ -283,15 +293,15 @@ class DatasetEnsemble:
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
return arcticdem_df
def get_stats(self) -> dict:
def get_stats(self) -> DatasetStats:
"""Get dataset statistics.
Returns:
dict: Dictionary containing target stats, member stats, and total features count.
DatasetStats: Dictionary containing target stats, member stats, and total features count.
"""
targets = self._read_target()
stats = {
stats: DatasetStats = {
"target": self.target,
"num_target_samples": len(targets),
"members": {},
@ -365,11 +375,64 @@ class DatasetEnsemble:
print(f"Saved dataset to cache at {cache_file}.")
return dataset
def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset:
def create_batches(
self,
batch_size: int,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> Generator[pd.DataFrame]:
targets = self._read_target()
if len(targets) == 0:
raise ValueError("No target samples found.")
elif len(targets) < batch_size:
yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
return
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
for i in range(0, len(targets), batch_size):
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(
self.id(), subset=filter_target_col, batch=(i, i + batch_size)
)
if cache_mode == "r" and cache_file.exists():
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
yield dataset
else:
targets_batch = targets.iloc[i : i + batch_size]
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets_batch, era5_agg))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets_batch))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets_batch))
else:
raise NotImplementedError(f"Member {member} not implemented.")
dataset = targets_batch.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]:
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
yield dataset
def create_cat_training_dataset(
self, task: Task, device: Literal["cpu", "cuda", "torch"]
) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
@ -414,10 +477,23 @@ class DatasetEnsemble:
split.loc[test_idx] = "test"
split = split.astype("category")
X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0)
X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0)
y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0)
y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0)
X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64")
X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64")
y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64")
y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64")
if device == "cuda":
X_train = cp.asarray(X_train)
X_test = cp.asarray(X_test)
y_train = cp.asarray(y_train)
y_test = cp.asarray(y_test)
elif device == "torch":
X_train = torch.from_numpy(X_train)
X_test = torch.from_numpy(X_test)
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)
else:
assert device == "cpu", "Invalid device specified."
return CategoricalTrainingDataset(
dataset=dataset.to_crs("EPSG:4326"),

View file

@ -1,6 +1,7 @@
# ruff: noqa: N806
"""Inference runs on trained models."""
import cupy as cp
import geopandas as gpd
import pandas as pd
import torch
@ -32,35 +33,32 @@ def predict_proba(
list: A list of predicted probabilities for each cell.
"""
data = e.create()
print(f"Predicting probabilities for {len(data)} cells...")
# Predict in batches to avoid memory issues
batch_size = 10_000
batch_size = 10000
preds = []
for batch in e.create_batches(batch_size=batch_size):
cols_to_drop = ["geometry"]
if e.target == "darts_mllabels":
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
for i in range(0, len(data), batch_size):
batch = data.iloc[i : i + batch_size]
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
X_batch = batch.drop(columns=cols_to_drop).dropna()
cell_ids = X_batch.index.to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float64")
X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = clf.predict(X_batch)
if isinstance(batch_preds, cp.ndarray):
batch_preds = batch_preds.get()
elif torch.is_tensor(batch_preds):
batch_preds = batch_preds.cpu().numpy()
batch_preds = gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"predicted_class": [classes[i] for i in batch_preds],
"geometry": cell_geoms,
},
crs="epsg:3413",
)
).set_crs(epsg=3413, inplace=False)
preds.append(batch_preds)
preds = gpd.GeoDataFrame(pd.concat(preds))
return preds
return gpd.GeoDataFrame(pd.concat(preds, ignore_index=True))

View file

@ -9,7 +9,7 @@ from typing import Literal
DATA_DIR = (
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
)
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
# DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures")
@ -110,7 +110,9 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file
def get_dataset_cache(eid: str, subset: str | None = None) -> Path:
def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int] | None = None) -> Path:
if batch is not None:
eid = f"{eid}_batch{batch[0]}-{batch[1]}"
if subset is None:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
else:

View file

@ -4,6 +4,7 @@ import pickle
from dataclasses import asdict, dataclass
from typing import Literal
import cupy as cp
import cyclopts
import pandas as pd
import toml
@ -14,7 +15,6 @@ from entropy import ESPAClassifier
from rich import pretty, traceback
from scipy.stats import loguniform, randint
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV
from stopuhr import stopwatch
from xgboost.sklearn import XGBClassifier
@ -26,7 +26,9 @@ from entropice.paths import get_cv_results_dir
traceback.install()
pretty.install()
set_config(array_api_dispatch=True)
# Disabled array_api_dispatch to avoid namespace conflicts between NumPy and CuPy
# when using XGBoost with device="cuda"
# set_config(array_api_dispatch=True)
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
@ -85,8 +87,8 @@ def _create_clf(
objective="multi:softprob" if settings.task != "binary" else "binary:logistic",
eval_metric="mlogloss" if settings.task != "binary" else "logloss",
random_state=42,
tree_method="gpu_hist",
device="cuda",
tree_method="hist",
device="gpu", # Using CPU to avoid CuPy/NumPy namespace conflicts in sklearn metrics
)
fit_params = {}
elif settings.model == "rf":
@ -98,11 +100,10 @@ def _create_clf(
fit_params = {}
elif settings.model == "knn":
param_grid = {
"n_neighbors": randint(3, 15),
"n_neighbors": randint(10, 200),
"weights": ["uniform", "distance"],
"algorithm": ["brute", "kd_tree", "ball_tree"],
}
clf = KNeighborsClassifier(random_state=42)
clf = KNeighborsClassifier()
fit_params = {}
else:
raise ValueError(f"Unknown model: {settings.model}")
@ -148,8 +149,9 @@ def random_cv(
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
"""
device = "torch" if settings.model in ["espa"] else "cuda"
print("Creating training data...")
training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task)
training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task, device=device)
clf, param_grid, fit_params = _create_clf(settings)
print(f"Using model: {settings.model} with parameters: {param_grid}")
@ -159,7 +161,7 @@ def random_cv(
clf,
param_grid,
n_iter=settings.n_iter,
n_jobs=8,
n_jobs=1,
cv=cv,
random_state=42,
verbose=10,
@ -169,14 +171,25 @@ def random_cv(
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
search.fit(training_data.X.train, training_data.y.train, **fit_params)
y_train = (
training_data.y.train.get()
if settings.model == "xgboost" and isinstance(training_data.y.train, cp.ndarray)
else training_data.y.train
)
search.fit(training_data.X.train, y_train, **fit_params)
print("Best parameters combination found:")
best_parameters = search.best_estimator_.get_params()
best_estimator = search.best_estimator_
best_parameters = best_estimator.get_params()
for param_name in sorted(param_grid.keys()):
print(f"{param_name}: {best_parameters[param_name]}")
test_accuracy = search.score(training_data.X.test, training_data.y.test)
y_test = (
training_data.y.test.get()
if settings.model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
test_accuracy = search.score(training_data.X.test, y_test)
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
@ -207,7 +220,7 @@ def random_cv(
best_model_file = results_dir / "best_estimator_model.pkl"
print(f"Storing best estimator model to {best_model_file}")
with open(best_model_file, "wb") as f:
pickle.dump(search.best_estimator_, f, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(best_estimator, f, protocol=pickle.HIGHEST_PROTOCOL)
# Store the search results
results = pd.DataFrame(search.cv_results_)
@ -220,10 +233,10 @@ def random_cv(
results.to_parquet(results_file)
# Get the inner state of the best estimator
if settings.model == "espa":
best_estimator = search.best_estimator_
# Annotate the state with xarray metadata
features = training_data.X.data.columns.tolist()
if settings.model == "espa":
# Annotate the state with xarray metadata
boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(),
@ -260,9 +273,154 @@ def random_cv(
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
elif settings.model == "xgboost":
# Extract XGBoost-specific information
# Get the underlying booster
booster = best_estimator.get_booster()
# Feature importance with different importance types
importance_weight = booster.get_score(importance_type="weight")
importance_gain = booster.get_score(importance_type="gain")
importance_cover = booster.get_score(importance_type="cover")
importance_total_gain = booster.get_score(importance_type="total_gain")
importance_total_cover = booster.get_score(importance_type="total_cover")
# Create aligned arrays for all features (including zero-importance)
def align_importance(importance_dict, features):
"""Align importance dict to feature list, filling missing with 0."""
return [importance_dict.get(f, 0.0) for f in features]
feature_importance_weight = xr.DataArray(
align_importance(importance_weight, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_weight",
attrs={"description": "Number of times a feature is used to split the data across all trees."},
)
feature_importance_gain = xr.DataArray(
align_importance(importance_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_gain",
attrs={"description": "Average gain across all splits the feature is used in."},
)
feature_importance_cover = xr.DataArray(
align_importance(importance_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_cover",
attrs={"description": "Average coverage across all splits the feature is used in."},
)
feature_importance_total_gain = xr.DataArray(
align_importance(importance_total_gain, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_gain",
attrs={"description": "Total gain across all splits the feature is used in."},
)
feature_importance_total_cover = xr.DataArray(
align_importance(importance_total_cover, features),
dims=["feature"],
coords={"feature": features},
name="feature_importance_total_cover",
attrs={"description": "Total coverage across all splits the feature is used in."},
)
# Store tree information
n_trees = booster.num_boosted_rounds()
state = xr.Dataset(
{
"feature_importance_weight": feature_importance_weight,
"feature_importance_gain": feature_importance_gain,
"feature_importance_cover": feature_importance_cover,
"feature_importance_total_gain": feature_importance_total_gain,
"feature_importance_total_cover": feature_importance_total_cover,
},
attrs={
"description": "Inner state of the best XGBClassifier from RandomizedSearchCV.",
"n_trees": n_trees,
"objective": str(best_estimator.objective),
},
)
state_file = results_dir / "best_estimator_state.nc"
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
elif settings.model == "rf":
# Extract Random Forest-specific information
# Note: cuML's RandomForestClassifier doesn't expose individual trees (estimators_)
# like sklearn does, so we can only extract feature importances and model parameters
# Feature importances (Gini importance)
feature_importances = best_estimator.feature_importances_
feature_importance = xr.DataArray(
feature_importances,
dims=["feature"],
coords={"feature": features},
name="feature_importance",
attrs={"description": "Gini importance (impurity-based feature importance)."},
)
# cuML RF doesn't expose individual trees, so we store model parameters instead
n_estimators = best_estimator.n_estimators
max_depth = best_estimator.max_depth
# OOB score if available
oob_score = None
if hasattr(best_estimator, "oob_score_") and best_estimator.oob_score:
oob_score = float(best_estimator.oob_score_)
# cuML RandomForest doesn't provide per-tree statistics like sklearn
# Store what we have: feature importances and model configuration
attrs = {
"description": "Inner state of the best RandomForestClassifier from RandomizedSearchCV (cuML).",
"n_estimators": int(n_estimators),
"note": "cuML RandomForest does not expose individual tree statistics like sklearn",
}
# Only add optional attributes if they have values
if max_depth is not None:
attrs["max_depth"] = int(max_depth)
if oob_score is not None:
attrs["oob_score"] = oob_score
state = xr.Dataset(
{
"feature_importance": feature_importance,
},
attrs=attrs,
)
state_file = results_dir / "best_estimator_state.nc"
print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
elif settings.model == "knn":
# KNN doesn't have traditional feature importance
# Store information about the training data and neighbors
n_neighbors = best_estimator.n_neighbors
# We can't extract meaningful "state" from KNN in the same way
# but we can store metadata about the model
state = xr.Dataset(
attrs={
"description": "Metadata of the best KNeighborsClassifier from RandomizedSearchCV.",
"n_neighbors": n_neighbors,
"weights": str(best_estimator.weights),
"algorithm": str(best_estimator.algorithm),
"metric": str(best_estimator.metric),
"n_samples_fit": best_estimator.n_samples_fit_,
},
)
state_file = results_dir / "best_estimator_state.nc"
print(f"Storing best estimator metadata to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
# Predict probabilities for all cells
print("Predicting probabilities for all cells...")
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels)
print(f"Predicted probabilities DataFrame with {len(preds)} entries.")
preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file)