entropice/src/entropice/dashboard/sections/experiment_feature_importance.py

331 lines
12 KiB
Python

"""Feature importance analysis section for experiment comparison."""
import pandas as pd
import streamlit as st
from entropice.dashboard.plots.experiment_comparison import (
create_data_source_importance_bars,
create_feature_consistency_plot,
create_feature_importance_by_grid_level,
)
from entropice.dashboard.utils.loaders import (
AutogluonTrainingResult,
TrainingResult,
)
def _extract_feature_importance_from_results(
training_results: list[TrainingResult],
) -> pd.DataFrame:
"""Extract feature importance from all training results.
Args:
training_results: List of TrainingResult objects
Returns:
DataFrame with columns: feature, importance, model, grid, level, task, target
"""
records = []
for tr in training_results:
# Load model state if available
model_state = tr.load_model_state()
if model_state is None:
continue
info = tr.display_info
# Extract feature importance based on available data
if "feature_importance" in model_state.data_vars:
# eSPA or similar models with direct feature importance
importance_data = model_state["feature_importance"]
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
importance_value = float(importance_data.isel(feature=feature_idx).values)
records.append(
{
"feature": str(feature_name),
"importance": importance_value,
"model": info.model,
"grid": info.grid,
"level": info.level,
"task": info.task,
"target": info.target,
}
)
elif "gain" in model_state.data_vars:
# XGBoost-style feature importance
gain_data = model_state["gain"]
for feature_idx, feature_name in enumerate(gain_data.coords["feature"].values):
importance_value = float(gain_data.isel(feature=feature_idx).values)
records.append(
{
"feature": str(feature_name),
"importance": importance_value,
"model": info.model,
"grid": info.grid,
"level": info.level,
"task": info.task,
"target": info.target,
}
)
elif "feature_importances_" in model_state.data_vars:
# Random Forest style
importance_data = model_state["feature_importances_"]
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
importance_value = float(importance_data.isel(feature=feature_idx).values)
records.append(
{
"feature": str(feature_name),
"importance": importance_value,
"model": info.model,
"grid": info.grid,
"level": info.level,
"task": info.task,
"target": info.target,
}
)
return pd.DataFrame(records)
def _extract_feature_importance_from_autogluon(
autogluon_results: list[AutogluonTrainingResult],
) -> pd.DataFrame:
"""Extract feature importance from AutoGluon results.
Args:
autogluon_results: List of AutogluonTrainingResult objects
Returns:
DataFrame with columns: feature, importance, model, grid, level, task, target
"""
records = []
for ag in autogluon_results:
if ag.feature_importance is None:
continue
info = ag.display_info
# AutoGluon feature importance is already a DataFrame with features as index
for feature_name, importance_value in ag.feature_importance["importance"].items():
records.append(
{
"feature": str(feature_name),
"importance": float(importance_value),
"model": "autogluon",
"grid": info.grid,
"level": info.level,
"task": info.task,
"target": info.target,
}
)
return pd.DataFrame(records)
def _categorize_feature(feature_name: str) -> str:
"""Categorize feature by data source."""
feature_lower = feature_name.lower()
if feature_lower.startswith("arcticdem"):
return "ArcticDEM"
if feature_lower.startswith("era5"):
return "ERA5"
if feature_lower.startswith("embeddings") or feature_lower.startswith("alphaearth"):
return "Embeddings"
return "General"
def _prepare_feature_importance_data(
training_results: list[TrainingResult],
autogluon_results: list[AutogluonTrainingResult],
) -> pd.DataFrame | None:
"""Extract and prepare feature importance data.
Args:
training_results: List of RandomSearchCV training results
autogluon_results: List of AutoGluon training results
Returns:
DataFrame with feature importance data or None if no data available
"""
fi_df_cv = _extract_feature_importance_from_results(training_results)
fi_df_ag = _extract_feature_importance_from_autogluon(autogluon_results)
if fi_df_cv.empty and fi_df_ag.empty:
return None
# Combine both
fi_df = pd.concat([fi_df_cv, fi_df_ag], ignore_index=True)
# Add data source categorization
fi_df["data_source"] = fi_df["feature"].apply(_categorize_feature)
fi_df["grid_level"] = fi_df["grid"] + "_" + fi_df["level"].astype(str)
return fi_df
@st.fragment
def render_feature_importance_analysis(
training_results: list[TrainingResult],
autogluon_results: list[AutogluonTrainingResult],
):
"""Render feature importance analysis section.
Args:
training_results: List of RandomSearchCV training results
autogluon_results: List of AutoGluon training results
"""
st.header("🔍 Feature Importance Analysis")
st.markdown(
"""
This section analyzes which features are most important across different
models, grid levels, tasks, and targets.
"""
)
# Extract feature importance
with st.spinner("Extracting feature importance from training results..."):
fi_df = _prepare_feature_importance_data(training_results, autogluon_results)
if fi_df is None:
st.warning("No feature importance data available. Model state files may be missing.")
return
st.success(f"Extracted feature importance from {len(fi_df)} feature-model combinations")
# Filters
st.subheader("Filters")
col1, col2, col3 = st.columns(3)
with col1:
# Task filter
available_tasks = ["All", *sorted(fi_df["task"].unique().tolist())]
selected_task = st.selectbox("Task", options=available_tasks, index=0, key="fi_task_filter")
with col2:
# Target filter
available_targets = ["All", *sorted(fi_df["target"].unique().tolist())]
selected_target = st.selectbox("Target Dataset", options=available_targets, index=0, key="fi_target_filter")
with col3:
# Top N features
top_n_features = st.number_input("Top N Features", min_value=5, max_value=50, value=15, key="top_n_features")
# Apply filters
filtered_fi_df = fi_df.copy()
if selected_task != "All":
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["task"] == selected_task]
if selected_target != "All":
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["target"] == selected_target]
if len(filtered_fi_df) == 0:
st.warning("No feature importance data available for the selected filters.")
return
# Section 1: Top features by grid level
st.subheader("Top Features by Grid Level")
try:
fig = create_feature_importance_by_grid_level(filtered_fi_df, top_n=top_n_features)
st.plotly_chart(fig, width="stretch")
except Exception as e:
st.error(f"Could not create feature importance by grid level plot: {e}")
# Show detailed breakdown in expander
grid_levels = sorted(filtered_fi_df["grid_level"].unique())
with st.expander("Show Detailed Breakdown by Grid Level", expanded=False):
for grid_level in grid_levels:
grid_data = filtered_fi_df[filtered_fi_df["grid_level"] == grid_level]
# Get top features for this grid level
top_features_grid = (
grid_data.groupby("feature")["importance"].mean().reset_index().nlargest(top_n_features, "importance")
)
st.markdown(f"**{grid_level.replace('_', '-').title()}**")
# Create display dataframe with data source
display_df = top_features_grid.merge(
grid_data[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
)
display_df.columns = ["Feature", "Mean Importance", "Data Source"]
display_df = display_df.sort_values("Mean Importance", ascending=False)
st.dataframe(display_df, width="stretch", hide_index=True)
# Section 2: Feature importance consistency across models
st.subheader("Feature Importance Consistency Across Models")
st.markdown(
"""
**Coefficient of Variation (CV)**: Lower values indicate more consistent importance across models.
High CV suggests the feature's importance varies significantly between different models.
"""
)
try:
fig = create_feature_consistency_plot(filtered_fi_df, top_n=top_n_features)
st.plotly_chart(fig, width="stretch")
except Exception as e:
st.error(f"Could not create feature consistency plot: {e}")
# Show detailed statistics in expander
with st.expander("Show Detailed Statistics", expanded=False):
# Get top features overall
overall_top_features = (
filtered_fi_df.groupby("feature")["importance"]
.mean()
.reset_index()
.nlargest(top_n_features, "importance")["feature"]
.tolist()
)
# Calculate variance in importance across models for each feature
feature_variance = (
filtered_fi_df[filtered_fi_df["feature"].isin(overall_top_features)]
.groupby("feature")["importance"]
.agg(["mean", "std", "min", "max"])
.reset_index()
)
feature_variance["coefficient_of_variation"] = feature_variance["std"] / feature_variance["mean"]
feature_variance = feature_variance.sort_values("mean", ascending=False)
# Add data source
feature_variance = feature_variance.merge(
filtered_fi_df[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
)
feature_variance.columns = ["Feature", "Mean", "Std Dev", "Min", "Max", "CV", "Data Source"]
st.dataframe(
feature_variance[["Feature", "Data Source", "Mean", "Std Dev", "CV"]],
width="stretch",
hide_index=True,
)
# Section 3: Feature importance by data source
st.subheader("Feature Importance by Data Source")
try:
fig = create_data_source_importance_bars(filtered_fi_df)
st.plotly_chart(fig, width="stretch")
except Exception as e:
st.error(f"Could not create data source importance chart: {e}")
# Show detailed table in expander
with st.expander("Show Data Source Statistics", expanded=False):
# Aggregate importance by data source
source_importance = (
filtered_fi_df.groupby("data_source")["importance"].agg(["sum", "mean", "count"]).reset_index()
)
source_importance.columns = ["Data Source", "Total Importance", "Mean Importance", "Feature Count"]
source_importance = source_importance.sort_values("Total Importance", ascending=False)
st.dataframe(source_importance, width="stretch", hide_index=True)