Add training results page and model state page
This commit is contained in:
parent
31933b58d3
commit
696bef39c2
9 changed files with 1988 additions and 17 deletions
|
|
@ -1,8 +1,322 @@
|
|||
"""Model State page for the Entropice dashboard."""
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.colors import generate_unified_colormap
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_box_assignment_bars,
|
||||
plot_box_assignments,
|
||||
plot_common_features,
|
||||
plot_embedding_aggregation_summary,
|
||||
plot_embedding_heatmap,
|
||||
plot_era5_heatmap,
|
||||
plot_era5_summary,
|
||||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
load_all_training_results,
|
||||
)
|
||||
from entropice.dashboard.utils.training import load_model_state
|
||||
|
||||
|
||||
def render_model_state_page():
|
||||
"""Render the Model State page of the dashboard."""
|
||||
st.title("Model State")
|
||||
st.write("This page will display model state and feature visualizations.")
|
||||
# Add more components and visualizations as needed for model state.
|
||||
st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
|
||||
|
||||
# Load available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.error("No training results found. Please run a training search first.")
|
||||
return
|
||||
|
||||
# Result selection
|
||||
result_options = {tr.name: tr for tr in training_results}
|
||||
selected_name = st.selectbox(
|
||||
"Select Training Result",
|
||||
options=list(result_options.keys()),
|
||||
help="Choose a training result to visualize model state",
|
||||
)
|
||||
selected_result = result_options[selected_name]
|
||||
|
||||
# Load model state
|
||||
with st.spinner("Loading model state..."):
|
||||
model_state = load_model_state(selected_result)
|
||||
if model_state is None:
|
||||
st.error("Could not load model state for this result.")
|
||||
return
|
||||
|
||||
# Scale feature weights by number of features
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
||||
# Extract different feature types
|
||||
embedding_feature_array = extract_embedding_features(model_state)
|
||||
era5_feature_array = extract_era5_features(model_state)
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps
|
||||
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
|
||||
|
||||
# Display basic model state info
|
||||
with st.expander("Model State Information", expanded=False):
|
||||
st.write(f"**Variables:** {list(model_state.data_vars)}")
|
||||
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
|
||||
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||
|
||||
# Show statistics
|
||||
st.write("**Feature Weight Statistics:**")
|
||||
feature_weights = model_state["feature_weights"].to_pandas()
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
|
||||
with col2:
|
||||
st.metric("Max Weight", f"{feature_weights.max():.4f}")
|
||||
with col3:
|
||||
st.metric("Total Features", len(feature_weights))
|
||||
|
||||
# Feature importance section
|
||||
st.header("Feature Importance")
|
||||
st.markdown("The most important features based on learned feature weights from the best estimator.")
|
||||
|
||||
@st.fragment
|
||||
def render_feature_importance():
|
||||
# Slider to control number of features to display
|
||||
top_n = st.slider(
|
||||
"Number of top features to display",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=10,
|
||||
step=5,
|
||||
help="Select how many of the most important features to visualize",
|
||||
)
|
||||
|
||||
with st.spinner("Generating feature importance plot..."):
|
||||
feature_chart = plot_top_features(model_state, top_n=top_n)
|
||||
st.altair_chart(feature_chart, use_container_width=True)
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **Magnitude**: Larger absolute values indicate more important features
|
||||
- **Color**: Blue bars indicate positive weights, coral bars indicate negative weights
|
||||
"""
|
||||
)
|
||||
|
||||
render_feature_importance()
|
||||
|
||||
# Box-to-Label Assignment Visualization
|
||||
st.header("Box-to-Label Assignments")
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows how the learned boxes (prototypes in feature space) are
|
||||
assigned to different class labels. The ESPA classifier learns K boxes and assigns
|
||||
them to classes through the Lambda matrix. Higher values indicate stronger assignment
|
||||
of a box to a particular class.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.spinner("Generating box assignment visualizations..."):
|
||||
col1, col2 = st.columns([0.7, 0.3])
|
||||
|
||||
with col1:
|
||||
st.markdown("### Assignment Heatmap")
|
||||
box_assignment_heatmap = plot_box_assignments(model_state)
|
||||
st.altair_chart(box_assignment_heatmap, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.markdown("### Box Count by Class")
|
||||
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
|
||||
st.altair_chart(box_assignment_bars, use_container_width=True)
|
||||
|
||||
# Show statistics
|
||||
with st.expander("Box Assignment Statistics"):
|
||||
box_assignments = model_state["box_assignments"].to_pandas()
|
||||
st.write("**Assignment Matrix Statistics:**")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Boxes", len(box_assignments.columns))
|
||||
with col2:
|
||||
st.metric("Number of Classes", len(box_assignments.index))
|
||||
with col3:
|
||||
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
|
||||
with col4:
|
||||
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
|
||||
|
||||
# Show which boxes are most strongly assigned to each class
|
||||
st.write("**Top Box Assignments per Class:**")
|
||||
for class_label in box_assignments.index:
|
||||
top_boxes = box_assignments.loc[class_label].nlargest(5)
|
||||
st.write(
|
||||
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
|
||||
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- Each box can be assigned to multiple classes with different strengths
|
||||
- Boxes with higher assignment values for a class contribute more to that class's predictions
|
||||
- The distribution shows how the model partitions the feature space for classification
|
||||
"""
|
||||
)
|
||||
|
||||
# Embedding features analysis (if present)
|
||||
if embedding_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("🛰️ Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of embedding features showing which aggregations, bands, and years
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating dimension summaries..."):
|
||||
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.altair_chart(chart_agg, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_band, use_container_width=True)
|
||||
with col3:
|
||||
st.altair_chart(chart_year, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap by Aggregation")
|
||||
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
|
||||
with st.spinner("Generating heatmap..."):
|
||||
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
|
||||
st.altair_chart(heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Embedding Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_emb_features = embedding_feature_array.size
|
||||
mean_weight = float(embedding_feature_array.mean().values)
|
||||
max_weight = float(embedding_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Embedding Features", n_emb_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top embedding features
|
||||
st.write("**Top 10 Embedding Features:**")
|
||||
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
|
||||
st.dataframe(top_emb, width="stretch")
|
||||
else:
|
||||
st.info("No embedding features found in this model.")
|
||||
|
||||
# ERA5 features analysis (if present)
|
||||
if era5_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("⛅ ERA5 Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of ERA5 climate features showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Summary bar charts
|
||||
st.markdown("### Importance by Dimension")
|
||||
with st.spinner("Generating ERA5 dimension summaries..."):
|
||||
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.altair_chart(chart_variable, use_container_width=True)
|
||||
with col2:
|
||||
st.altair_chart(chart_time, use_container_width=True)
|
||||
|
||||
# Detailed heatmap
|
||||
st.markdown("### Detailed Heatmap")
|
||||
st.markdown("Shows the weight of each variable-time combination.")
|
||||
with st.spinner("Generating ERA5 heatmap..."):
|
||||
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
|
||||
st.altair_chart(era5_heatmap_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("ERA5 Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_era5_features = era5_feature_array.size
|
||||
mean_weight = float(era5_feature_array.mean().values)
|
||||
max_weight = float(era5_feature_array.max().values)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total ERA5 Features", n_era5_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
|
||||
# Show top ERA5 features
|
||||
st.write("**Top 10 ERA5 Features:**")
|
||||
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
|
||||
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
|
||||
st.dataframe(top_era5, width="stretch")
|
||||
else:
|
||||
st.info("No ERA5 features found in this model.")
|
||||
|
||||
# Common features analysis (if present)
|
||||
if common_feature_array is not None:
|
||||
with st.container(border=True):
|
||||
st.header("🗺️ Common Feature Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Analysis of common features including cell area, water area, land area, land ratio,
|
||||
longitude, and latitude. These features provide spatial and geographic context.
|
||||
"""
|
||||
)
|
||||
|
||||
# Bar chart showing all common feature weights
|
||||
with st.spinner("Generating common features chart..."):
|
||||
common_chart = plot_common_features(common_feature_array)
|
||||
st.altair_chart(common_chart, use_container_width=True)
|
||||
|
||||
# Statistics
|
||||
with st.expander("Common Feature Statistics"):
|
||||
st.write("**Overall Statistics:**")
|
||||
n_common_features = common_feature_array.size
|
||||
mean_weight = float(common_feature_array.mean().values)
|
||||
max_weight = float(common_feature_array.max().values)
|
||||
min_weight = float(common_feature_array.min().values)
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Total Common Features", n_common_features)
|
||||
with col2:
|
||||
st.metric("Mean Weight", f"{mean_weight:.4f}")
|
||||
with col3:
|
||||
st.metric("Max Weight", f"{max_weight:.4f}")
|
||||
with col4:
|
||||
st.metric("Min Weight", f"{min_weight:.4f}")
|
||||
|
||||
# Show all common features sorted by importance
|
||||
st.write("**All Common Features (by absolute weight):**")
|
||||
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
|
||||
common_df["abs_weight"] = common_df["weight"].abs()
|
||||
common_df = common_df.sort_values("abs_weight", ascending=False)
|
||||
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
**Interpretation:**
|
||||
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
|
||||
size-related patterns
|
||||
- **land_ratio**: Proportion of land vs water in each cell
|
||||
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
|
||||
- Positive weights indicate features that increase the probability of the positive class
|
||||
- Negative weights indicate features that decrease the probability of the positive class
|
||||
"""
|
||||
)
|
||||
else:
|
||||
st.info("No common features found in this model.")
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ Material palettes:
|
|||
"""
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
from pypalettes import load_cmap
|
||||
|
||||
|
||||
|
|
@ -90,3 +93,54 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
|
|||
cmap = get_cmap(variable).resampled(n_colors)
|
||||
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
|
||||
return colors
|
||||
|
||||
|
||||
def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
|
||||
"""Generate unified colormaps for all plotting libraries.
|
||||
|
||||
This function creates consistent color schemes across Matplotlib/Ultraplot,
|
||||
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
|
||||
of classes from the settings, then generating appropriate colormaps for each library.
|
||||
|
||||
Args:
|
||||
settings: Settings dictionary containing task type, classes, and other configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
|
||||
- matplotlib_cmap: matplotlib ListedColormap object
|
||||
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
|
||||
- altair_colors: list of hex color strings for Altair
|
||||
|
||||
"""
|
||||
# Determine task type and number of classes from settings
|
||||
task = settings.get("task", "binary")
|
||||
n_classes = len(settings.get("classes", []))
|
||||
|
||||
# Check theme
|
||||
is_dark_theme = st.context.theme.type == "dark"
|
||||
|
||||
# Define base colormaps for different tasks
|
||||
if task == "binary":
|
||||
# For binary: use a simple two-color scheme
|
||||
if is_dark_theme:
|
||||
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
|
||||
else:
|
||||
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
|
||||
else:
|
||||
# For multi-class: use a sequential colormap
|
||||
# Use matplotlib's viridis colormap
|
||||
cmap = plt.get_cmap("viridis")
|
||||
# Sample colors evenly across the colormap
|
||||
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
|
||||
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
|
||||
|
||||
# Create matplotlib colormap (for ultraplot and geopandas)
|
||||
matplotlib_cmap = mcolors.ListedColormap(base_colors)
|
||||
|
||||
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
|
||||
folium_cmap = mcolors.ListedColormap(base_colors)
|
||||
|
||||
# Create Altair color list (Altair uses hex color strings in range)
|
||||
altair_colors = base_colors
|
||||
|
||||
return matplotlib_cmap, folium_cmap, altair_colors
|
||||
|
|
|
|||
535
src/entropice/dashboard/plots/hyperparameter_analysis.py
Normal file
535
src/entropice/dashboard/plots/hyperparameter_analysis.py
Normal file
|
|
@ -0,0 +1,535 @@
|
|||
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
|
||||
|
||||
import altair as alt
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||
"""Render summary statistics of model performance.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
|
||||
|
||||
"""
|
||||
st.subheader("📊 Performance Summary")
|
||||
|
||||
# Get all test score columns
|
||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||
|
||||
if not score_cols:
|
||||
st.warning("No test score columns found in results.")
|
||||
return
|
||||
|
||||
# Calculate statistics for each metric
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("#### Best Scores")
|
||||
best_scores = []
|
||||
for col in score_cols:
|
||||
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
|
||||
best_score = results[col].max()
|
||||
best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
|
||||
|
||||
st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.markdown("#### Score Statistics")
|
||||
score_stats = []
|
||||
for col in score_cols:
|
||||
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
|
||||
mean_score = results[col].mean()
|
||||
std_score = results[col].std()
|
||||
score_stats.append(
|
||||
{
|
||||
"Metric": metric_name,
|
||||
"Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}",
|
||||
}
|
||||
)
|
||||
|
||||
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
|
||||
|
||||
# Show best parameter combination
|
||||
st.markdown("#### 🏆 Best Parameter Combination")
|
||||
refit_col = f"mean_test_{refit_metric}"
|
||||
|
||||
# Check if refit metric exists in results
|
||||
if refit_col not in results.columns:
|
||||
st.warning(
|
||||
f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
|
||||
)
|
||||
# Use the first available metric as fallback
|
||||
refit_col = score_cols[0]
|
||||
refit_metric = refit_col.replace("mean_test_", "")
|
||||
st.info(f"Using '{refit_metric}' as fallback metric.")
|
||||
|
||||
best_idx = results[refit_col].idxmax()
|
||||
best_row = results.loc[best_idx]
|
||||
|
||||
# Extract parameter columns
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
if param_cols:
|
||||
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
|
||||
|
||||
# Display in a nice formatted way
|
||||
param_df = pd.DataFrame([best_params]).T
|
||||
param_df.columns = ["Value"]
|
||||
param_df.index.name = "Parameter"
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
with col1:
|
||||
st.dataframe(param_df, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}")
|
||||
rank_col = "rank_test_" + refit_metric
|
||||
if rank_col in best_row.index:
|
||||
try:
|
||||
# Handle potential Series or scalar values
|
||||
rank_val = best_row[rank_col]
|
||||
if hasattr(rank_val, "item"):
|
||||
rank_val = rank_val.item()
|
||||
rank_display = str(int(float(rank_val)))
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
rank_display = "N/A"
|
||||
else:
|
||||
rank_display = "N/A"
|
||||
st.metric("Rank", rank_display)
|
||||
|
||||
|
||||
def render_parameter_distributions(results: pd.DataFrame):
|
||||
"""Render histograms of parameter distributions explored.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
"""
|
||||
st.subheader("📈 Parameter Space Exploration")
|
||||
|
||||
# Get parameter columns
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
if not param_cols:
|
||||
st.warning("No parameter columns found in results.")
|
||||
return
|
||||
|
||||
# Create histograms for each parameter
|
||||
n_params = len(param_cols)
|
||||
n_cols = min(3, n_params)
|
||||
n_rows = (n_params + n_cols - 1) // n_cols
|
||||
|
||||
for row in range(n_rows):
|
||||
cols = st.columns(n_cols)
|
||||
for col_idx in range(n_cols):
|
||||
param_idx = row * n_cols + col_idx
|
||||
if param_idx >= n_params:
|
||||
break
|
||||
|
||||
param_col = param_cols[param_idx]
|
||||
param_name = param_col.replace("param_", "")
|
||||
|
||||
with cols[col_idx]:
|
||||
# Check if parameter is numeric or categorical
|
||||
param_values = results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
# Numeric parameter - use histogram
|
||||
df_plot = pd.DataFrame({param_name: param_values})
|
||||
|
||||
# Use log scale if the range spans multiple orders of magnitude
|
||||
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||
use_log = value_range > 100
|
||||
|
||||
if use_log:
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
alt.X(
|
||||
param_name,
|
||||
bin=alt.Bin(maxbins=30),
|
||||
scale=alt.Scale(type="log"),
|
||||
title=param_name,
|
||||
),
|
||||
alt.Y("count()", title="Count"),
|
||||
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
|
||||
)
|
||||
.properties(height=250, title=f"{param_name} (log scale)")
|
||||
)
|
||||
else:
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name),
|
||||
alt.Y("count()", title="Count"),
|
||||
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
else:
|
||||
# Categorical parameter - use bar chart
|
||||
value_counts = param_values.value_counts().reset_index()
|
||||
value_counts.columns = [param_name, "count"]
|
||||
|
||||
chart = (
|
||||
alt.Chart(value_counts)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
alt.X(param_name, title=param_name, sort="-y"),
|
||||
alt.Y("count", title="Count"),
|
||||
tooltip=[param_name, "count"],
|
||||
)
|
||||
.properties(height=250, title=param_name)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
|
||||
def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
||||
"""Render scatter plots of score vs each parameter.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: The metric to plot (e.g., 'f1', 'accuracy').
|
||||
|
||||
"""
|
||||
st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters")
|
||||
|
||||
score_col = f"mean_test_{metric}"
|
||||
if score_col not in results.columns:
|
||||
st.warning(f"Metric {metric} not found in results.")
|
||||
return
|
||||
|
||||
# Get parameter columns
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
if not param_cols:
|
||||
st.warning("No parameter columns found in results.")
|
||||
return
|
||||
|
||||
# Create scatter plots for each parameter
|
||||
n_params = len(param_cols)
|
||||
n_cols = min(2, n_params)
|
||||
n_rows = (n_params + n_cols - 1) // n_cols
|
||||
|
||||
for row in range(n_rows):
|
||||
cols = st.columns(n_cols)
|
||||
for col_idx in range(n_cols):
|
||||
param_idx = row * n_cols + col_idx
|
||||
if param_idx >= n_params:
|
||||
break
|
||||
|
||||
param_col = param_cols[param_idx]
|
||||
param_name = param_col.replace("param_", "")
|
||||
|
||||
with cols[col_idx]:
|
||||
param_values = results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
# Numeric parameter - scatter plot
|
||||
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
|
||||
|
||||
# Use log scale if needed
|
||||
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||
use_log = value_range > 100
|
||||
|
||||
if use_log:
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_circle(size=60, opacity=0.6)
|
||||
.encode(
|
||||
alt.X(
|
||||
param_name,
|
||||
scale=alt.Scale(type="log"),
|
||||
title=param_name,
|
||||
),
|
||||
alt.Y(metric, title=metric.replace("_", " ").title()),
|
||||
alt.Color(
|
||||
metric,
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
|
||||
)
|
||||
.properties(height=300, title=f"{metric} vs {param_name} (log scale)")
|
||||
)
|
||||
else:
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_circle(size=60, opacity=0.6)
|
||||
.encode(
|
||||
alt.X(param_name, title=param_name),
|
||||
alt.Y(metric, title=metric.replace("_", " ").title()),
|
||||
alt.Color(
|
||||
metric,
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
|
||||
)
|
||||
.properties(height=300, title=f"{metric} vs {param_name}")
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
else:
|
||||
# Categorical parameter - box plot
|
||||
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
|
||||
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_boxplot()
|
||||
.encode(
|
||||
alt.X(param_name, title=param_name),
|
||||
alt.Y(metric, title=metric.replace("_", " ").title()),
|
||||
tooltip=[param_name, alt.Tooltip(metric, format=".4f")],
|
||||
)
|
||||
.properties(height=300, title=f"{metric} vs {param_name}")
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
|
||||
def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
||||
"""Render correlation heatmap between parameters and score.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: The metric to analyze (e.g., 'f1', 'accuracy').
|
||||
|
||||
"""
|
||||
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
|
||||
|
||||
score_col = f"mean_test_{metric}"
|
||||
if score_col not in results.columns:
|
||||
st.warning(f"Metric {metric} not found in results.")
|
||||
return
|
||||
|
||||
# Get numeric parameter columns
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
|
||||
|
||||
if not numeric_params:
|
||||
st.warning("No numeric parameters found for correlation analysis.")
|
||||
return
|
||||
|
||||
# Calculate correlations
|
||||
correlations = []
|
||||
for param_col in numeric_params:
|
||||
param_name = param_col.replace("param_", "")
|
||||
corr = results[[param_col, score_col]].corr().iloc[0, 1]
|
||||
correlations.append({"Parameter": param_name, "Correlation": corr})
|
||||
|
||||
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(corr_df)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
alt.X("Correlation", title="Correlation with Score"),
|
||||
alt.Y("Parameter", sort="-x", title="Parameter"),
|
||||
alt.Color(
|
||||
"Correlation",
|
||||
scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
|
||||
)
|
||||
.properties(height=max(200, len(correlations) * 30))
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
# Show correlation table
|
||||
with st.expander("📋 Correlation Table"):
|
||||
st.dataframe(
|
||||
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
|
||||
hide_index=True,
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
|
||||
def render_score_evolution(results: pd.DataFrame, metric: str):
|
||||
"""Render evolution of scores during search.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: The metric to plot (e.g., 'f1', 'accuracy').
|
||||
|
||||
"""
|
||||
st.subheader(f"📉 {metric.replace('_', ' ').title()} Evolution")
|
||||
|
||||
score_col = f"mean_test_{metric}"
|
||||
if score_col not in results.columns:
|
||||
st.warning(f"Metric {metric} not found in results.")
|
||||
return
|
||||
|
||||
# Create a copy with iteration number
|
||||
df_plot = results[[score_col]].copy()
|
||||
df_plot["Iteration"] = range(len(df_plot))
|
||||
df_plot["Best So Far"] = df_plot[score_col].cummax()
|
||||
df_plot = df_plot.rename(columns={score_col: "Score"})
|
||||
|
||||
# Reshape for Altair
|
||||
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
|
||||
|
||||
# Create line chart
|
||||
chart = (
|
||||
alt.Chart(df_long)
|
||||
.mark_line()
|
||||
.encode(
|
||||
alt.X("Iteration", title="Iteration"),
|
||||
alt.Y("value", title=metric.replace("_", " ").title()),
|
||||
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
|
||||
strokeDash=alt.StrokeDash(
|
||||
"Type",
|
||||
legend=None,
|
||||
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
|
||||
),
|
||||
tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")],
|
||||
)
|
||||
.properties(height=400)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
# Show statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}")
|
||||
with col2:
|
||||
st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}")
|
||||
with col3:
|
||||
st.metric("Std Dev", f"{df_plot['Score'].std():.4f}")
|
||||
with col4:
|
||||
# Find iteration where best was found
|
||||
best_iter = df_plot["Score"].idxmax()
|
||||
st.metric("Best at Iteration", best_iter)
|
||||
|
||||
|
||||
def render_multi_metric_comparison(results: pd.DataFrame):
|
||||
"""Render comparison of multiple metrics.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
"""
|
||||
st.subheader("📊 Multi-Metric Comparison")
|
||||
|
||||
# Get all test score columns
|
||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||
|
||||
if len(score_cols) < 2:
|
||||
st.warning("Need at least 2 metrics for comparison.")
|
||||
return
|
||||
|
||||
# Let user select two metrics to compare
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
metric1 = st.selectbox(
|
||||
"Select First Metric",
|
||||
options=[col.replace("mean_test_", "") for col in score_cols],
|
||||
index=0,
|
||||
key="metric1_select",
|
||||
)
|
||||
|
||||
with col2:
|
||||
metric2 = st.selectbox(
|
||||
"Select Second Metric",
|
||||
options=[col.replace("mean_test_", "") for col in score_cols],
|
||||
index=min(1, len(score_cols) - 1),
|
||||
key="metric2_select",
|
||||
)
|
||||
|
||||
if metric1 == metric2:
|
||||
st.warning("Please select different metrics.")
|
||||
return
|
||||
|
||||
# Create scatter plot
|
||||
df_plot = pd.DataFrame(
|
||||
{
|
||||
metric1: results[f"mean_test_{metric1}"],
|
||||
metric2: results[f"mean_test_{metric2}"],
|
||||
"Iteration": range(len(results)),
|
||||
}
|
||||
)
|
||||
|
||||
chart = (
|
||||
alt.Chart(df_plot)
|
||||
.mark_circle(size=60, opacity=0.6)
|
||||
.encode(
|
||||
alt.X(metric1, title=metric1.replace("_", " ").title()),
|
||||
alt.Y(metric2, title=metric2.replace("_", " ").title()),
|
||||
alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
|
||||
tooltip=[
|
||||
alt.Tooltip(metric1, format=".4f"),
|
||||
alt.Tooltip(metric2, format=".4f"),
|
||||
"Iteration",
|
||||
],
|
||||
)
|
||||
.properties(height=500)
|
||||
)
|
||||
|
||||
st.altair_chart(chart, use_container_width=True)
|
||||
|
||||
# Calculate correlation
|
||||
corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
|
||||
st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}")
|
||||
|
||||
|
||||
def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10):
|
||||
"""Render table of top N configurations.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: The metric to rank by (e.g., 'f1', 'accuracy').
|
||||
top_n: Number of top configurations to show.
|
||||
|
||||
"""
|
||||
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
|
||||
|
||||
score_col = f"mean_test_{metric}"
|
||||
if score_col not in results.columns:
|
||||
st.warning(f"Metric {metric} not found in results.")
|
||||
return
|
||||
|
||||
# Get parameter columns
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
if not param_cols:
|
||||
st.warning("No parameter columns found in results.")
|
||||
return
|
||||
|
||||
# Get top N configurations
|
||||
top_configs = results.nlargest(top_n, score_col)
|
||||
|
||||
# Create display dataframe
|
||||
display_cols = ["rank_test_" + metric, score_col, *param_cols]
|
||||
display_cols = [col for col in display_cols if col in top_configs.columns]
|
||||
|
||||
display_df = top_configs[display_cols].copy()
|
||||
|
||||
# Rename columns for better display
|
||||
display_df = display_df.rename(
|
||||
columns={
|
||||
"rank_test_" + metric: "Rank",
|
||||
score_col: metric.replace("_", " ").title(),
|
||||
}
|
||||
)
|
||||
|
||||
# Rename parameter columns
|
||||
display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns]
|
||||
|
||||
# Format score column
|
||||
score_col_display = metric.replace("_", " ").title()
|
||||
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
449
src/entropice/dashboard/plots/model_state.py
Normal file
449
src/entropice/dashboard/plots/model_state.py
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
"""Plotting functions for model state visualization."""
|
||||
|
||||
import altair as alt
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
|
||||
|
||||
def plot_top_features(model_state: xr.Dataset, top_n: int = 10) -> alt.Chart:
|
||||
"""Plot the top N most important features based on feature weights.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
top_n: Number of top features to display.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the top features by importance.
|
||||
|
||||
"""
|
||||
# Extract feature weights
|
||||
feature_weights = model_state["feature_weights"].to_pandas()
|
||||
|
||||
# Sort by absolute weight and take top N
|
||||
top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True)
|
||||
|
||||
# Create DataFrame for plotting with original (signed) weights
|
||||
plot_data = pd.DataFrame(
|
||||
{
|
||||
"feature": top_features.index,
|
||||
"weight": feature_weights.loc[top_features.index].to_numpy(),
|
||||
"abs_weight": top_features.to_numpy(),
|
||||
}
|
||||
)
|
||||
|
||||
# Create horizontal bar chart
|
||||
chart = (
|
||||
alt.Chart(plot_data)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
|
||||
color=alt.condition(
|
||||
alt.datum.weight > 0,
|
||||
alt.value("steelblue"), # Positive weights
|
||||
alt.value("coral"), # Negative weights
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=400,
|
||||
title=f"Top {top_n} Most Important Features",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a heatmap showing embedding feature weights across bands and years.
|
||||
|
||||
Args:
|
||||
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the heatmap.
|
||||
|
||||
"""
|
||||
# Convert to DataFrame for plotting
|
||||
df = embedding_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Create faceted heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("year:O", title="Year"),
|
||||
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("agg:N", title="Aggregation"),
|
||||
alt.Tooltip("band:N", title="Band"),
|
||||
alt.Tooltip("year:O", title="Year"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=200, height=200)
|
||||
.facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
|
||||
|
||||
Args:
|
||||
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Tuple of three Altair charts (by_agg, by_band, by_year).
|
||||
|
||||
"""
|
||||
# Aggregate by different dimensions
|
||||
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
|
||||
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
|
||||
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
|
||||
|
||||
# Create DataFrames
|
||||
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
|
||||
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
|
||||
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
|
||||
|
||||
# Sort by weight
|
||||
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
|
||||
df_band = df_band.sort_values("mean_abs_weight", ascending=True)
|
||||
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
|
||||
|
||||
# Create charts with different colors
|
||||
chart_agg = (
|
||||
alt.Chart(df_agg)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Aggregation", sort="-x"),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="blues"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Aggregation"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=250, height=200, title="By Aggregation")
|
||||
)
|
||||
|
||||
chart_band = (
|
||||
alt.Chart(df_band)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Band", sort="-x"),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="greens"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Band"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=250, height=200, title="By Band")
|
||||
)
|
||||
|
||||
chart_year = (
|
||||
alt.Chart(df_year)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:O", title="Year", sort="-x"),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="oranges"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:O", title="Year"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=250, height=200, title="By Year")
|
||||
)
|
||||
|
||||
return chart_agg, chart_band, chart_year
|
||||
|
||||
|
||||
def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a heatmap showing ERA5 feature weights across variables and time.
|
||||
|
||||
Args:
|
||||
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the heatmap.
|
||||
|
||||
"""
|
||||
# Convert to DataFrame for plotting
|
||||
df = era5_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("time:N", title="Time", sort=None),
|
||||
y=alt.Y("variable:N", title="Variable", sort="-color"),
|
||||
color=alt.Color(
|
||||
"weight:Q",
|
||||
scale=alt.Scale(scheme="redblue", domainMid=0),
|
||||
title="Weight",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("variable:N", title="Variable"),
|
||||
alt.Tooltip("time:N", title="Time"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
height=400,
|
||||
title="ERA5 Feature Weights Heatmap",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
|
||||
"""Create bar charts summarizing ERA5 weights by variable and time.
|
||||
|
||||
Args:
|
||||
era5_array: DataArray with dimensions (variable, time) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Tuple of two Altair charts (by_variable, by_time).
|
||||
|
||||
"""
|
||||
# Aggregate by different dimensions
|
||||
by_variable = era5_array.mean(dim="time").to_pandas().abs()
|
||||
by_time = era5_array.mean(dim="variable").to_pandas().abs()
|
||||
|
||||
# Create DataFrames
|
||||
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
|
||||
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
|
||||
|
||||
# Sort by weight
|
||||
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
|
||||
df_time = df_time.sort_values("mean_abs_weight", ascending=True)
|
||||
|
||||
# Create charts with different colors
|
||||
chart_variable = (
|
||||
alt.Chart(df_variable)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="purples"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Variable"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Variable")
|
||||
)
|
||||
|
||||
chart_time = (
|
||||
alt.Chart(df_time)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
|
||||
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
|
||||
color=alt.Color(
|
||||
"mean_abs_weight:Q",
|
||||
scale=alt.Scale(scheme="teals"),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("dimension:N", title="Time"),
|
||||
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
|
||||
],
|
||||
)
|
||||
.properties(width=400, height=300, title="By Time")
|
||||
)
|
||||
|
||||
return chart_variable, chart_time
|
||||
|
||||
|
||||
def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart:
|
||||
"""Create a heatmap showing which boxes are assigned to which labels/classes.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the box-to-label assignment heatmap.
|
||||
|
||||
"""
|
||||
# Extract box assignments
|
||||
box_assignments = model_state["box_assignments"]
|
||||
|
||||
# Convert to DataFrame for plotting
|
||||
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||
|
||||
# Create heatmap
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_rect()
|
||||
.encode(
|
||||
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
|
||||
y=alt.Y("class:N", title="Class Label"),
|
||||
color=alt.Color(
|
||||
"assignment:Q",
|
||||
scale=alt.Scale(scheme="viridis"),
|
||||
title="Assignment Strength",
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("class:N", title="Class"),
|
||||
alt.Tooltip("box:O", title="Box"),
|
||||
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
height=150,
|
||||
title="Box-to-Label Assignments (Lambda Matrix)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]) -> alt.Chart:
|
||||
"""Create a bar chart showing how many boxes are assigned to each class.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state with box_assignments.
|
||||
altair_colors: List of hex color strings for altair.
|
||||
|
||||
Returns:
|
||||
Altair chart showing count of boxes per class.
|
||||
|
||||
"""
|
||||
# Extract box assignments
|
||||
box_assignments = model_state["box_assignments"]
|
||||
|
||||
# Convert to DataFrame
|
||||
df = box_assignments.to_dataframe(name="assignment").reset_index()
|
||||
|
||||
# For each box, find which class it's most strongly assigned to
|
||||
box_to_class = df.groupby("box")["assignment"].idxmax()
|
||||
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
|
||||
|
||||
# Count boxes per class
|
||||
counts = primary_classes.groupby("class").size().reset_index(name="count")
|
||||
|
||||
# Replace the special (-1, 0] interval with "No RTS" if present
|
||||
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
|
||||
|
||||
# Sort the classes: "No RTS" first, then by the lower bound of intervals
|
||||
def sort_key(class_str):
|
||||
if class_str == "No RTS":
|
||||
return -1 # Put "No RTS" first
|
||||
# Parse interval string like "(0, 4]" or "(4, 36]"
|
||||
try:
|
||||
lower = float(str(class_str).split(",")[0].strip("([ "))
|
||||
return lower
|
||||
except (ValueError, IndexError):
|
||||
return float("inf") # Put unparseable values at the end
|
||||
|
||||
# Sort counts by the same key
|
||||
counts["sort_key"] = counts["class"].apply(sort_key)
|
||||
counts = counts.sort_values("sort_key")
|
||||
|
||||
# Create an ordered list of classes for consistent color mapping
|
||||
class_order = counts["class"].tolist()
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(counts)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
|
||||
y=alt.Y("count:Q", title="Number of Boxes"),
|
||||
color=alt.Color(
|
||||
"class:N",
|
||||
title="Class",
|
||||
scale=alt.Scale(domain=class_order, range=altair_colors),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("class:N", title="Class"),
|
||||
alt.Tooltip("count:Q", title="Number of Boxes"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=300,
|
||||
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
||||
|
||||
def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
|
||||
"""Create a bar chart showing the weights of common features.
|
||||
|
||||
Args:
|
||||
common_array: DataArray with dimension (feature) containing feature weights.
|
||||
|
||||
Returns:
|
||||
Altair chart showing the common feature weights.
|
||||
|
||||
"""
|
||||
# Convert to DataFrame for plotting
|
||||
df = common_array.to_dataframe(name="weight").reset_index()
|
||||
|
||||
# Sort by absolute weight
|
||||
df["abs_weight"] = df["weight"].abs()
|
||||
df = df.sort_values("abs_weight", ascending=True)
|
||||
|
||||
# Create bar chart
|
||||
chart = (
|
||||
alt.Chart(df)
|
||||
.mark_bar()
|
||||
.encode(
|
||||
y=alt.Y("feature:N", title="Feature", sort="-x"),
|
||||
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
|
||||
color=alt.condition(
|
||||
alt.datum.weight > 0,
|
||||
alt.value("steelblue"), # Positive weights
|
||||
alt.value("coral"), # Negative weights
|
||||
),
|
||||
tooltip=[
|
||||
alt.Tooltip("feature:N", title="Feature"),
|
||||
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
|
||||
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
|
||||
],
|
||||
)
|
||||
.properties(
|
||||
width=600,
|
||||
height=300,
|
||||
title="Common Feature Weights",
|
||||
)
|
||||
)
|
||||
|
||||
return chart
|
||||
|
|
@ -2,9 +2,237 @@
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||
render_multi_metric_comparison,
|
||||
render_parameter_correlation,
|
||||
render_parameter_distributions,
|
||||
render_performance_summary,
|
||||
render_score_evolution,
|
||||
render_score_vs_parameter,
|
||||
render_top_configurations,
|
||||
)
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
from entropice.dashboard.utils.training import (
|
||||
format_metric_name,
|
||||
get_available_metrics,
|
||||
get_cv_statistics,
|
||||
get_parameter_space_summary,
|
||||
)
|
||||
|
||||
|
||||
def render_training_analysis_page():
|
||||
"""Render the Training Results Analysis page of the dashboard."""
|
||||
st.title("Training Results Analysis")
|
||||
st.write("This page will display analysis of training results and model performance.")
|
||||
# Add more components and visualizations as needed for training results analysis.
|
||||
st.title("🦾 Training Results Analysis")
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.training`")
|
||||
return
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options
|
||||
training_options = {tr.name: tr for tr in training_results}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(training_options.keys()),
|
||||
index=0,
|
||||
help="Select a training run to analyze",
|
||||
)
|
||||
|
||||
selected_result = training_options[selected_name]
|
||||
|
||||
st.divider()
|
||||
|
||||
# Display selected run info
|
||||
st.subheader("Run Information")
|
||||
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
|
||||
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
|
||||
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
|
||||
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
|
||||
st.write(f"**Trials:** {len(selected_result.results)}")
|
||||
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
|
||||
|
||||
# Refit metric - determine from available metrics
|
||||
available_metrics = get_available_metrics(selected_result.results)
|
||||
|
||||
# Try to get refit metric from settings
|
||||
refit_metric = selected_result.settings.get("refit_metric")
|
||||
|
||||
if not refit_metric or refit_metric not in available_metrics:
|
||||
# Infer from task or use first available metric
|
||||
task = selected_result.settings.get("task", "binary")
|
||||
if task == "binary" and "f1" in available_metrics:
|
||||
refit_metric = "f1"
|
||||
elif "f1_weighted" in available_metrics:
|
||||
refit_metric = "f1_weighted"
|
||||
elif "accuracy" in available_metrics:
|
||||
refit_metric = "accuracy"
|
||||
elif available_metrics:
|
||||
refit_metric = available_metrics[0]
|
||||
else:
|
||||
st.error("No metrics found in results.")
|
||||
return
|
||||
|
||||
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Metric selection for detailed analysis
|
||||
st.subheader("Analysis Settings")
|
||||
|
||||
available_metrics = get_available_metrics(selected_result.results)
|
||||
|
||||
if refit_metric in available_metrics:
|
||||
default_metric_idx = available_metrics.index(refit_metric)
|
||||
else:
|
||||
default_metric_idx = 0
|
||||
|
||||
selected_metric = st.selectbox(
|
||||
"Primary Metric for Analysis",
|
||||
options=available_metrics,
|
||||
index=default_metric_idx,
|
||||
format_func=format_metric_name,
|
||||
help="Select the metric to focus on for detailed analysis",
|
||||
)
|
||||
|
||||
# Top N configurations
|
||||
top_n = st.slider(
|
||||
"Top N Configurations",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
value=10,
|
||||
step=5,
|
||||
help="Number of top configurations to display",
|
||||
)
|
||||
|
||||
# Main content area
|
||||
results = selected_result.results
|
||||
settings = selected_result.settings
|
||||
|
||||
# Performance Summary Section
|
||||
st.header("📊 Performance Overview")
|
||||
|
||||
render_performance_summary(results, refit_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Quick Statistics
|
||||
st.header("📈 Cross-Validation Statistics")
|
||||
|
||||
cv_stats = get_cv_statistics(results, selected_metric)
|
||||
|
||||
if cv_stats:
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
|
||||
with col1:
|
||||
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
|
||||
|
||||
with col2:
|
||||
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
|
||||
|
||||
with col3:
|
||||
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
|
||||
|
||||
with col4:
|
||||
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
|
||||
|
||||
with col5:
|
||||
st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
|
||||
|
||||
if "mean_cv_std" in cv_stats:
|
||||
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Score Evolution
|
||||
st.header("📉 Training Progress")
|
||||
|
||||
render_score_evolution(results, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Space Exploration
|
||||
st.header("🔍 Parameter Space Analysis")
|
||||
|
||||
# Show parameter space summary
|
||||
with st.expander("📋 Parameter Space Summary", expanded=False):
|
||||
param_summary = get_parameter_space_summary(results)
|
||||
if not param_summary.empty:
|
||||
st.dataframe(param_summary, hide_index=True, use_container_width=True)
|
||||
else:
|
||||
st.info("No parameter information available.")
|
||||
|
||||
# Parameter distributions
|
||||
render_parameter_distributions(results)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Score vs Parameters
|
||||
st.header("🎯 Parameter Impact Analysis")
|
||||
|
||||
render_score_vs_parameter(results, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Correlation
|
||||
st.header("🔗 Parameter Correlation Analysis")
|
||||
|
||||
render_parameter_correlation(results, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Multi-Metric Comparison
|
||||
if len(available_metrics) >= 2:
|
||||
st.header("📊 Multi-Metric Analysis")
|
||||
|
||||
render_multi_metric_comparison(results)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Top Configurations
|
||||
st.header("🏆 Top Performing Configurations")
|
||||
|
||||
render_top_configurations(results, selected_metric, top_n)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Raw Data Export
|
||||
with st.expander("💾 Export Data", expanded=False):
|
||||
st.subheader("Download Results")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Download full results as CSV
|
||||
csv_data = results.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="📥 Download Full Results (CSV)",
|
||||
data=csv_data,
|
||||
file_name=f"{selected_result.path.name}_results.csv",
|
||||
mime="text/csv",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Download settings as text
|
||||
import json
|
||||
|
||||
settings_json = json.dumps(settings, indent=2)
|
||||
st.download_button(
|
||||
label="⚙️ Download Settings (JSON)",
|
||||
data=settings_json,
|
||||
file_name=f"{selected_result.path.name}_settings.json",
|
||||
mime="application/json",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
# Show raw data preview
|
||||
st.subheader("Raw Data Preview")
|
||||
st.dataframe(results.head(100), use_container_width=True)
|
||||
|
|
|
|||
|
|
@ -118,6 +118,42 @@ def render_training_data_page():
|
|||
# Display dataset ID in a styled container
|
||||
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
||||
|
||||
# Display dataset statistics
|
||||
st.markdown("---")
|
||||
st.subheader("📈 Dataset Statistics")
|
||||
|
||||
with st.spinner("Computing dataset statistics..."):
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Display target information
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.metric(label="Target", value=stats["target"].replace("darts_", ""))
|
||||
with col2:
|
||||
st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}")
|
||||
|
||||
# Display member statistics
|
||||
st.markdown("**Member Statistics:**")
|
||||
|
||||
for member, member_stats in stats["members"].items():
|
||||
with st.expander(f"📦 {member}", expanded=False):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.markdown(f"**Number of Features:** {member_stats['num_features']}")
|
||||
st.markdown(f"**Number of Variables:** {member_stats['num_variables']}")
|
||||
with col2:
|
||||
st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`")
|
||||
|
||||
# Display variables as a compact list
|
||||
st.markdown(f"**Variables ({member_stats['num_variables']}):**")
|
||||
vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]])
|
||||
st.markdown(vars_str)
|
||||
|
||||
# Display total features
|
||||
st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Create tabs for different data views
|
||||
tab_names = ["📊 Labels"]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import antimeridian
|
|||
import pandas as pd
|
||||
import streamlit as st
|
||||
import toml
|
||||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
import entropice.paths
|
||||
|
|
@ -119,3 +120,98 @@ def load_source_data(e: DatasetEnsemble, source: str):
|
|||
ds = e._read_member(source, targets, lazy=False)
|
||||
|
||||
return ds, targets
|
||||
|
||||
|
||||
def extract_embedding_features(model_state) -> xr.DataArray | None:
|
||||
"""Extract embedding features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted embedding features. This DataArray has dimensions
|
||||
('agg', 'band', 'year') corresponding to the different components of the embedding features.
|
||||
Returns None if no embedding features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_embedding_feature(feature: str) -> bool:
|
||||
return feature.startswith("embedding_")
|
||||
|
||||
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
|
||||
if len(embedding_features) == 0:
|
||||
return None
|
||||
|
||||
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
|
||||
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
|
||||
embedding_feature_array = embedding_feature_array.assign_coords(
|
||||
agg=("feature", [f.split("_")[1] for f in embedding_features]),
|
||||
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
||||
)
|
||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||
return embedding_feature_array
|
||||
|
||||
|
||||
def extract_era5_features(model_state) -> xr.DataArray | None:
|
||||
"""Extract ERA5 features from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
|
||||
('variable', 'time') corresponding to the different components of the ERA5 features.
|
||||
Returns None if no ERA5 features are found.
|
||||
|
||||
"""
|
||||
|
||||
def _is_era5_feature(feature: str) -> bool:
|
||||
return feature.startswith("era5_")
|
||||
|
||||
def _extract_var_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# era5_variablename_timetype format
|
||||
return "_".join(parts[1:-1])
|
||||
|
||||
def _extract_time_name(feature: str) -> str:
|
||||
parts = feature.split("_")
|
||||
# Last part is the time type
|
||||
return parts[-1]
|
||||
|
||||
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
|
||||
if len(era5_features) == 0:
|
||||
return None
|
||||
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
|
||||
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
|
||||
era5_features_array = era5_features_array.assign_coords(
|
||||
variable=("feature", [_extract_var_name(f) for f in era5_features]),
|
||||
time=("feature", [_extract_time_name(f) for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||
return era5_features_array
|
||||
|
||||
|
||||
def extract_common_features(model_state) -> xr.DataArray | None:
|
||||
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
|
||||
|
||||
Args:
|
||||
model_state: The xarray Dataset containing the model state.
|
||||
|
||||
Returns:
|
||||
xr.DataArray: The extracted common features with a single 'feature' dimension.
|
||||
Returns None if no common features are found.
|
||||
|
||||
"""
|
||||
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
|
||||
|
||||
def _is_common_feature(feature: str) -> bool:
|
||||
return feature in common_feature_names
|
||||
|
||||
common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)]
|
||||
if len(common_features) == 0:
|
||||
return None
|
||||
|
||||
# Extract the feature weights for common features
|
||||
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
|
||||
return common_feature_array
|
||||
|
|
|
|||
232
src/entropice/dashboard/utils/training.py
Normal file
232
src/entropice/dashboard/utils/training.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
"""Training utilities for dashboard."""
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.utils.data import TrainingResult
|
||||
|
||||
|
||||
def format_metric_name(metric: str) -> str:
|
||||
"""Format metric name for display.
|
||||
|
||||
Args:
|
||||
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
||||
|
||||
Returns:
|
||||
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
||||
|
||||
"""
|
||||
# Split by underscore and capitalize each part
|
||||
parts = metric.split("_")
|
||||
# Special handling for F1
|
||||
formatted_parts = []
|
||||
for part in parts:
|
||||
if part.lower() == "f1":
|
||||
formatted_parts.append("F1")
|
||||
else:
|
||||
formatted_parts.append(part.capitalize())
|
||||
return " ".join(formatted_parts)
|
||||
|
||||
|
||||
def get_available_metrics(results: pd.DataFrame) -> list[str]:
|
||||
"""Get list of available metrics from results.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
Returns:
|
||||
List of metric names (without 'mean_test_' prefix).
|
||||
|
||||
"""
|
||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||
return [col.replace("mean_test_", "") for col in score_cols]
|
||||
|
||||
|
||||
def load_best_model(result: TrainingResult):
|
||||
"""Load the best model from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
The loaded model object, or None if loading fails.
|
||||
|
||||
"""
|
||||
model_file = result.path / "best_estimator_model.pkl"
|
||||
if not model_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(model_file, "rb") as f:
|
||||
model = pickle.load(f)
|
||||
return model
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
|
||||
"""Load the model state from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
xarray Dataset with model state, or None if not available.
|
||||
|
||||
"""
|
||||
state_file = result.path / "best_estimator_state.nc"
|
||||
if not state_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
return state
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model state: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
|
||||
"""Load predictions from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
DataFrame with predictions, or None if not available.
|
||||
|
||||
"""
|
||||
preds_file = result.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
preds = pd.read_parquet(preds_file)
|
||||
return preds
|
||||
except Exception as e:
|
||||
st.error(f"Error loading predictions: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Get summary of parameter space explored.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
Returns:
|
||||
DataFrame with parameter ranges and statistics.
|
||||
|
||||
"""
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
summary_data = []
|
||||
for param_col in param_cols:
|
||||
param_name = param_col.replace("param_", "")
|
||||
param_values = results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
summary_data.append(
|
||||
{
|
||||
"Parameter": param_name,
|
||||
"Type": "Numeric",
|
||||
"Min": f"{param_values.min():.2e}",
|
||||
"Max": f"{param_values.max():.2e}",
|
||||
"Mean": f"{param_values.mean():.2e}",
|
||||
"Unique Values": param_values.nunique(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
unique_vals = param_values.unique()
|
||||
summary_data.append(
|
||||
{
|
||||
"Parameter": param_name,
|
||||
"Type": "Categorical",
|
||||
"Min": "-",
|
||||
"Max": "-",
|
||||
"Mean": "-",
|
||||
"Unique Values": len(unique_vals),
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(summary_data)
|
||||
|
||||
|
||||
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
|
||||
"""Get cross-validation statistics for a metric.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: Metric name (without 'mean_test_' prefix).
|
||||
|
||||
Returns:
|
||||
Dictionary with CV statistics.
|
||||
|
||||
"""
|
||||
score_col = f"mean_test_{metric}"
|
||||
std_col = f"std_test_{metric}"
|
||||
|
||||
if score_col not in results.columns:
|
||||
return {}
|
||||
|
||||
stats = {
|
||||
"best_score": results[score_col].max(),
|
||||
"mean_score": results[score_col].mean(),
|
||||
"std_score": results[score_col].std(),
|
||||
"worst_score": results[score_col].min(),
|
||||
"median_score": results[score_col].median(),
|
||||
}
|
||||
|
||||
if std_col in results.columns:
|
||||
stats["mean_cv_std"] = results[std_col].mean()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
"""Prepare results dataframe with binned columns for plotting.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
k_bin_width: Width of bins for initial_K parameter.
|
||||
|
||||
Returns:
|
||||
DataFrame with added binned columns.
|
||||
|
||||
"""
|
||||
results_copy = results.copy()
|
||||
|
||||
# Check if we have the parameters
|
||||
if "param_initial_K" in results.columns:
|
||||
# Bin initial_K
|
||||
k_values = results["param_initial_K"].dropna()
|
||||
if len(k_values) > 0:
|
||||
k_min = k_values.min()
|
||||
k_max = k_values.max()
|
||||
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
|
||||
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
|
||||
|
||||
if "param_eps_cl" in results.columns:
|
||||
# Create logarithmic bins for eps_cl
|
||||
eps_cl_values = results["param_eps_cl"].dropna()
|
||||
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
|
||||
eps_cl_min = eps_cl_values.min()
|
||||
eps_cl_max = eps_cl_values.max()
|
||||
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
|
||||
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
|
||||
|
||||
if "param_eps_e" in results.columns:
|
||||
# Create logarithmic bins for eps_e
|
||||
eps_e_values = results["param_eps_e"].dropna()
|
||||
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
|
||||
eps_e_min = eps_e_values.min()
|
||||
eps_e_max = eps_e_values.max()
|
||||
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
|
||||
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
|
||||
|
||||
return results_copy
|
||||
|
|
@ -283,25 +283,52 @@ class DatasetEnsemble:
|
|||
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||
return arcticdem_df
|
||||
|
||||
def print_stats(self):
|
||||
targets = self._read_target()
|
||||
print(f"=== Target: {self.target}")
|
||||
print(f"\tNumber of target samples: {len(targets)}")
|
||||
def get_stats(self) -> dict:
|
||||
"""Get dataset statistics.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing target stats, member stats, and total features count.
|
||||
|
||||
"""
|
||||
targets = self._read_target()
|
||||
stats = {
|
||||
"target": self.target,
|
||||
"num_target_samples": len(targets),
|
||||
"members": {},
|
||||
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
|
||||
}
|
||||
|
||||
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
|
||||
for member in self.members:
|
||||
ds = self._read_member(member, targets, lazy=True)
|
||||
print(f"=== Member: {member}")
|
||||
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
|
||||
print(f"\tDimensions: {dict(ds.sizes)}")
|
||||
print(f"\tCoordinates: {list(ds.coords)}")
|
||||
n_cols_member = len(ds.data_vars)
|
||||
for dim in ds.sizes:
|
||||
if dim != "cell_ids":
|
||||
n_cols_member *= ds.sizes[dim]
|
||||
print(f"\tNumber of features from member: {n_cols_member}")
|
||||
n_cols += n_cols_member
|
||||
print(f"=== Total number of features in dataset: {n_cols}")
|
||||
|
||||
stats["members"][member] = {
|
||||
"variables": list(ds.data_vars),
|
||||
"num_variables": len(ds.data_vars),
|
||||
"dimensions": dict(ds.sizes),
|
||||
"coordinates": list(ds.coords),
|
||||
"num_features": n_cols_member,
|
||||
}
|
||||
stats["total_features"] += n_cols_member
|
||||
|
||||
return stats
|
||||
|
||||
def print_stats(self):
|
||||
stats = self.get_stats()
|
||||
print(f"=== Target: {stats['target']}")
|
||||
print(f"\tNumber of target samples: {stats['num_target_samples']}")
|
||||
|
||||
for member, member_stats in stats["members"].items():
|
||||
print(f"=== Member: {member}")
|
||||
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
|
||||
print(f"\tDimensions: {member_stats['dimensions']}")
|
||||
print(f"\tCoordinates: {member_stats['coordinates']}")
|
||||
print(f"\tNumber of features from member: {member_stats['num_features']}")
|
||||
|
||||
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue