Add training results page and model state page

This commit is contained in:
Tobias Hölzer 2025-12-19 02:59:42 +01:00
parent 31933b58d3
commit 696bef39c2
9 changed files with 1988 additions and 17 deletions

View file

@ -1,8 +1,322 @@
"""Model State page for the Entropice dashboard."""
import streamlit as st import streamlit as st
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import (
plot_box_assignment_bars,
plot_box_assignments,
plot_common_features,
plot_embedding_aggregation_summary,
plot_embedding_heatmap,
plot_era5_heatmap,
plot_era5_summary,
plot_top_features,
)
from entropice.dashboard.utils.data import (
extract_common_features,
extract_embedding_features,
extract_era5_features,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
def render_model_state_page(): def render_model_state_page():
"""Render the Model State page of the dashboard.""" """Render the Model State page of the dashboard."""
st.title("Model State") st.title("Model State")
st.write("This page will display model state and feature visualizations.") st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
# Add more components and visualizations as needed for model state.
# Load available training results
training_results = load_all_training_results()
if not training_results:
st.error("No training results found. Please run a training search first.")
return
# Result selection
result_options = {tr.name: tr for tr in training_results}
selected_name = st.selectbox(
"Select Training Result",
options=list(result_options.keys()),
help="Choose a training result to visualize model state",
)
selected_result = result_options[selected_name]
# Load model state
with st.spinner("Loading model state..."):
model_state = load_model_state(selected_result)
if model_state is None:
st.error("Could not load model state for this result.")
return
# Scale feature weights by number of features
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
# Extract different feature types
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state)
# Generate unified colormaps
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
# Display basic model state info
with st.expander("Model State Information", expanded=False):
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
# Show statistics
st.write("**Feature Weight Statistics:**")
feature_weights = model_state["feature_weights"].to_pandas()
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
with col2:
st.metric("Max Weight", f"{feature_weights.max():.4f}")
with col3:
st.metric("Total Features", len(feature_weights))
# Feature importance section
st.header("Feature Importance")
st.markdown("The most important features based on learned feature weights from the best estimator.")
@st.fragment
def render_feature_importance():
# Slider to control number of features to display
top_n = st.slider(
"Number of top features to display",
min_value=5,
max_value=50,
value=10,
step=5,
help="Select how many of the most important features to visualize",
)
with st.spinner("Generating feature importance plot..."):
feature_chart = plot_top_features(model_state, top_n=top_n)
st.altair_chart(feature_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- **Magnitude**: Larger absolute values indicate more important features
- **Color**: Blue bars indicate positive weights, coral bars indicate negative weights
"""
)
render_feature_importance()
# Box-to-Label Assignment Visualization
st.header("Box-to-Label Assignments")
st.markdown(
"""
This visualization shows how the learned boxes (prototypes in feature space) are
assigned to different class labels. The ESPA classifier learns K boxes and assigns
them to classes through the Lambda matrix. Higher values indicate stronger assignment
of a box to a particular class.
"""
)
with st.spinner("Generating box assignment visualizations..."):
col1, col2 = st.columns([0.7, 0.3])
with col1:
st.markdown("### Assignment Heatmap")
box_assignment_heatmap = plot_box_assignments(model_state)
st.altair_chart(box_assignment_heatmap, use_container_width=True)
with col2:
st.markdown("### Box Count by Class")
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
st.altair_chart(box_assignment_bars, use_container_width=True)
# Show statistics
with st.expander("Box Assignment Statistics"):
box_assignments = model_state["box_assignments"].to_pandas()
st.write("**Assignment Matrix Statistics:**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Boxes", len(box_assignments.columns))
with col2:
st.metric("Number of Classes", len(box_assignments.index))
with col3:
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
with col4:
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
# Show which boxes are most strongly assigned to each class
st.write("**Top Box Assignments per Class:**")
for class_label in box_assignments.index:
top_boxes = box_assignments.loc[class_label].nlargest(5)
st.write(
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
)
st.markdown(
"""
**Interpretation:**
- Each box can be assigned to multiple classes with different strengths
- Boxes with higher assignment values for a class contribute more to that class's predictions
- The distribution shows how the model partitions the feature space for classification
"""
)
# Embedding features analysis (if present)
if embedding_feature_array is not None:
with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis")
st.markdown(
"""
Analysis of embedding features showing which aggregations, bands, and years
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating dimension summaries..."):
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_agg, use_container_width=True)
with col2:
st.altair_chart(chart_band, use_container_width=True)
with col3:
st.altair_chart(chart_year, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap by Aggregation")
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
with st.spinner("Generating heatmap..."):
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
st.altair_chart(heatmap_chart, use_container_width=True)
# Statistics
with st.expander("Embedding Feature Statistics"):
st.write("**Overall Statistics:**")
n_emb_features = embedding_feature_array.size
mean_weight = float(embedding_feature_array.mean().values)
max_weight = float(embedding_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Embedding Features", n_emb_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top embedding features
st.write("**Top 10 Embedding Features:**")
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
st.dataframe(top_emb, width="stretch")
else:
st.info("No embedding features found in this model.")
# ERA5 features analysis (if present)
if era5_feature_array is not None:
with st.container(border=True):
st.header("⛅ ERA5 Feature Analysis")
st.markdown(
"""
Analysis of ERA5 climate features showing which variables and time periods
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating ERA5 dimension summaries..."):
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
col1, col2 = st.columns(2)
with col1:
st.altair_chart(chart_variable, use_container_width=True)
with col2:
st.altair_chart(chart_time, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap")
st.markdown("Shows the weight of each variable-time combination.")
with st.spinner("Generating ERA5 heatmap..."):
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, use_container_width=True)
# Statistics
with st.expander("ERA5 Feature Statistics"):
st.write("**Overall Statistics:**")
n_era5_features = era5_feature_array.size
mean_weight = float(era5_feature_array.mean().values)
max_weight = float(era5_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total ERA5 Features", n_era5_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top ERA5 features
st.write("**Top 10 ERA5 Features:**")
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
st.dataframe(top_era5, width="stretch")
else:
st.info("No ERA5 features found in this model.")
# Common features analysis (if present)
if common_feature_array is not None:
with st.container(border=True):
st.header("🗺️ Common Feature Analysis")
st.markdown(
"""
Analysis of common features including cell area, water area, land area, land ratio,
longitude, and latitude. These features provide spatial and geographic context.
"""
)
# Bar chart showing all common feature weights
with st.spinner("Generating common features chart..."):
common_chart = plot_common_features(common_feature_array)
st.altair_chart(common_chart, use_container_width=True)
# Statistics
with st.expander("Common Feature Statistics"):
st.write("**Overall Statistics:**")
n_common_features = common_feature_array.size
mean_weight = float(common_feature_array.mean().values)
max_weight = float(common_feature_array.max().values)
min_weight = float(common_feature_array.min().values)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Common Features", n_common_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
with col4:
st.metric("Min Weight", f"{min_weight:.4f}")
# Show all common features sorted by importance
st.write("**All Common Features (by absolute weight):**")
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
common_df["abs_weight"] = common_df["weight"].abs()
common_df = common_df.sort_values("abs_weight", ascending=False)
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
st.markdown(
"""
**Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
size-related patterns
- **land_ratio**: Proportion of land vs water in each cell
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Positive weights indicate features that increase the probability of the positive class
- Negative weights indicate features that decrease the probability of the positive class
"""
)
else:
st.info("No common features found in this model.")

View file

@ -25,6 +25,9 @@ Material palettes:
""" """
import matplotlib.colors as mcolors import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from pypalettes import load_cmap from pypalettes import load_cmap
@ -90,3 +93,54 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
cmap = get_cmap(variable).resampled(n_colors) cmap = get_cmap(variable).resampled(n_colors)
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)] colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
return colors return colors
def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
"""Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot,
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
of classes from the settings, then generating appropriate colormaps for each library.
Args:
settings: Settings dictionary containing task type, classes, and other configuration.
Returns:
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
- matplotlib_cmap: matplotlib ListedColormap object
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
- altair_colors: list of hex color strings for Altair
"""
# Determine task type and number of classes from settings
task = settings.get("task", "binary")
n_classes = len(settings.get("classes", []))
# Check theme
is_dark_theme = st.context.theme.type == "dark"
# Define base colormaps for different tasks
if task == "binary":
# For binary: use a simple two-color scheme
if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else:
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
else:
# For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap
cmap = plt.get_cmap("viridis")
# Sample colors evenly across the colormap
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
# Create matplotlib colormap (for ultraplot and geopandas)
matplotlib_cmap = mcolors.ListedColormap(base_colors)
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
folium_cmap = mcolors.ListedColormap(base_colors)
# Create Altair color list (Altair uses hex color strings in range)
altair_colors = base_colors
return matplotlib_cmap, folium_cmap, altair_colors

View file

@ -0,0 +1,535 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt
import pandas as pd
import streamlit as st
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
"""Render summary statistics of model performance.
Args:
results: DataFrame with CV results.
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
"""
st.subheader("📊 Performance Summary")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
if not score_cols:
st.warning("No test score columns found in results.")
return
# Calculate statistics for each metric
col1, col2 = st.columns(2)
with col1:
st.markdown("#### Best Scores")
best_scores = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
best_score = results[col].max()
best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True)
with col2:
st.markdown("#### Score Statistics")
score_stats = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
mean_score = results[col].mean()
std_score = results[col].std()
score_stats.append(
{
"Metric": metric_name,
"Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}",
}
)
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
# Show best parameter combination
st.markdown("#### 🏆 Best Parameter Combination")
refit_col = f"mean_test_{refit_metric}"
# Check if refit metric exists in results
if refit_col not in results.columns:
st.warning(
f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
)
# Use the first available metric as fallback
refit_col = score_cols[0]
refit_metric = refit_col.replace("mean_test_", "")
st.info(f"Using '{refit_metric}' as fallback metric.")
best_idx = results[refit_col].idxmax()
best_row = results.loc[best_idx]
# Extract parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if param_cols:
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
# Display in a nice formatted way
param_df = pd.DataFrame([best_params]).T
param_df.columns = ["Value"]
param_df.index.name = "Parameter"
col1, col2 = st.columns([1, 1])
with col1:
st.dataframe(param_df, use_container_width=True)
with col2:
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}")
rank_col = "rank_test_" + refit_metric
if rank_col in best_row.index:
try:
# Handle potential Series or scalar values
rank_val = best_row[rank_col]
if hasattr(rank_val, "item"):
rank_val = rank_val.item()
rank_display = str(int(float(rank_val)))
except (ValueError, TypeError, AttributeError):
rank_display = "N/A"
else:
rank_display = "N/A"
st.metric("Rank", rank_display)
def render_parameter_distributions(results: pd.DataFrame):
"""Render histograms of parameter distributions explored.
Args:
results: DataFrame with CV results.
"""
st.subheader("📈 Parameter Space Exploration")
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Create histograms for each parameter
n_params = len(param_cols)
n_cols = min(3, n_params)
n_rows = (n_params + n_cols - 1) // n_cols
for row in range(n_rows):
cols = st.columns(n_cols)
for col_idx in range(n_cols):
param_idx = row * n_cols + col_idx
if param_idx >= n_params:
break
param_col = param_cols[param_idx]
param_name = param_col.replace("param_", "")
with cols[col_idx]:
# Check if parameter is numeric or categorical
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
# Numeric parameter - use histogram
df_plot = pd.DataFrame({param_name: param_values})
# Use log scale if the range spans multiple orders of magnitude
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
if use_log:
chart = (
alt.Chart(df_plot)
.mark_bar()
.encode(
alt.X(
param_name,
bin=alt.Bin(maxbins=30),
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
)
.properties(height=250, title=f"{param_name} (log scale)")
)
else:
chart = (
alt.Chart(df_plot)
.mark_bar()
.encode(
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
)
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
else:
# Categorical parameter - use bar chart
value_counts = param_values.value_counts().reset_index()
value_counts.columns = [param_name, "count"]
chart = (
alt.Chart(value_counts)
.mark_bar()
.encode(
alt.X(param_name, title=param_name, sort="-y"),
alt.Y("count", title="Count"),
tooltip=[param_name, "count"],
)
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
def render_score_vs_parameter(results: pd.DataFrame, metric: str):
"""Render scatter plots of score vs each parameter.
Args:
results: DataFrame with CV results.
metric: The metric to plot (e.g., 'f1', 'accuracy').
"""
st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Create scatter plots for each parameter
n_params = len(param_cols)
n_cols = min(2, n_params)
n_rows = (n_params + n_cols - 1) // n_cols
for row in range(n_rows):
cols = st.columns(n_cols)
for col_idx in range(n_cols):
param_idx = row * n_cols + col_idx
if param_idx >= n_params:
break
param_col = param_cols[param_idx]
param_name = param_col.replace("param_", "")
with cols[col_idx]:
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
# Numeric parameter - scatter plot
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
# Use log scale if needed
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
if use_log:
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(
param_name,
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name} (log scale)")
)
else:
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(param_name, title=param_name),
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
else:
# Categorical parameter - box plot
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
chart = (
alt.Chart(df_plot)
.mark_boxplot()
.encode(
alt.X(param_name, title=param_name),
alt.Y(metric, title=metric.replace("_", " ").title()),
tooltip=[param_name, alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
def render_parameter_correlation(results: pd.DataFrame, metric: str):
"""Render correlation heatmap between parameters and score.
Args:
results: DataFrame with CV results.
metric: The metric to analyze (e.g., 'f1', 'accuracy').
"""
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get numeric parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
if not numeric_params:
st.warning("No numeric parameters found for correlation analysis.")
return
# Calculate correlations
correlations = []
for param_col in numeric_params:
param_name = param_col.replace("param_", "")
corr = results[[param_col, score_col]].corr().iloc[0, 1]
correlations.append({"Parameter": param_name, "Correlation": corr})
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
# Create bar chart
chart = (
alt.Chart(corr_df)
.mark_bar()
.encode(
alt.X("Correlation", title="Correlation with Score"),
alt.Y("Parameter", sort="-x", title="Parameter"),
alt.Color(
"Correlation",
scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
legend=None,
),
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
)
.properties(height=max(200, len(correlations) * 30))
)
st.altair_chart(chart, use_container_width=True)
# Show correlation table
with st.expander("📋 Correlation Table"):
st.dataframe(
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
hide_index=True,
use_container_width=True,
)
def render_score_evolution(results: pd.DataFrame, metric: str):
"""Render evolution of scores during search.
Args:
results: DataFrame with CV results.
metric: The metric to plot (e.g., 'f1', 'accuracy').
"""
st.subheader(f"📉 {metric.replace('_', ' ').title()} Evolution")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Create a copy with iteration number
df_plot = results[[score_col]].copy()
df_plot["Iteration"] = range(len(df_plot))
df_plot["Best So Far"] = df_plot[score_col].cummax()
df_plot = df_plot.rename(columns={score_col: "Score"})
# Reshape for Altair
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
# Create line chart
chart = (
alt.Chart(df_long)
.mark_line()
.encode(
alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()),
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
strokeDash=alt.StrokeDash(
"Type",
legend=None,
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
),
tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")],
)
.properties(height=400)
)
st.altair_chart(chart, use_container_width=True)
# Show statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}")
with col2:
st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}")
with col3:
st.metric("Std Dev", f"{df_plot['Score'].std():.4f}")
with col4:
# Find iteration where best was found
best_iter = df_plot["Score"].idxmax()
st.metric("Best at Iteration", best_iter)
def render_multi_metric_comparison(results: pd.DataFrame):
"""Render comparison of multiple metrics.
Args:
results: DataFrame with CV results.
"""
st.subheader("📊 Multi-Metric Comparison")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
if len(score_cols) < 2:
st.warning("Need at least 2 metrics for comparison.")
return
# Let user select two metrics to compare
col1, col2 = st.columns(2)
with col1:
metric1 = st.selectbox(
"Select First Metric",
options=[col.replace("mean_test_", "") for col in score_cols],
index=0,
key="metric1_select",
)
with col2:
metric2 = st.selectbox(
"Select Second Metric",
options=[col.replace("mean_test_", "") for col in score_cols],
index=min(1, len(score_cols) - 1),
key="metric2_select",
)
if metric1 == metric2:
st.warning("Please select different metrics.")
return
# Create scatter plot
df_plot = pd.DataFrame(
{
metric1: results[f"mean_test_{metric1}"],
metric2: results[f"mean_test_{metric2}"],
"Iteration": range(len(results)),
}
)
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(metric1, title=metric1.replace("_", " ").title()),
alt.Y(metric2, title=metric2.replace("_", " ").title()),
alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
tooltip=[
alt.Tooltip(metric1, format=".4f"),
alt.Tooltip(metric2, format=".4f"),
"Iteration",
],
)
.properties(height=500)
)
st.altair_chart(chart, use_container_width=True)
# Calculate correlation
corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}")
def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10):
"""Render table of top N configurations.
Args:
results: DataFrame with CV results.
metric: The metric to rank by (e.g., 'f1', 'accuracy').
top_n: Number of top configurations to show.
"""
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Get top N configurations
top_configs = results.nlargest(top_n, score_col)
# Create display dataframe
display_cols = ["rank_test_" + metric, score_col, *param_cols]
display_cols = [col for col in display_cols if col in top_configs.columns]
display_df = top_configs[display_cols].copy()
# Rename columns for better display
display_df = display_df.rename(
columns={
"rank_test_" + metric: "Rank",
score_col: metric.replace("_", " ").title(),
}
)
# Rename parameter columns
display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns]
# Format score column
score_col_display = metric.replace("_", " ").title()
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
st.dataframe(display_df, hide_index=True, use_container_width=True)

View file

@ -0,0 +1,449 @@
"""Plotting functions for model state visualization."""
import altair as alt
import pandas as pd
import xarray as xr
def plot_top_features(model_state: xr.Dataset, top_n: int = 10) -> alt.Chart:
"""Plot the top N most important features based on feature weights.
Args:
model_state: The xarray Dataset containing the model state.
top_n: Number of top features to display.
Returns:
Altair chart showing the top features by importance.
"""
# Extract feature weights
feature_weights = model_state["feature_weights"].to_pandas()
# Sort by absolute weight and take top N
top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True)
# Create DataFrame for plotting with original (signed) weights
plot_data = pd.DataFrame(
{
"feature": top_features.index,
"weight": feature_weights.loc[top_features.index].to_numpy(),
"abs_weight": top_features.to_numpy(),
}
)
# Create horizontal bar chart
chart = (
alt.Chart(plot_data)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=400,
title=f"Top {top_n} Most Important Features",
)
)
return chart
def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing embedding feature weights across bands and years.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = embedding_array.to_dataframe(name="weight").reset_index()
# Create faceted heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("year:O", title="Year"),
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("agg:N", title="Aggregation"),
alt.Tooltip("band:N", title="Band"),
alt.Tooltip("year:O", title="Year"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(width=200, height=200)
.facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11)
)
return chart
def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Tuple of three Altair charts (by_agg, by_band, by_year).
"""
# Aggregate by different dimensions
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
# Create DataFrames
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
# Sort by weight
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
df_band = df_band.sort_values("mean_abs_weight", ascending=True)
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_agg = (
alt.Chart(df_agg)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Aggregation", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="blues"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Aggregation"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Aggregation")
)
chart_band = (
alt.Chart(df_band)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Band", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="greens"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Band"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Band")
)
chart_year = (
alt.Chart(df_year)
.mark_bar()
.encode(
y=alt.Y("dimension:O", title="Year", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="oranges"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:O", title="Year"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Year")
)
return chart_agg, chart_band, chart_year
def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing ERA5 feature weights across variables and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = era5_array.to_dataframe(name="weight").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("time:N", title="Time", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("time:N", title="Time"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(
height=400,
title="ERA5 Feature Weights Heatmap",
)
)
return chart
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ERA5 weights by variable and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Tuple of two Altair charts (by_variable, by_time).
"""
# Aggregate by different dimensions
by_variable = era5_array.mean(dim="time").to_pandas().abs()
by_time = era5_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
df_time = df_time.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_variable = (
alt.Chart(df_variable)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="purples"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Variable"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Variable")
)
chart_time = (
alt.Chart(df_time)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="teals"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Time"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Time")
)
return chart_variable, chart_time
def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart:
"""Create a heatmap showing which boxes are assigned to which labels/classes.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
Returns:
Altair chart showing the box-to-label assignment heatmap.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame for plotting
df = box_assignments.to_dataframe(name="assignment").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
y=alt.Y("class:N", title="Class Label"),
color=alt.Color(
"assignment:Q",
scale=alt.Scale(scheme="viridis"),
title="Assignment Strength",
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("box:O", title="Box"),
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
],
)
.properties(
height=150,
title="Box-to-Label Assignments (Lambda Matrix)",
)
)
return chart
def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]) -> alt.Chart:
"""Create a bar chart showing how many boxes are assigned to each class.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing count of boxes per class.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame
df = box_assignments.to_dataframe(name="assignment").reset_index()
# For each box, find which class it's most strongly assigned to
box_to_class = df.groupby("box")["assignment"].idxmax()
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
# Count boxes per class
counts = primary_classes.groupby("class").size().reset_index(name="count")
# Replace the special (-1, 0] interval with "No RTS" if present
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(str(class_str).split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
# Create bar chart
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Number of Boxes"),
],
)
.properties(
width=600,
height=300,
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
)
)
return chart
def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
"""Create a bar chart showing the weights of common features.
Args:
common_array: DataArray with dimension (feature) containing feature weights.
Returns:
Altair chart showing the common feature weights.
"""
# Convert to DataFrame for plotting
df = common_array.to_dataframe(name="weight").reset_index()
# Sort by absolute weight
df["abs_weight"] = df["weight"].abs()
df = df.sort_values("abs_weight", ascending=True)
# Create bar chart
chart = (
alt.Chart(df)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x"),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=300,
title="Common Feature Weights",
)
)
return chart

View file

@ -2,9 +2,237 @@
import streamlit as st import streamlit as st
from entropice.dashboard.plots.hyperparameter_analysis import (
render_multi_metric_comparison,
render_parameter_correlation,
render_parameter_distributions,
render_performance_summary,
render_score_evolution,
render_score_vs_parameter,
render_top_configurations,
)
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.utils.training import (
format_metric_name,
get_available_metrics,
get_cv_statistics,
get_parameter_space_summary,
)
def render_training_analysis_page(): def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard.""" """Render the Training Results Analysis page of the dashboard."""
st.title("Training Results Analysis") st.title("🦾 Training Results Analysis")
st.write("This page will display analysis of training results and model performance.")
# Add more components and visualizations as needed for training results analysis. # Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`")
return
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
# Create selection options
training_options = {tr.name: tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to analyze",
)
selected_result = training_options[selected_name]
st.divider()
# Display selected run info
st.subheader("Run Information")
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
st.write(f"**Trials:** {len(selected_result.results)}")
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
# Refit metric - determine from available metrics
available_metrics = get_available_metrics(selected_result.results)
# Try to get refit metric from settings
refit_metric = selected_result.settings.get("refit_metric")
if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric
task = selected_result.settings.get("task", "binary")
if task == "binary" and "f1" in available_metrics:
refit_metric = "f1"
elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted"
elif "accuracy" in available_metrics:
refit_metric = "accuracy"
elif available_metrics:
refit_metric = available_metrics[0]
else:
st.error("No metrics found in results.")
return
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results)
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
else:
default_metric_idx = 0
selected_metric = st.selectbox(
"Primary Metric for Analysis",
options=available_metrics,
index=default_metric_idx,
format_func=format_metric_name,
help="Select the metric to focus on for detailed analysis",
)
# Top N configurations
top_n = st.slider(
"Top N Configurations",
min_value=5,
max_value=50,
value=10,
step=5,
help="Number of top configurations to display",
)
# Main content area
results = selected_result.results
settings = selected_result.settings
# Performance Summary Section
st.header("📊 Performance Overview")
render_performance_summary(results, refit_metric)
st.divider()
# Quick Statistics
st.header("📈 Cross-Validation Statistics")
cv_stats = get_cv_statistics(results, selected_metric)
if cv_stats:
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
with col2:
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
with col3:
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
with col4:
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
with col5:
st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
if "mean_cv_std" in cv_stats:
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
st.divider()
# Score Evolution
st.header("📉 Training Progress")
render_score_evolution(results, selected_metric)
st.divider()
# Parameter Space Exploration
st.header("🔍 Parameter Space Analysis")
# Show parameter space summary
with st.expander("📋 Parameter Space Summary", expanded=False):
param_summary = get_parameter_space_summary(results)
if not param_summary.empty:
st.dataframe(param_summary, hide_index=True, use_container_width=True)
else:
st.info("No parameter information available.")
# Parameter distributions
render_parameter_distributions(results)
st.divider()
# Score vs Parameters
st.header("🎯 Parameter Impact Analysis")
render_score_vs_parameter(results, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation Analysis")
render_parameter_correlation(results, selected_metric)
st.divider()
# Multi-Metric Comparison
if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Analysis")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
with st.expander("💾 Export Data", expanded=False):
st.subheader("Download Results")
col1, col2 = st.columns(2)
with col1:
# Download full results as CSV
csv_data = results.to_csv(index=False)
st.download_button(
label="📥 Download Full Results (CSV)",
data=csv_data,
file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv",
use_container_width=True,
)
with col2:
# Download settings as text
import json
settings_json = json.dumps(settings, indent=2)
st.download_button(
label="⚙️ Download Settings (JSON)",
data=settings_json,
file_name=f"{selected_result.path.name}_settings.json",
mime="application/json",
use_container_width=True,
)
# Show raw data preview
st.subheader("Raw Data Preview")
st.dataframe(results.head(100), use_container_width=True)

View file

@ -118,6 +118,42 @@ def render_training_data_page():
# Display dataset ID in a styled container # Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`") st.info(f"**Dataset ID:** `{ensemble.id()}`")
# Display dataset statistics
st.markdown("---")
st.subheader("📈 Dataset Statistics")
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Display target information
col1, col2 = st.columns(2)
with col1:
st.metric(label="Target", value=stats["target"].replace("darts_", ""))
with col2:
st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}")
# Display member statistics
st.markdown("**Member Statistics:**")
for member, member_stats in stats["members"].items():
with st.expander(f"📦 {member}", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**Number of Features:** {member_stats['num_features']}")
st.markdown(f"**Number of Variables:** {member_stats['num_variables']}")
with col2:
st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`")
# Display variables as a compact list
st.markdown(f"**Variables ({member_stats['num_variables']}):**")
vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]])
st.markdown(vars_str)
# Display total features
st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}")
st.markdown("---")
# Create tabs for different data views # Create tabs for different data views
tab_names = ["📊 Labels"] tab_names = ["📊 Labels"]

View file

@ -8,6 +8,7 @@ import antimeridian
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
import toml import toml
import xarray as xr
from shapely.geometry import shape from shapely.geometry import shape
import entropice.paths import entropice.paths
@ -119,3 +120,98 @@ def load_source_data(e: DatasetEnsemble, source: str):
ds = e._read_member(source, targets, lazy=False) ds = e._read_member(source, targets, lazy=False)
return ds, targets return ds, targets
def extract_embedding_features(model_state) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
('agg', 'band', 'year') corresponding to the different components of the embedding features.
Returns None if no embedding features are found.
"""
def _is_embedding_feature(feature: str) -> bool:
return feature.startswith("embedding_")
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[1] for f in embedding_features]),
band=("feature", [f.split("_")[2] for f in embedding_features]),
year=("feature", [f.split("_")[3] for f in embedding_features]),
)
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
return embedding_feature_array
def extract_era5_features(model_state) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
Returns None if no ERA5 features are found.
"""
def _is_era5_feature(feature: str) -> bool:
return feature.startswith("era5_")
def _extract_var_name(feature: str) -> str:
parts = feature.split("_")
# era5_variablename_timetype format
return "_".join(parts[1:-1])
def _extract_time_name(feature: str) -> str:
parts = feature.split("_")
# Last part is the time type
return parts[-1]
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
def extract_common_features(model_state) -> xr.DataArray | None:
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted common features with a single 'feature' dimension.
Returns None if no common features are found.
"""
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
def _is_common_feature(feature: str) -> bool:
return feature in common_feature_names
common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)]
if len(common_features) == 0:
return None
# Extract the feature weights for common features
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
return common_feature_array

View file

@ -0,0 +1,232 @@
"""Training utilities for dashboard."""
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.data import TrainingResult
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)
def get_available_metrics(results: pd.DataFrame) -> list[str]:
"""Get list of available metrics from results.
Args:
results: DataFrame with CV results.
Returns:
List of metric names (without 'mean_test_' prefix).
"""
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
return [col.replace("mean_test_", "") for col in score_cols]
def load_best_model(result: TrainingResult):
"""Load the best model from a training result.
Args:
result: TrainingResult object.
Returns:
The loaded model object, or None if loading fails.
"""
model_file = result.path / "best_estimator_model.pkl"
if not model_file.exists():
return None
try:
with open(model_file, "rb") as f:
model = pickle.load(f)
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
"""Load the model state from a training result.
Args:
result: TrainingResult object.
Returns:
xarray Dataset with model state, or None if not available.
"""
state_file = result.path / "best_estimator_state.nc"
if not state_file.exists():
return None
try:
state = xr.open_dataset(state_file, engine="h5netcdf")
return state
except Exception as e:
st.error(f"Error loading model state: {e}")
return None
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
"""Load predictions from a training result.
Args:
result: TrainingResult object.
Returns:
DataFrame with predictions, or None if not available.
"""
preds_file = result.path / "predicted_probabilities.parquet"
if not preds_file.exists():
return None
try:
preds = pd.read_parquet(preds_file)
return preds
except Exception as e:
st.error(f"Error loading predictions: {e}")
return None
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
"""Get summary of parameter space explored.
Args:
results: DataFrame with CV results.
Returns:
DataFrame with parameter ranges and statistics.
"""
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
summary_data = []
for param_col in param_cols:
param_name = param_col.replace("param_", "")
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
summary_data.append(
{
"Parameter": param_name,
"Type": "Numeric",
"Min": f"{param_values.min():.2e}",
"Max": f"{param_values.max():.2e}",
"Mean": f"{param_values.mean():.2e}",
"Unique Values": param_values.nunique(),
}
)
else:
unique_vals = param_values.unique()
summary_data.append(
{
"Parameter": param_name,
"Type": "Categorical",
"Min": "-",
"Max": "-",
"Mean": "-",
"Unique Values": len(unique_vals),
}
)
return pd.DataFrame(summary_data)
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
"""Get cross-validation statistics for a metric.
Args:
results: DataFrame with CV results.
metric: Metric name (without 'mean_test_' prefix).
Returns:
Dictionary with CV statistics.
"""
score_col = f"mean_test_{metric}"
std_col = f"std_test_{metric}"
if score_col not in results.columns:
return {}
stats = {
"best_score": results[score_col].max(),
"mean_score": results[score_col].mean(),
"std_score": results[score_col].std(),
"worst_score": results[score_col].min(),
"median_score": results[score_col].median(),
}
if std_col in results.columns:
stats["mean_cv_std"] = results[std_col].mean()
return stats
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
"""Prepare results dataframe with binned columns for plotting.
Args:
results: DataFrame with CV results.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results_copy = results.copy()
# Check if we have the parameters
if "param_initial_K" in results.columns:
# Bin initial_K
k_values = results["param_initial_K"].dropna()
if len(k_values) > 0:
k_min = k_values.min()
k_max = k_values.max()
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
if "param_eps_cl" in results.columns:
# Create logarithmic bins for eps_cl
eps_cl_values = results["param_eps_cl"].dropna()
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
eps_cl_min = eps_cl_values.min()
eps_cl_max = eps_cl_values.max()
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
if "param_eps_e" in results.columns:
# Create logarithmic bins for eps_e
eps_e_values = results["param_eps_e"].dropna()
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
eps_e_min = eps_e_values.min()
eps_e_max = eps_e_values.max()
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
return results_copy

View file

@ -283,25 +283,52 @@ class DatasetEnsemble:
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
return arcticdem_df return arcticdem_df
def print_stats(self): def get_stats(self) -> dict:
targets = self._read_target() """Get dataset statistics.
print(f"=== Target: {self.target}")
print(f"\tNumber of target samples: {len(targets)}") Returns:
dict: Dictionary containing target stats, member stats, and total features count.
"""
targets = self._read_target()
stats = {
"target": self.target,
"num_target_samples": len(targets),
"members": {},
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
}
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
for member in self.members: for member in self.members:
ds = self._read_member(member, targets, lazy=True) ds = self._read_member(member, targets, lazy=True)
print(f"=== Member: {member}")
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
print(f"\tDimensions: {dict(ds.sizes)}")
print(f"\tCoordinates: {list(ds.coords)}")
n_cols_member = len(ds.data_vars) n_cols_member = len(ds.data_vars)
for dim in ds.sizes: for dim in ds.sizes:
if dim != "cell_ids": if dim != "cell_ids":
n_cols_member *= ds.sizes[dim] n_cols_member *= ds.sizes[dim]
print(f"\tNumber of features from member: {n_cols_member}")
n_cols += n_cols_member stats["members"][member] = {
print(f"=== Total number of features in dataset: {n_cols}") "variables": list(ds.data_vars),
"num_variables": len(ds.data_vars),
"dimensions": dict(ds.sizes),
"coordinates": list(ds.coords),
"num_features": n_cols_member,
}
stats["total_features"] += n_cols_member
return stats
def print_stats(self):
stats = self.get_stats()
print(f"=== Target: {stats['target']}")
print(f"\tNumber of target samples: {stats['num_target_samples']}")
for member, member_stats in stats["members"].items():
print(f"=== Member: {member}")
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
print(f"\tDimensions: {member_stats['dimensions']}")
print(f"\tCoordinates: {member_stats['coordinates']}")
print(f"\tNumber of features from member: {member_stats['num_features']}")
print(f"=== Total number of features in dataset: {stats['total_features']}")
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: