331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""Feature importance analysis section for experiment comparison."""
|
|
|
|
import pandas as pd
|
|
import streamlit as st
|
|
|
|
from entropice.dashboard.plots.experiment_comparison import (
|
|
create_data_source_importance_bars,
|
|
create_feature_consistency_plot,
|
|
create_feature_importance_by_grid_level,
|
|
)
|
|
from entropice.dashboard.utils.loaders import (
|
|
AutogluonTrainingResult,
|
|
TrainingResult,
|
|
)
|
|
|
|
|
|
def _extract_feature_importance_from_results(
|
|
training_results: list[TrainingResult],
|
|
) -> pd.DataFrame:
|
|
"""Extract feature importance from all training results.
|
|
|
|
Args:
|
|
training_results: List of TrainingResult objects
|
|
|
|
Returns:
|
|
DataFrame with columns: feature, importance, model, grid, level, task, target
|
|
|
|
"""
|
|
records = []
|
|
|
|
for tr in training_results:
|
|
# Load model state if available
|
|
model_state = tr.load_model_state()
|
|
if model_state is None:
|
|
continue
|
|
|
|
info = tr.display_info
|
|
|
|
# Extract feature importance based on available data
|
|
if "feature_importance" in model_state.data_vars:
|
|
# eSPA or similar models with direct feature importance
|
|
importance_data = model_state["feature_importance"]
|
|
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
|
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
|
records.append(
|
|
{
|
|
"feature": str(feature_name),
|
|
"importance": importance_value,
|
|
"model": info.model,
|
|
"grid": info.grid,
|
|
"level": info.level,
|
|
"task": info.task,
|
|
"target": info.target,
|
|
}
|
|
)
|
|
elif "gain" in model_state.data_vars:
|
|
# XGBoost-style feature importance
|
|
gain_data = model_state["gain"]
|
|
for feature_idx, feature_name in enumerate(gain_data.coords["feature"].values):
|
|
importance_value = float(gain_data.isel(feature=feature_idx).values)
|
|
records.append(
|
|
{
|
|
"feature": str(feature_name),
|
|
"importance": importance_value,
|
|
"model": info.model,
|
|
"grid": info.grid,
|
|
"level": info.level,
|
|
"task": info.task,
|
|
"target": info.target,
|
|
}
|
|
)
|
|
elif "feature_importances_" in model_state.data_vars:
|
|
# Random Forest style
|
|
importance_data = model_state["feature_importances_"]
|
|
for feature_idx, feature_name in enumerate(importance_data.coords["feature"].values):
|
|
importance_value = float(importance_data.isel(feature=feature_idx).values)
|
|
records.append(
|
|
{
|
|
"feature": str(feature_name),
|
|
"importance": importance_value,
|
|
"model": info.model,
|
|
"grid": info.grid,
|
|
"level": info.level,
|
|
"task": info.task,
|
|
"target": info.target,
|
|
}
|
|
)
|
|
|
|
return pd.DataFrame(records)
|
|
|
|
|
|
def _extract_feature_importance_from_autogluon(
|
|
autogluon_results: list[AutogluonTrainingResult],
|
|
) -> pd.DataFrame:
|
|
"""Extract feature importance from AutoGluon results.
|
|
|
|
Args:
|
|
autogluon_results: List of AutogluonTrainingResult objects
|
|
|
|
Returns:
|
|
DataFrame with columns: feature, importance, model, grid, level, task, target
|
|
|
|
"""
|
|
records = []
|
|
|
|
for ag in autogluon_results:
|
|
if ag.feature_importance is None:
|
|
continue
|
|
|
|
info = ag.display_info
|
|
|
|
# AutoGluon feature importance is already a DataFrame with features as index
|
|
for feature_name, importance_value in ag.feature_importance["importance"].items():
|
|
records.append(
|
|
{
|
|
"feature": str(feature_name),
|
|
"importance": float(importance_value),
|
|
"model": "autogluon",
|
|
"grid": info.grid,
|
|
"level": info.level,
|
|
"task": info.task,
|
|
"target": info.target,
|
|
}
|
|
)
|
|
|
|
return pd.DataFrame(records)
|
|
|
|
|
|
def _categorize_feature(feature_name: str) -> str:
|
|
"""Categorize feature by data source."""
|
|
feature_lower = feature_name.lower()
|
|
if feature_lower.startswith("arcticdem"):
|
|
return "ArcticDEM"
|
|
if feature_lower.startswith("era5"):
|
|
return "ERA5"
|
|
if feature_lower.startswith("embeddings") or feature_lower.startswith("alphaearth"):
|
|
return "Embeddings"
|
|
return "General"
|
|
|
|
|
|
def _prepare_feature_importance_data(
|
|
training_results: list[TrainingResult],
|
|
autogluon_results: list[AutogluonTrainingResult],
|
|
) -> pd.DataFrame | None:
|
|
"""Extract and prepare feature importance data.
|
|
|
|
Args:
|
|
training_results: List of RandomSearchCV training results
|
|
autogluon_results: List of AutoGluon training results
|
|
|
|
Returns:
|
|
DataFrame with feature importance data or None if no data available
|
|
|
|
"""
|
|
fi_df_cv = _extract_feature_importance_from_results(training_results)
|
|
fi_df_ag = _extract_feature_importance_from_autogluon(autogluon_results)
|
|
|
|
if fi_df_cv.empty and fi_df_ag.empty:
|
|
return None
|
|
|
|
# Combine both
|
|
fi_df = pd.concat([fi_df_cv, fi_df_ag], ignore_index=True)
|
|
|
|
# Add data source categorization
|
|
fi_df["data_source"] = fi_df["feature"].apply(_categorize_feature)
|
|
fi_df["grid_level"] = fi_df["grid"] + "_" + fi_df["level"].astype(str)
|
|
|
|
return fi_df
|
|
|
|
|
|
@st.fragment
|
|
def render_feature_importance_analysis(
|
|
training_results: list[TrainingResult],
|
|
autogluon_results: list[AutogluonTrainingResult],
|
|
):
|
|
"""Render feature importance analysis section.
|
|
|
|
Args:
|
|
training_results: List of RandomSearchCV training results
|
|
autogluon_results: List of AutoGluon training results
|
|
|
|
"""
|
|
st.header("🔍 Feature Importance Analysis")
|
|
|
|
st.markdown(
|
|
"""
|
|
This section analyzes which features are most important across different
|
|
models, grid levels, tasks, and targets.
|
|
"""
|
|
)
|
|
|
|
# Extract feature importance
|
|
with st.spinner("Extracting feature importance from training results..."):
|
|
fi_df = _prepare_feature_importance_data(training_results, autogluon_results)
|
|
|
|
if fi_df is None:
|
|
st.warning("No feature importance data available. Model state files may be missing.")
|
|
return
|
|
|
|
st.success(f"Extracted feature importance from {len(fi_df)} feature-model combinations")
|
|
|
|
# Filters
|
|
st.subheader("Filters")
|
|
col1, col2, col3 = st.columns(3)
|
|
|
|
with col1:
|
|
# Task filter
|
|
available_tasks = ["All", *sorted(fi_df["task"].unique().tolist())]
|
|
selected_task = st.selectbox("Task", options=available_tasks, index=0, key="fi_task_filter")
|
|
|
|
with col2:
|
|
# Target filter
|
|
available_targets = ["All", *sorted(fi_df["target"].unique().tolist())]
|
|
selected_target = st.selectbox("Target Dataset", options=available_targets, index=0, key="fi_target_filter")
|
|
|
|
with col3:
|
|
# Top N features
|
|
top_n_features = st.number_input("Top N Features", min_value=5, max_value=50, value=15, key="top_n_features")
|
|
|
|
# Apply filters
|
|
filtered_fi_df = fi_df.copy()
|
|
if selected_task != "All":
|
|
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["task"] == selected_task]
|
|
if selected_target != "All":
|
|
filtered_fi_df = filtered_fi_df.loc[filtered_fi_df["target"] == selected_target]
|
|
|
|
if len(filtered_fi_df) == 0:
|
|
st.warning("No feature importance data available for the selected filters.")
|
|
return
|
|
|
|
# Section 1: Top features by grid level
|
|
st.subheader("Top Features by Grid Level")
|
|
|
|
try:
|
|
fig = create_feature_importance_by_grid_level(filtered_fi_df, top_n=top_n_features)
|
|
st.plotly_chart(fig, width="stretch")
|
|
except Exception as e:
|
|
st.error(f"Could not create feature importance by grid level plot: {e}")
|
|
|
|
# Show detailed breakdown in expander
|
|
grid_levels = sorted(filtered_fi_df["grid_level"].unique())
|
|
|
|
with st.expander("Show Detailed Breakdown by Grid Level", expanded=False):
|
|
for grid_level in grid_levels:
|
|
grid_data = filtered_fi_df[filtered_fi_df["grid_level"] == grid_level]
|
|
|
|
# Get top features for this grid level
|
|
top_features_grid = (
|
|
grid_data.groupby("feature")["importance"].mean().reset_index().nlargest(top_n_features, "importance")
|
|
)
|
|
|
|
st.markdown(f"**{grid_level.replace('_', '-').title()}**")
|
|
|
|
# Create display dataframe with data source
|
|
display_df = top_features_grid.merge(
|
|
grid_data[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
|
)
|
|
display_df.columns = ["Feature", "Mean Importance", "Data Source"]
|
|
display_df = display_df.sort_values("Mean Importance", ascending=False)
|
|
|
|
st.dataframe(display_df, width="stretch", hide_index=True)
|
|
|
|
# Section 2: Feature importance consistency across models
|
|
st.subheader("Feature Importance Consistency Across Models")
|
|
|
|
st.markdown(
|
|
"""
|
|
**Coefficient of Variation (CV)**: Lower values indicate more consistent importance across models.
|
|
High CV suggests the feature's importance varies significantly between different models.
|
|
"""
|
|
)
|
|
|
|
try:
|
|
fig = create_feature_consistency_plot(filtered_fi_df, top_n=top_n_features)
|
|
st.plotly_chart(fig, width="stretch")
|
|
except Exception as e:
|
|
st.error(f"Could not create feature consistency plot: {e}")
|
|
|
|
# Show detailed statistics in expander
|
|
with st.expander("Show Detailed Statistics", expanded=False):
|
|
# Get top features overall
|
|
overall_top_features = (
|
|
filtered_fi_df.groupby("feature")["importance"]
|
|
.mean()
|
|
.reset_index()
|
|
.nlargest(top_n_features, "importance")["feature"]
|
|
.tolist()
|
|
)
|
|
|
|
# Calculate variance in importance across models for each feature
|
|
feature_variance = (
|
|
filtered_fi_df[filtered_fi_df["feature"].isin(overall_top_features)]
|
|
.groupby("feature")["importance"]
|
|
.agg(["mean", "std", "min", "max"])
|
|
.reset_index()
|
|
)
|
|
feature_variance["coefficient_of_variation"] = feature_variance["std"] / feature_variance["mean"]
|
|
feature_variance = feature_variance.sort_values("mean", ascending=False)
|
|
|
|
# Add data source
|
|
feature_variance = feature_variance.merge(
|
|
filtered_fi_df[["feature", "data_source"]].drop_duplicates(), on="feature", how="left"
|
|
)
|
|
|
|
feature_variance.columns = ["Feature", "Mean", "Std Dev", "Min", "Max", "CV", "Data Source"]
|
|
|
|
st.dataframe(
|
|
feature_variance[["Feature", "Data Source", "Mean", "Std Dev", "CV"]],
|
|
width="stretch",
|
|
hide_index=True,
|
|
)
|
|
|
|
# Section 3: Feature importance by data source
|
|
st.subheader("Feature Importance by Data Source")
|
|
|
|
try:
|
|
fig = create_data_source_importance_bars(filtered_fi_df)
|
|
st.plotly_chart(fig, width="stretch")
|
|
except Exception as e:
|
|
st.error(f"Could not create data source importance chart: {e}")
|
|
|
|
# Show detailed table in expander
|
|
with st.expander("Show Data Source Statistics", expanded=False):
|
|
# Aggregate importance by data source
|
|
source_importance = (
|
|
filtered_fi_df.groupby("data_source")["importance"].agg(["sum", "mean", "count"]).reset_index()
|
|
)
|
|
source_importance.columns = ["Data Source", "Total Importance", "Mean Importance", "Feature Count"]
|
|
source_importance = source_importance.sort_values("Total Importance", ascending=False)
|
|
|
|
st.dataframe(source_importance, width="stretch", hide_index=True)
|