Add training results page and model state page

This commit is contained in:
Tobias Hölzer 2025-12-19 02:59:42 +01:00
parent 31933b58d3
commit 696bef39c2
9 changed files with 1988 additions and 17 deletions

View file

@ -1,8 +1,322 @@
"""Model State page for the Entropice dashboard."""
import streamlit as st
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import (
plot_box_assignment_bars,
plot_box_assignments,
plot_common_features,
plot_embedding_aggregation_summary,
plot_embedding_heatmap,
plot_era5_heatmap,
plot_era5_summary,
plot_top_features,
)
from entropice.dashboard.utils.data import (
extract_common_features,
extract_embedding_features,
extract_era5_features,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
def render_model_state_page():
"""Render the Model State page of the dashboard."""
st.title("Model State")
st.write("This page will display model state and feature visualizations.")
# Add more components and visualizations as needed for model state.
st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
# Load available training results
training_results = load_all_training_results()
if not training_results:
st.error("No training results found. Please run a training search first.")
return
# Result selection
result_options = {tr.name: tr for tr in training_results}
selected_name = st.selectbox(
"Select Training Result",
options=list(result_options.keys()),
help="Choose a training result to visualize model state",
)
selected_result = result_options[selected_name]
# Load model state
with st.spinner("Loading model state..."):
model_state = load_model_state(selected_result)
if model_state is None:
st.error("Could not load model state for this result.")
return
# Scale feature weights by number of features
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
# Extract different feature types
embedding_feature_array = extract_embedding_features(model_state)
era5_feature_array = extract_era5_features(model_state)
common_feature_array = extract_common_features(model_state)
# Generate unified colormaps
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
# Display basic model state info
with st.expander("Model State Information", expanded=False):
st.write(f"**Variables:** {list(model_state.data_vars)}")
st.write(f"**Dimensions:** {dict(model_state.sizes)}")
st.write(f"**Coordinates:** {list(model_state.coords)}")
# Show statistics
st.write("**Feature Weight Statistics:**")
feature_weights = model_state["feature_weights"].to_pandas()
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Mean Weight", f"{feature_weights.mean():.4f}")
with col2:
st.metric("Max Weight", f"{feature_weights.max():.4f}")
with col3:
st.metric("Total Features", len(feature_weights))
# Feature importance section
st.header("Feature Importance")
st.markdown("The most important features based on learned feature weights from the best estimator.")
@st.fragment
def render_feature_importance():
# Slider to control number of features to display
top_n = st.slider(
"Number of top features to display",
min_value=5,
max_value=50,
value=10,
step=5,
help="Select how many of the most important features to visualize",
)
with st.spinner("Generating feature importance plot..."):
feature_chart = plot_top_features(model_state, top_n=top_n)
st.altair_chart(feature_chart, use_container_width=True)
st.markdown(
"""
**Interpretation:**
- **Magnitude**: Larger absolute values indicate more important features
- **Color**: Blue bars indicate positive weights, coral bars indicate negative weights
"""
)
render_feature_importance()
# Box-to-Label Assignment Visualization
st.header("Box-to-Label Assignments")
st.markdown(
"""
This visualization shows how the learned boxes (prototypes in feature space) are
assigned to different class labels. The ESPA classifier learns K boxes and assigns
them to classes through the Lambda matrix. Higher values indicate stronger assignment
of a box to a particular class.
"""
)
with st.spinner("Generating box assignment visualizations..."):
col1, col2 = st.columns([0.7, 0.3])
with col1:
st.markdown("### Assignment Heatmap")
box_assignment_heatmap = plot_box_assignments(model_state)
st.altair_chart(box_assignment_heatmap, use_container_width=True)
with col2:
st.markdown("### Box Count by Class")
box_assignment_bars = plot_box_assignment_bars(model_state, altair_colors)
st.altair_chart(box_assignment_bars, use_container_width=True)
# Show statistics
with st.expander("Box Assignment Statistics"):
box_assignments = model_state["box_assignments"].to_pandas()
st.write("**Assignment Matrix Statistics:**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Boxes", len(box_assignments.columns))
with col2:
st.metric("Number of Classes", len(box_assignments.index))
with col3:
st.metric("Mean Assignment", f"{box_assignments.to_numpy().mean():.4f}")
with col4:
st.metric("Max Assignment", f"{box_assignments.to_numpy().max():.4f}")
# Show which boxes are most strongly assigned to each class
st.write("**Top Box Assignments per Class:**")
for class_label in box_assignments.index:
top_boxes = box_assignments.loc[class_label].nlargest(5)
st.write(
f"**Class {class_label}:** Boxes {', '.join(map(str, top_boxes.index.tolist()))} "
f"(strengths: {', '.join(f'{v:.3f}' for v in top_boxes.to_numpy())})"
)
st.markdown(
"""
**Interpretation:**
- Each box can be assigned to multiple classes with different strengths
- Boxes with higher assignment values for a class contribute more to that class's predictions
- The distribution shows how the model partitions the feature space for classification
"""
)
# Embedding features analysis (if present)
if embedding_feature_array is not None:
with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis")
st.markdown(
"""
Analysis of embedding features showing which aggregations, bands, and years
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating dimension summaries..."):
chart_agg, chart_band, chart_year = plot_embedding_aggregation_summary(embedding_feature_array)
col1, col2, col3 = st.columns(3)
with col1:
st.altair_chart(chart_agg, use_container_width=True)
with col2:
st.altair_chart(chart_band, use_container_width=True)
with col3:
st.altair_chart(chart_year, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap by Aggregation")
st.markdown("Shows the weight of each band-year combination for each aggregation type.")
with st.spinner("Generating heatmap..."):
heatmap_chart = plot_embedding_heatmap(embedding_feature_array)
st.altair_chart(heatmap_chart, use_container_width=True)
# Statistics
with st.expander("Embedding Feature Statistics"):
st.write("**Overall Statistics:**")
n_emb_features = embedding_feature_array.size
mean_weight = float(embedding_feature_array.mean().values)
max_weight = float(embedding_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Embedding Features", n_emb_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top embedding features
st.write("**Top 10 Embedding Features:**")
emb_df = embedding_feature_array.to_dataframe(name="weight").reset_index()
top_emb = emb_df.nlargest(10, "weight")[["agg", "band", "year", "weight"]]
st.dataframe(top_emb, width="stretch")
else:
st.info("No embedding features found in this model.")
# ERA5 features analysis (if present)
if era5_feature_array is not None:
with st.container(border=True):
st.header("⛅ ERA5 Feature Analysis")
st.markdown(
"""
Analysis of ERA5 climate features showing which variables and time periods
are most important for the model predictions.
"""
)
# Summary bar charts
st.markdown("### Importance by Dimension")
with st.spinner("Generating ERA5 dimension summaries..."):
chart_variable, chart_time = plot_era5_summary(era5_feature_array)
col1, col2 = st.columns(2)
with col1:
st.altair_chart(chart_variable, use_container_width=True)
with col2:
st.altair_chart(chart_time, use_container_width=True)
# Detailed heatmap
st.markdown("### Detailed Heatmap")
st.markdown("Shows the weight of each variable-time combination.")
with st.spinner("Generating ERA5 heatmap..."):
era5_heatmap_chart = plot_era5_heatmap(era5_feature_array)
st.altair_chart(era5_heatmap_chart, use_container_width=True)
# Statistics
with st.expander("ERA5 Feature Statistics"):
st.write("**Overall Statistics:**")
n_era5_features = era5_feature_array.size
mean_weight = float(era5_feature_array.mean().values)
max_weight = float(era5_feature_array.max().values)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total ERA5 Features", n_era5_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
# Show top ERA5 features
st.write("**Top 10 ERA5 Features:**")
era5_df = era5_feature_array.to_dataframe(name="weight").reset_index()
top_era5 = era5_df.nlargest(10, "weight")[["variable", "time", "weight"]]
st.dataframe(top_era5, width="stretch")
else:
st.info("No ERA5 features found in this model.")
# Common features analysis (if present)
if common_feature_array is not None:
with st.container(border=True):
st.header("🗺️ Common Feature Analysis")
st.markdown(
"""
Analysis of common features including cell area, water area, land area, land ratio,
longitude, and latitude. These features provide spatial and geographic context.
"""
)
# Bar chart showing all common feature weights
with st.spinner("Generating common features chart..."):
common_chart = plot_common_features(common_feature_array)
st.altair_chart(common_chart, use_container_width=True)
# Statistics
with st.expander("Common Feature Statistics"):
st.write("**Overall Statistics:**")
n_common_features = common_feature_array.size
mean_weight = float(common_feature_array.mean().values)
max_weight = float(common_feature_array.max().values)
min_weight = float(common_feature_array.min().values)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Common Features", n_common_features)
with col2:
st.metric("Mean Weight", f"{mean_weight:.4f}")
with col3:
st.metric("Max Weight", f"{max_weight:.4f}")
with col4:
st.metric("Min Weight", f"{min_weight:.4f}")
# Show all common features sorted by importance
st.write("**All Common Features (by absolute weight):**")
common_df = common_feature_array.to_dataframe(name="weight").reset_index()
common_df["abs_weight"] = common_df["weight"].abs()
common_df = common_df.sort_values("abs_weight", ascending=False)
st.dataframe(common_df[["feature", "weight", "abs_weight"]], width="stretch")
st.markdown(
"""
**Interpretation:**
- **cell_area, water_area, land_area**: Spatial extent features that may indicate
size-related patterns
- **land_ratio**: Proportion of land vs water in each cell
- **lon, lat**: Geographic coordinates that can capture spatial trends or regional patterns
- Positive weights indicate features that increase the probability of the positive class
- Negative weights indicate features that decrease the probability of the positive class
"""
)
else:
st.info("No common features found in this model.")

View file

@ -25,6 +25,9 @@ Material palettes:
"""
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from pypalettes import load_cmap
@ -90,3 +93,54 @@ def get_palette(variable: str, n_colors: int) -> list[str]:
cmap = get_cmap(variable).resampled(n_colors)
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
return colors
def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, mcolors.ListedColormap, list[str]]:
"""Generate unified colormaps for all plotting libraries.
This function creates consistent color schemes across Matplotlib/Ultraplot,
Folium/Leaflet, and Altair/Vega-Lite by determining the task type and number
of classes from the settings, then generating appropriate colormaps for each library.
Args:
settings: Settings dictionary containing task type, classes, and other configuration.
Returns:
Tuple of (matplotlib_cmap, folium_cmap, altair_colors) where:
- matplotlib_cmap: matplotlib ListedColormap object
- folium_cmap: matplotlib ListedColormap object (for geopandas.explore)
- altair_colors: list of hex color strings for Altair
"""
# Determine task type and number of classes from settings
task = settings.get("task", "binary")
n_classes = len(settings.get("classes", []))
# Check theme
is_dark_theme = st.context.theme.type == "dark"
# Define base colormaps for different tasks
if task == "binary":
# For binary: use a simple two-color scheme
if is_dark_theme:
base_colors = ["#1f77b4", "#ff7f0e"] # Blue and orange for dark theme
else:
base_colors = ["#3498db", "#e74c3c"] # Brighter blue and red for light theme
else:
# For multi-class: use a sequential colormap
# Use matplotlib's viridis colormap
cmap = plt.get_cmap("viridis")
# Sample colors evenly across the colormap
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
# Create matplotlib colormap (for ultraplot and geopandas)
matplotlib_cmap = mcolors.ListedColormap(base_colors)
# Create Folium/Leaflet colormap (geopandas.explore uses matplotlib colormaps)
folium_cmap = mcolors.ListedColormap(base_colors)
# Create Altair color list (Altair uses hex color strings in range)
altair_colors = base_colors
return matplotlib_cmap, folium_cmap, altair_colors

View file

@ -0,0 +1,535 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt
import pandas as pd
import streamlit as st
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
"""Render summary statistics of model performance.
Args:
results: DataFrame with CV results.
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
"""
st.subheader("📊 Performance Summary")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
if not score_cols:
st.warning("No test score columns found in results.")
return
# Calculate statistics for each metric
col1, col2 = st.columns(2)
with col1:
st.markdown("#### Best Scores")
best_scores = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
best_score = results[col].max()
best_scores.append({"Metric": metric_name, "Best Score": f"{best_score:.4f}"})
st.dataframe(pd.DataFrame(best_scores), hide_index=True, use_container_width=True)
with col2:
st.markdown("#### Score Statistics")
score_stats = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").replace("_", " ").title()
mean_score = results[col].mean()
std_score = results[col].std()
score_stats.append(
{
"Metric": metric_name,
"Mean ± Std": f"{mean_score:.4f} ± {std_score:.4f}",
}
)
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
# Show best parameter combination
st.markdown("#### 🏆 Best Parameter Combination")
refit_col = f"mean_test_{refit_metric}"
# Check if refit metric exists in results
if refit_col not in results.columns:
st.warning(
f"Refit metric '{refit_metric}' not found in results. Available metrics: {[col.replace('mean_test_', '') for col in score_cols]}"
)
# Use the first available metric as fallback
refit_col = score_cols[0]
refit_metric = refit_col.replace("mean_test_", "")
st.info(f"Using '{refit_metric}' as fallback metric.")
best_idx = results[refit_col].idxmax()
best_row = results.loc[best_idx]
# Extract parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if param_cols:
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
# Display in a nice formatted way
param_df = pd.DataFrame([best_params]).T
param_df.columns = ["Value"]
param_df.index.name = "Parameter"
col1, col2 = st.columns([1, 1])
with col1:
st.dataframe(param_df, use_container_width=True)
with col2:
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}")
rank_col = "rank_test_" + refit_metric
if rank_col in best_row.index:
try:
# Handle potential Series or scalar values
rank_val = best_row[rank_col]
if hasattr(rank_val, "item"):
rank_val = rank_val.item()
rank_display = str(int(float(rank_val)))
except (ValueError, TypeError, AttributeError):
rank_display = "N/A"
else:
rank_display = "N/A"
st.metric("Rank", rank_display)
def render_parameter_distributions(results: pd.DataFrame):
"""Render histograms of parameter distributions explored.
Args:
results: DataFrame with CV results.
"""
st.subheader("📈 Parameter Space Exploration")
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Create histograms for each parameter
n_params = len(param_cols)
n_cols = min(3, n_params)
n_rows = (n_params + n_cols - 1) // n_cols
for row in range(n_rows):
cols = st.columns(n_cols)
for col_idx in range(n_cols):
param_idx = row * n_cols + col_idx
if param_idx >= n_params:
break
param_col = param_cols[param_idx]
param_name = param_col.replace("param_", "")
with cols[col_idx]:
# Check if parameter is numeric or categorical
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
# Numeric parameter - use histogram
df_plot = pd.DataFrame({param_name: param_values})
# Use log scale if the range spans multiple orders of magnitude
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
if use_log:
chart = (
alt.Chart(df_plot)
.mark_bar()
.encode(
alt.X(
param_name,
bin=alt.Bin(maxbins=30),
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
)
.properties(height=250, title=f"{param_name} (log scale)")
)
else:
chart = (
alt.Chart(df_plot)
.mark_bar()
.encode(
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
)
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
else:
# Categorical parameter - use bar chart
value_counts = param_values.value_counts().reset_index()
value_counts.columns = [param_name, "count"]
chart = (
alt.Chart(value_counts)
.mark_bar()
.encode(
alt.X(param_name, title=param_name, sort="-y"),
alt.Y("count", title="Count"),
tooltip=[param_name, "count"],
)
.properties(height=250, title=param_name)
)
st.altair_chart(chart, use_container_width=True)
def render_score_vs_parameter(results: pd.DataFrame, metric: str):
"""Render scatter plots of score vs each parameter.
Args:
results: DataFrame with CV results.
metric: The metric to plot (e.g., 'f1', 'accuracy').
"""
st.subheader(f"🎯 {metric.replace('_', ' ').title()} vs Parameters")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Create scatter plots for each parameter
n_params = len(param_cols)
n_cols = min(2, n_params)
n_rows = (n_params + n_cols - 1) // n_cols
for row in range(n_rows):
cols = st.columns(n_cols)
for col_idx in range(n_cols):
param_idx = row * n_cols + col_idx
if param_idx >= n_params:
break
param_col = param_cols[param_idx]
param_name = param_col.replace("param_", "")
with cols[col_idx]:
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
# Numeric parameter - scatter plot
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
# Use log scale if needed
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100
if use_log:
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(
param_name,
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".2e"), alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name} (log scale)")
)
else:
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(param_name, title=param_name),
alt.Y(metric, title=metric.replace("_", " ").title()),
alt.Color(
metric,
scale=alt.Scale(scheme="viridis"),
legend=None,
),
tooltip=[alt.Tooltip(param_name, format=".3f"), alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
else:
# Categorical parameter - box plot
df_plot = pd.DataFrame({param_name: results[param_col], metric: results[score_col]})
chart = (
alt.Chart(df_plot)
.mark_boxplot()
.encode(
alt.X(param_name, title=param_name),
alt.Y(metric, title=metric.replace("_", " ").title()),
tooltip=[param_name, alt.Tooltip(metric, format=".4f")],
)
.properties(height=300, title=f"{metric} vs {param_name}")
)
st.altair_chart(chart, use_container_width=True)
def render_parameter_correlation(results: pd.DataFrame, metric: str):
"""Render correlation heatmap between parameters and score.
Args:
results: DataFrame with CV results.
metric: The metric to analyze (e.g., 'f1', 'accuracy').
"""
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get numeric parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
if not numeric_params:
st.warning("No numeric parameters found for correlation analysis.")
return
# Calculate correlations
correlations = []
for param_col in numeric_params:
param_name = param_col.replace("param_", "")
corr = results[[param_col, score_col]].corr().iloc[0, 1]
correlations.append({"Parameter": param_name, "Correlation": corr})
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
# Create bar chart
chart = (
alt.Chart(corr_df)
.mark_bar()
.encode(
alt.X("Correlation", title="Correlation with Score"),
alt.Y("Parameter", sort="-x", title="Parameter"),
alt.Color(
"Correlation",
scale=alt.Scale(scheme="redblue", domain=[-1, 1]),
legend=None,
),
tooltip=["Parameter", alt.Tooltip("Correlation", format=".3f")],
)
.properties(height=max(200, len(correlations) * 30))
)
st.altair_chart(chart, use_container_width=True)
# Show correlation table
with st.expander("📋 Correlation Table"):
st.dataframe(
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
hide_index=True,
use_container_width=True,
)
def render_score_evolution(results: pd.DataFrame, metric: str):
"""Render evolution of scores during search.
Args:
results: DataFrame with CV results.
metric: The metric to plot (e.g., 'f1', 'accuracy').
"""
st.subheader(f"📉 {metric.replace('_', ' ').title()} Evolution")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Create a copy with iteration number
df_plot = results[[score_col]].copy()
df_plot["Iteration"] = range(len(df_plot))
df_plot["Best So Far"] = df_plot[score_col].cummax()
df_plot = df_plot.rename(columns={score_col: "Score"})
# Reshape for Altair
df_long = df_plot.melt(id_vars=["Iteration"], value_vars=["Score", "Best So Far"], var_name="Type")
# Create line chart
chart = (
alt.Chart(df_long)
.mark_line()
.encode(
alt.X("Iteration", title="Iteration"),
alt.Y("value", title=metric.replace("_", " ").title()),
alt.Color("Type", legend=alt.Legend(title=""), scale=alt.Scale(scheme="category10")),
strokeDash=alt.StrokeDash(
"Type",
legend=None,
scale=alt.Scale(domain=["Score", "Best So Far"], range=[[1, 0], [5, 5]]),
),
tooltip=["Iteration", "Type", alt.Tooltip("value", format=".4f", title="Score")],
)
.properties(height=400)
)
st.altair_chart(chart, use_container_width=True)
# Show statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Best Score", f"{df_plot['Best So Far'].iloc[-1]:.4f}")
with col2:
st.metric("Mean Score", f"{df_plot['Score'].mean():.4f}")
with col3:
st.metric("Std Dev", f"{df_plot['Score'].std():.4f}")
with col4:
# Find iteration where best was found
best_iter = df_plot["Score"].idxmax()
st.metric("Best at Iteration", best_iter)
def render_multi_metric_comparison(results: pd.DataFrame):
"""Render comparison of multiple metrics.
Args:
results: DataFrame with CV results.
"""
st.subheader("📊 Multi-Metric Comparison")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
if len(score_cols) < 2:
st.warning("Need at least 2 metrics for comparison.")
return
# Let user select two metrics to compare
col1, col2 = st.columns(2)
with col1:
metric1 = st.selectbox(
"Select First Metric",
options=[col.replace("mean_test_", "") for col in score_cols],
index=0,
key="metric1_select",
)
with col2:
metric2 = st.selectbox(
"Select Second Metric",
options=[col.replace("mean_test_", "") for col in score_cols],
index=min(1, len(score_cols) - 1),
key="metric2_select",
)
if metric1 == metric2:
st.warning("Please select different metrics.")
return
# Create scatter plot
df_plot = pd.DataFrame(
{
metric1: results[f"mean_test_{metric1}"],
metric2: results[f"mean_test_{metric2}"],
"Iteration": range(len(results)),
}
)
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
.encode(
alt.X(metric1, title=metric1.replace("_", " ").title()),
alt.Y(metric2, title=metric2.replace("_", " ").title()),
alt.Color("Iteration", scale=alt.Scale(scheme="viridis")),
tooltip=[
alt.Tooltip(metric1, format=".4f"),
alt.Tooltip(metric2, format=".4f"),
"Iteration",
],
)
.properties(height=500)
)
st.altair_chart(chart, use_container_width=True)
# Calculate correlation
corr = df_plot[[metric1, metric2]].corr().iloc[0, 1]
st.metric(f"Correlation between {metric1} and {metric2}", f"{corr:.3f}")
def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 10):
"""Render table of top N configurations.
Args:
results: DataFrame with CV results.
metric: The metric to rank by (e.g., 'f1', 'accuracy').
top_n: Number of top configurations to show.
"""
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
if not param_cols:
st.warning("No parameter columns found in results.")
return
# Get top N configurations
top_configs = results.nlargest(top_n, score_col)
# Create display dataframe
display_cols = ["rank_test_" + metric, score_col, *param_cols]
display_cols = [col for col in display_cols if col in top_configs.columns]
display_df = top_configs[display_cols].copy()
# Rename columns for better display
display_df = display_df.rename(
columns={
"rank_test_" + metric: "Rank",
score_col: metric.replace("_", " ").title(),
}
)
# Rename parameter columns
display_df.columns = [col.replace("param_", "") if col.startswith("param_") else col for col in display_df.columns]
# Format score column
score_col_display = metric.replace("_", " ").title()
display_df[score_col_display] = display_df[score_col_display].apply(lambda x: f"{x:.4f}")
st.dataframe(display_df, hide_index=True, use_container_width=True)

View file

@ -0,0 +1,449 @@
"""Plotting functions for model state visualization."""
import altair as alt
import pandas as pd
import xarray as xr
def plot_top_features(model_state: xr.Dataset, top_n: int = 10) -> alt.Chart:
"""Plot the top N most important features based on feature weights.
Args:
model_state: The xarray Dataset containing the model state.
top_n: Number of top features to display.
Returns:
Altair chart showing the top features by importance.
"""
# Extract feature weights
feature_weights = model_state["feature_weights"].to_pandas()
# Sort by absolute weight and take top N
top_features = feature_weights.abs().nlargest(top_n).sort_values(ascending=True)
# Create DataFrame for plotting with original (signed) weights
plot_data = pd.DataFrame(
{
"feature": top_features.index,
"weight": feature_weights.loc[top_features.index].to_numpy(),
"abs_weight": top_features.to_numpy(),
}
)
# Create horizontal bar chart
chart = (
alt.Chart(plot_data)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=400,
title=f"Top {top_n} Most Important Features",
)
)
return chart
def plot_embedding_heatmap(embedding_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing embedding feature weights across bands and years.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = embedding_array.to_dataframe(name="weight").reset_index()
# Create faceted heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("year:O", title="Year"),
y=alt.Y("band:O", title="Band", sort=alt.SortField(field="band", order="ascending")),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("agg:N", title="Aggregation"),
alt.Tooltip("band:N", title="Band"),
alt.Tooltip("year:O", title="Year"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(width=200, height=200)
.facet(facet=alt.Facet("agg:N", title="Aggregation"), columns=11)
)
return chart
def plot_embedding_aggregation_summary(embedding_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart, alt.Chart]:
"""Create bar charts summarizing embedding weights by aggregation, band, and year.
Args:
embedding_array: DataArray with dimensions (agg, band, year) containing feature weights.
Returns:
Tuple of three Altair charts (by_agg, by_band, by_year).
"""
# Aggregate by different dimensions
by_agg = embedding_array.mean(dim=["band", "year"]).to_pandas().abs()
by_band = embedding_array.mean(dim=["agg", "year"]).to_pandas().abs()
by_year = embedding_array.mean(dim=["agg", "band"]).to_pandas().abs()
# Create DataFrames
df_agg = pd.DataFrame({"dimension": by_agg.index, "mean_abs_weight": by_agg.to_numpy()})
df_band = pd.DataFrame({"dimension": by_band.index, "mean_abs_weight": by_band.to_numpy()})
df_year = pd.DataFrame({"dimension": by_year.index, "mean_abs_weight": by_year.to_numpy()})
# Sort by weight
df_agg = df_agg.sort_values("mean_abs_weight", ascending=True)
df_band = df_band.sort_values("mean_abs_weight", ascending=True)
df_year = df_year.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_agg = (
alt.Chart(df_agg)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Aggregation", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="blues"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Aggregation"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Aggregation")
)
chart_band = (
alt.Chart(df_band)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Band", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="greens"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Band"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Band")
)
chart_year = (
alt.Chart(df_year)
.mark_bar()
.encode(
y=alt.Y("dimension:O", title="Year", sort="-x"),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="oranges"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:O", title="Year"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=250, height=200, title="By Year")
)
return chart_agg, chart_band, chart_year
def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
"""Create a heatmap showing ERA5 feature weights across variables and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Altair chart showing the heatmap.
"""
# Convert to DataFrame for plotting
df = era5_array.to_dataframe(name="weight").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("time:N", title="Time", sort=None),
y=alt.Y("variable:N", title="Variable", sort="-color"),
color=alt.Color(
"weight:Q",
scale=alt.Scale(scheme="redblue", domainMid=0),
title="Weight",
),
tooltip=[
alt.Tooltip("variable:N", title="Variable"),
alt.Tooltip("time:N", title="Time"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
],
)
.properties(
height=400,
title="ERA5 Feature Weights Heatmap",
)
)
return chart
def plot_era5_summary(era5_array: xr.DataArray) -> tuple[alt.Chart, alt.Chart]:
"""Create bar charts summarizing ERA5 weights by variable and time.
Args:
era5_array: DataArray with dimensions (variable, time) containing feature weights.
Returns:
Tuple of two Altair charts (by_variable, by_time).
"""
# Aggregate by different dimensions
by_variable = era5_array.mean(dim="time").to_pandas().abs()
by_time = era5_array.mean(dim="variable").to_pandas().abs()
# Create DataFrames
df_variable = pd.DataFrame({"dimension": by_variable.index, "mean_abs_weight": by_variable.to_numpy()})
df_time = pd.DataFrame({"dimension": by_time.index, "mean_abs_weight": by_time.to_numpy()})
# Sort by weight
df_variable = df_variable.sort_values("mean_abs_weight", ascending=True)
df_time = df_time.sort_values("mean_abs_weight", ascending=True)
# Create charts with different colors
chart_variable = (
alt.Chart(df_variable)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Variable", sort="-x", axis=alt.Axis(labelLimit=300)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="purples"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Variable"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Variable")
)
chart_time = (
alt.Chart(df_time)
.mark_bar()
.encode(
y=alt.Y("dimension:N", title="Time", sort="-x", axis=alt.Axis(labelLimit=200)),
x=alt.X("mean_abs_weight:Q", title="Mean Absolute Weight"),
color=alt.Color(
"mean_abs_weight:Q",
scale=alt.Scale(scheme="teals"),
legend=None,
),
tooltip=[
alt.Tooltip("dimension:N", title="Time"),
alt.Tooltip("mean_abs_weight:Q", format=".4f", title="Mean Abs Weight"),
],
)
.properties(width=400, height=300, title="By Time")
)
return chart_variable, chart_time
def plot_box_assignments(model_state: xr.Dataset) -> alt.Chart:
"""Create a heatmap showing which boxes are assigned to which labels/classes.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
Returns:
Altair chart showing the box-to-label assignment heatmap.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame for plotting
df = box_assignments.to_dataframe(name="assignment").reset_index()
# Create heatmap
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("box:O", title="Box ID", axis=alt.Axis(labelAngle=0)),
y=alt.Y("class:N", title="Class Label"),
color=alt.Color(
"assignment:Q",
scale=alt.Scale(scheme="viridis"),
title="Assignment Strength",
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("box:O", title="Box"),
alt.Tooltip("assignment:Q", format=".4f", title="Assignment"),
],
)
.properties(
height=150,
title="Box-to-Label Assignments (Lambda Matrix)",
)
)
return chart
def plot_box_assignment_bars(model_state: xr.Dataset, altair_colors: list[str]) -> alt.Chart:
"""Create a bar chart showing how many boxes are assigned to each class.
Args:
model_state: The xarray Dataset containing the model state with box_assignments.
altair_colors: List of hex color strings for altair.
Returns:
Altair chart showing count of boxes per class.
"""
# Extract box assignments
box_assignments = model_state["box_assignments"]
# Convert to DataFrame
df = box_assignments.to_dataframe(name="assignment").reset_index()
# For each box, find which class it's most strongly assigned to
box_to_class = df.groupby("box")["assignment"].idxmax()
primary_classes = df.loc[box_to_class, ["box", "class", "assignment"]].reset_index(drop=True)
# Count boxes per class
counts = primary_classes.groupby("class").size().reset_index(name="count")
# Replace the special (-1, 0] interval with "No RTS" if present
counts["class"] = counts["class"].replace("(-1, 0]", "No RTS")
# Sort the classes: "No RTS" first, then by the lower bound of intervals
def sort_key(class_str):
if class_str == "No RTS":
return -1 # Put "No RTS" first
# Parse interval string like "(0, 4]" or "(4, 36]"
try:
lower = float(str(class_str).split(",")[0].strip("([ "))
return lower
except (ValueError, IndexError):
return float("inf") # Put unparseable values at the end
# Sort counts by the same key
counts["sort_key"] = counts["class"].apply(sort_key)
counts = counts.sort_values("sort_key")
# Create an ordered list of classes for consistent color mapping
class_order = counts["class"].tolist()
# Create bar chart
chart = (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("class:N", title="Class Label", sort=class_order, axis=alt.Axis(labelAngle=-45)),
y=alt.Y("count:Q", title="Number of Boxes"),
color=alt.Color(
"class:N",
title="Class",
scale=alt.Scale(domain=class_order, range=altair_colors),
legend=None,
),
tooltip=[
alt.Tooltip("class:N", title="Class"),
alt.Tooltip("count:Q", title="Number of Boxes"),
],
)
.properties(
width=600,
height=300,
title="Number of Boxes Assigned to Each Class (by Primary Assignment)",
)
)
return chart
def plot_common_features(common_array: xr.DataArray) -> alt.Chart:
"""Create a bar chart showing the weights of common features.
Args:
common_array: DataArray with dimension (feature) containing feature weights.
Returns:
Altair chart showing the common feature weights.
"""
# Convert to DataFrame for plotting
df = common_array.to_dataframe(name="weight").reset_index()
# Sort by absolute weight
df["abs_weight"] = df["weight"].abs()
df = df.sort_values("abs_weight", ascending=True)
# Create bar chart
chart = (
alt.Chart(df)
.mark_bar()
.encode(
y=alt.Y("feature:N", title="Feature", sort="-x"),
x=alt.X("weight:Q", title="Feature Weight (scaled by number of features)"),
color=alt.condition(
alt.datum.weight > 0,
alt.value("steelblue"), # Positive weights
alt.value("coral"), # Negative weights
),
tooltip=[
alt.Tooltip("feature:N", title="Feature"),
alt.Tooltip("weight:Q", format=".4f", title="Weight"),
alt.Tooltip("abs_weight:Q", format=".4f", title="Absolute Weight"),
],
)
.properties(
width=600,
height=300,
title="Common Feature Weights",
)
)
return chart

View file

@ -2,9 +2,237 @@
import streamlit as st
from entropice.dashboard.plots.hyperparameter_analysis import (
render_multi_metric_comparison,
render_parameter_correlation,
render_parameter_distributions,
render_performance_summary,
render_score_evolution,
render_score_vs_parameter,
render_top_configurations,
)
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.utils.training import (
format_metric_name,
get_available_metrics,
get_cv_statistics,
get_parameter_space_summary,
)
def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard."""
st.title("Training Results Analysis")
st.write("This page will display analysis of training results and model performance.")
# Add more components and visualizations as needed for training results analysis.
st.title("🦾 Training Results Analysis")
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`")
return
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
# Create selection options
training_options = {tr.name: tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to analyze",
)
selected_result = training_options[selected_name]
st.divider()
# Display selected run info
st.subheader("Run Information")
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
st.write(f"**Trials:** {len(selected_result.results)}")
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
# Refit metric - determine from available metrics
available_metrics = get_available_metrics(selected_result.results)
# Try to get refit metric from settings
refit_metric = selected_result.settings.get("refit_metric")
if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric
task = selected_result.settings.get("task", "binary")
if task == "binary" and "f1" in available_metrics:
refit_metric = "f1"
elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted"
elif "accuracy" in available_metrics:
refit_metric = "accuracy"
elif available_metrics:
refit_metric = available_metrics[0]
else:
st.error("No metrics found in results.")
return
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results)
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
else:
default_metric_idx = 0
selected_metric = st.selectbox(
"Primary Metric for Analysis",
options=available_metrics,
index=default_metric_idx,
format_func=format_metric_name,
help="Select the metric to focus on for detailed analysis",
)
# Top N configurations
top_n = st.slider(
"Top N Configurations",
min_value=5,
max_value=50,
value=10,
step=5,
help="Number of top configurations to display",
)
# Main content area
results = selected_result.results
settings = selected_result.settings
# Performance Summary Section
st.header("📊 Performance Overview")
render_performance_summary(results, refit_metric)
st.divider()
# Quick Statistics
st.header("📈 Cross-Validation Statistics")
cv_stats = get_cv_statistics(results, selected_metric)
if cv_stats:
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
with col2:
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
with col3:
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
with col4:
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
with col5:
st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
if "mean_cv_std" in cv_stats:
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
st.divider()
# Score Evolution
st.header("📉 Training Progress")
render_score_evolution(results, selected_metric)
st.divider()
# Parameter Space Exploration
st.header("🔍 Parameter Space Analysis")
# Show parameter space summary
with st.expander("📋 Parameter Space Summary", expanded=False):
param_summary = get_parameter_space_summary(results)
if not param_summary.empty:
st.dataframe(param_summary, hide_index=True, use_container_width=True)
else:
st.info("No parameter information available.")
# Parameter distributions
render_parameter_distributions(results)
st.divider()
# Score vs Parameters
st.header("🎯 Parameter Impact Analysis")
render_score_vs_parameter(results, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation Analysis")
render_parameter_correlation(results, selected_metric)
st.divider()
# Multi-Metric Comparison
if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Analysis")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
with st.expander("💾 Export Data", expanded=False):
st.subheader("Download Results")
col1, col2 = st.columns(2)
with col1:
# Download full results as CSV
csv_data = results.to_csv(index=False)
st.download_button(
label="📥 Download Full Results (CSV)",
data=csv_data,
file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv",
use_container_width=True,
)
with col2:
# Download settings as text
import json
settings_json = json.dumps(settings, indent=2)
st.download_button(
label="⚙️ Download Settings (JSON)",
data=settings_json,
file_name=f"{selected_result.path.name}_settings.json",
mime="application/json",
use_container_width=True,
)
# Show raw data preview
st.subheader("Raw Data Preview")
st.dataframe(results.head(100), use_container_width=True)

View file

@ -118,6 +118,42 @@ def render_training_data_page():
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
# Display dataset statistics
st.markdown("---")
st.subheader("📈 Dataset Statistics")
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Display target information
col1, col2 = st.columns(2)
with col1:
st.metric(label="Target", value=stats["target"].replace("darts_", ""))
with col2:
st.metric(label="Number of Target Samples", value=f"{stats['num_target_samples']:,}")
# Display member statistics
st.markdown("**Member Statistics:**")
for member, member_stats in stats["members"].items():
with st.expander(f"📦 {member}", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**Number of Features:** {member_stats['num_features']}")
st.markdown(f"**Number of Variables:** {member_stats['num_variables']}")
with col2:
st.markdown(f"**Dimensions:** `{member_stats['dimensions']}`")
# Display variables as a compact list
st.markdown(f"**Variables ({member_stats['num_variables']}):**")
vars_str = ", ".join([f"`{v}`" for v in member_stats["variables"]])
st.markdown(vars_str)
# Display total features
st.metric(label="🎯 Total Number of Features", value=f"{stats['total_features']:,}")
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels"]

View file

@ -8,6 +8,7 @@ import antimeridian
import pandas as pd
import streamlit as st
import toml
import xarray as xr
from shapely.geometry import shape
import entropice.paths
@ -119,3 +120,98 @@ def load_source_data(e: DatasetEnsemble, source: str):
ds = e._read_member(source, targets, lazy=False)
return ds, targets
def extract_embedding_features(model_state) -> xr.DataArray | None:
"""Extract embedding features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted embedding features. This DataArray has dimensions
('agg', 'band', 'year') corresponding to the different components of the embedding features.
Returns None if no embedding features are found.
"""
def _is_embedding_feature(feature: str) -> bool:
return feature.startswith("embedding_")
embedding_features = [f for f in model_state.feature.to_numpy() if _is_embedding_feature(f)]
if len(embedding_features) == 0:
return None
# Split the single feature dimension of embedding features into separate dimensions (agg, band, year)
embedding_feature_array = model_state.sel(feature=embedding_features)["feature_weights"]
embedding_feature_array = embedding_feature_array.assign_coords(
agg=("feature", [f.split("_")[1] for f in embedding_features]),
band=("feature", [f.split("_")[2] for f in embedding_features]),
year=("feature", [f.split("_")[3] for f in embedding_features]),
)
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
return embedding_feature_array
def extract_era5_features(model_state) -> xr.DataArray | None:
"""Extract ERA5 features from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted ERA5 features. This DataArray has dimensions
('variable', 'time') corresponding to the different components of the ERA5 features.
Returns None if no ERA5 features are found.
"""
def _is_era5_feature(feature: str) -> bool:
return feature.startswith("era5_")
def _extract_var_name(feature: str) -> str:
parts = feature.split("_")
# era5_variablename_timetype format
return "_".join(parts[1:-1])
def _extract_time_name(feature: str) -> str:
parts = feature.split("_")
# Last part is the time type
return parts[-1]
era5_features = [f for f in model_state.feature.to_numpy() if _is_era5_feature(f)]
if len(era5_features) == 0:
return None
# Split the single feature dimension of era5 features into separate dimensions (variable, time)
era5_features_array = model_state.sel(feature=era5_features)["feature_weights"]
era5_features_array = era5_features_array.assign_coords(
variable=("feature", [_extract_var_name(f) for f in era5_features]),
time=("feature", [_extract_time_name(f) for f in era5_features]),
)
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
return era5_features_array
def extract_common_features(model_state) -> xr.DataArray | None:
"""Extract common features (cell_area, water_area, land_area, land_ratio, lon, lat) from the model state.
Args:
model_state: The xarray Dataset containing the model state.
Returns:
xr.DataArray: The extracted common features with a single 'feature' dimension.
Returns None if no common features are found.
"""
common_feature_names = ["cell_area", "water_area", "land_area", "land_ratio", "lon", "lat"]
def _is_common_feature(feature: str) -> bool:
return feature in common_feature_names
common_features = [f for f in model_state.feature.to_numpy() if _is_common_feature(f)]
if len(common_features) == 0:
return None
# Extract the feature weights for common features
common_feature_array = model_state.sel(feature=common_features)["feature_weights"]
return common_feature_array

View file

@ -0,0 +1,232 @@
"""Training utilities for dashboard."""
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.data import TrainingResult
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)
def get_available_metrics(results: pd.DataFrame) -> list[str]:
"""Get list of available metrics from results.
Args:
results: DataFrame with CV results.
Returns:
List of metric names (without 'mean_test_' prefix).
"""
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
return [col.replace("mean_test_", "") for col in score_cols]
def load_best_model(result: TrainingResult):
"""Load the best model from a training result.
Args:
result: TrainingResult object.
Returns:
The loaded model object, or None if loading fails.
"""
model_file = result.path / "best_estimator_model.pkl"
if not model_file.exists():
return None
try:
with open(model_file, "rb") as f:
model = pickle.load(f)
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
"""Load the model state from a training result.
Args:
result: TrainingResult object.
Returns:
xarray Dataset with model state, or None if not available.
"""
state_file = result.path / "best_estimator_state.nc"
if not state_file.exists():
return None
try:
state = xr.open_dataset(state_file, engine="h5netcdf")
return state
except Exception as e:
st.error(f"Error loading model state: {e}")
return None
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
"""Load predictions from a training result.
Args:
result: TrainingResult object.
Returns:
DataFrame with predictions, or None if not available.
"""
preds_file = result.path / "predicted_probabilities.parquet"
if not preds_file.exists():
return None
try:
preds = pd.read_parquet(preds_file)
return preds
except Exception as e:
st.error(f"Error loading predictions: {e}")
return None
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
"""Get summary of parameter space explored.
Args:
results: DataFrame with CV results.
Returns:
DataFrame with parameter ranges and statistics.
"""
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
summary_data = []
for param_col in param_cols:
param_name = param_col.replace("param_", "")
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
summary_data.append(
{
"Parameter": param_name,
"Type": "Numeric",
"Min": f"{param_values.min():.2e}",
"Max": f"{param_values.max():.2e}",
"Mean": f"{param_values.mean():.2e}",
"Unique Values": param_values.nunique(),
}
)
else:
unique_vals = param_values.unique()
summary_data.append(
{
"Parameter": param_name,
"Type": "Categorical",
"Min": "-",
"Max": "-",
"Mean": "-",
"Unique Values": len(unique_vals),
}
)
return pd.DataFrame(summary_data)
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
"""Get cross-validation statistics for a metric.
Args:
results: DataFrame with CV results.
metric: Metric name (without 'mean_test_' prefix).
Returns:
Dictionary with CV statistics.
"""
score_col = f"mean_test_{metric}"
std_col = f"std_test_{metric}"
if score_col not in results.columns:
return {}
stats = {
"best_score": results[score_col].max(),
"mean_score": results[score_col].mean(),
"std_score": results[score_col].std(),
"worst_score": results[score_col].min(),
"median_score": results[score_col].median(),
}
if std_col in results.columns:
stats["mean_cv_std"] = results[std_col].mean()
return stats
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
"""Prepare results dataframe with binned columns for plotting.
Args:
results: DataFrame with CV results.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results_copy = results.copy()
# Check if we have the parameters
if "param_initial_K" in results.columns:
# Bin initial_K
k_values = results["param_initial_K"].dropna()
if len(k_values) > 0:
k_min = k_values.min()
k_max = k_values.max()
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
if "param_eps_cl" in results.columns:
# Create logarithmic bins for eps_cl
eps_cl_values = results["param_eps_cl"].dropna()
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
eps_cl_min = eps_cl_values.min()
eps_cl_max = eps_cl_values.max()
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
if "param_eps_e" in results.columns:
# Create logarithmic bins for eps_e
eps_e_values = results["param_eps_e"].dropna()
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
eps_e_min = eps_e_values.min()
eps_e_max = eps_e_values.max()
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
return results_copy

View file

@ -283,25 +283,52 @@ class DatasetEnsemble:
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
return arcticdem_df
def print_stats(self):
targets = self._read_target()
print(f"=== Target: {self.target}")
print(f"\tNumber of target samples: {len(targets)}")
def get_stats(self) -> dict:
"""Get dataset statistics.
Returns:
dict: Dictionary containing target stats, member stats, and total features count.
"""
targets = self._read_target()
stats = {
"target": self.target,
"num_target_samples": len(targets),
"members": {},
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
}
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
for member in self.members:
ds = self._read_member(member, targets, lazy=True)
print(f"=== Member: {member}")
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
print(f"\tDimensions: {dict(ds.sizes)}")
print(f"\tCoordinates: {list(ds.coords)}")
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
print(f"\tNumber of features from member: {n_cols_member}")
n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}")
stats["members"][member] = {
"variables": list(ds.data_vars),
"num_variables": len(ds.data_vars),
"dimensions": dict(ds.sizes),
"coordinates": list(ds.coords),
"num_features": n_cols_member,
}
stats["total_features"] += n_cols_member
return stats
def print_stats(self):
stats = self.get_stats()
print(f"=== Target: {stats['target']}")
print(f"\tNumber of target samples: {stats['num_target_samples']}")
for member, member_stats in stats["members"].items():
print(f"=== Member: {member}")
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
print(f"\tDimensions: {member_stats['dimensions']}")
print(f"\tCoordinates: {member_stats['coordinates']}")
print(f"\tNumber of features from member: {member_stats['num_features']}")
print(f"=== Total number of features in dataset: {stats['total_features']}")
@lru_cache(maxsize=1)
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: