Update training analysis page

This commit is contained in:
Tobias Hölzer 2025-12-19 16:52:02 +01:00
parent 8338efb31e
commit 6ed5a9c224
4 changed files with 279 additions and 83 deletions

View file

@ -72,7 +72,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
scale_factor = scale_factors[grid][level]
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
for year in track(range(2021, 2025), total=4, description="Processing years..."):
# 2024-2025 for hex-6
for year in track(range(2022, 2025), total=3, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]

View file

@ -1,9 +1,12 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
from entropice.dashboard.plots.colors import get_cmap
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
"""Render summary statistics of model performance.
@ -13,8 +16,6 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
"""
st.subheader("📊 Performance Summary")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
@ -51,7 +52,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
# Show best parameter combination
# Show best parameter combination in a cleaner format (similar to old dashboard)
st.markdown("#### 🏆 Best Parameter Combination")
refit_col = f"mean_test_{refit_metric}"
@ -74,30 +75,48 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
if param_cols:
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
# Display in a nice formatted way
param_df = pd.DataFrame([best_params]).T
param_df.columns = ["Value"]
param_df.index.name = "Parameter"
# Display in a container with metrics (similar to old dashboard style)
with st.container(border=True):
st.caption(f"Parameters of the best model (selected by {refit_metric.replace('_', ' ').title()} score)")
col1, col2 = st.columns([1, 1])
with col1:
st.dataframe(param_df, use_container_width=True)
# Display parameters as metrics
n_params = len(best_params)
cols = st.columns(n_params)
with col2:
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}")
rank_col = "rank_test_" + refit_metric
if rank_col in best_row.index:
try:
# Handle potential Series or scalar values
rank_val = best_row[rank_col]
if hasattr(rank_val, "item"):
rank_val = rank_val.item()
rank_display = str(int(float(rank_val)))
except (ValueError, TypeError, AttributeError):
rank_display = "N/A"
else:
rank_display = "N/A"
st.metric("Rank", rank_display)
for idx, (param_name, param_value) in enumerate(best_params.items()):
with cols[idx]:
# Format value based on type and magnitude
if isinstance(param_value, (int, np.integer)):
formatted_value = f"{param_value:.0f}"
elif isinstance(param_value, (float, np.floating)):
# Use scientific notation for very small numbers
if abs(param_value) < 0.01:
formatted_value = f"{param_value:.2e}"
else:
formatted_value = f"{param_value:.4f}"
else:
formatted_value = str(param_value)
st.metric(param_name, formatted_value)
# Show all metrics for the best model
st.divider()
st.caption("Performance across all metrics")
# Get all metrics from score_cols
all_metrics = [col.replace("mean_test_", "") for col in score_cols]
metric_cols = st.columns(len(all_metrics))
for idx, metric in enumerate(all_metrics):
with metric_cols[idx]:
best_score = results.loc[best_idx, f"mean_test_{metric}"]
best_std = results.loc[best_idx, f"std_test_{metric}"]
st.metric(
metric.replace("_", " ").title(),
f"{best_score:.4f}",
delta=f"±{best_std:.4f}",
help="Mean ± std across cross-validation folds",
)
def render_parameter_distributions(results: pd.DataFrame):
@ -107,8 +126,6 @@ def render_parameter_distributions(results: pd.DataFrame):
results: DataFrame with CV results.
"""
st.subheader("📈 Parameter Space Exploration")
# Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
@ -116,6 +133,12 @@ def render_parameter_distributions(results: pd.DataFrame):
st.warning("No parameter columns found in results.")
return
# Get colormap from colors module
cmap = get_cmap("parameter_distribution")
import matplotlib.colors as mcolors
bar_color = mcolors.rgb2hex(cmap(0.5))
# Create histograms for each parameter
n_params = len(param_cols)
n_cols = min(3, n_params)
@ -146,25 +169,25 @@ def render_parameter_distributions(results: pd.DataFrame):
if use_log:
chart = (
alt.Chart(df_plot)
.mark_bar()
.mark_bar(color=bar_color)
.encode(
alt.X(
param_name,
bin=alt.Bin(maxbins=30),
bin=alt.Bin(maxbins=20),
scale=alt.Scale(type="log"),
title=param_name,
),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
)
.properties(height=250, title=f"{param_name} (log scale)")
.properties(height=250, title=param_name)
)
else:
chart = (
alt.Chart(df_plot)
.mark_bar()
.mark_bar(color=bar_color)
.encode(
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name),
alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
)
@ -180,7 +203,7 @@ def render_parameter_distributions(results: pd.DataFrame):
chart = (
alt.Chart(value_counts)
.mark_bar()
.mark_bar(color=bar_color)
.encode(
alt.X(param_name, title=param_name, sort="-y"),
alt.Y("count", title="Count"),
@ -298,15 +321,13 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
def render_parameter_correlation(results: pd.DataFrame, metric: str):
"""Render correlation heatmap between parameters and score.
"""Render correlation bar chart between parameters and score.
Args:
results: DataFrame with CV results.
metric: The metric to analyze (e.g., 'f1', 'accuracy').
"""
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
@ -329,6 +350,14 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
# Get colormap from colors module
cmap = get_cmap("correlation")
# Sample colors for red-blue diverging scheme
colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
import matplotlib.colors as mcolors
hex_colors = [mcolors.rgb2hex(c) for c in colors]
# Create bar chart
chart = (
alt.Chart(corr_df)
@ -348,13 +377,183 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
st.altair_chart(chart, use_container_width=True)
# Show correlation table
with st.expander("📋 Correlation Table"):
st.dataframe(
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
hide_index=True,
use_container_width=True,
def render_binned_parameter_space(results: pd.DataFrame, metric: str):
"""Render binned parameter space plots similar to old dashboard.
This creates plots where parameters are binned and plotted against each other,
showing the metric value as color. Handles different hyperparameters dynamically.
Args:
results: DataFrame with CV results.
metric: The metric to visualize (e.g., 'f1', 'accuracy').
"""
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get numeric parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
if len(numeric_params) < 2:
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
return
# Prepare binned data
results_binned = results.copy()
bin_info = {}
for param_col in numeric_params:
param_name = param_col.replace("param_", "")
param_values = results[param_col].dropna()
if len(param_values) == 0:
continue
# Determine if we should use log bins
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100 or param_values.max() < 1.0
if use_log:
# Logarithmic binning for parameters spanning many orders of magnitude
log_min = np.log10(param_values.min())
log_max = np.log10(param_values.max())
n_bins = min(10, int(log_max - log_min) + 1)
bins = np.logspace(log_min, log_max, num=n_bins)
else:
# Linear binning
n_bins = min(10, int(np.sqrt(len(param_values))))
bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
bin_info[param_name] = {"use_log": use_log, "bins": bins}
# Get colormap from colors module
cmap = get_cmap(metric)
# Create binned plots for pairs of parameters
# Show the most meaningful combinations
param_names = [col.replace("param_", "") for col in numeric_params]
# If we have 3+ parameters, create multiple plots
if len(param_names) >= 3:
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
# Plot first parameter vs others, binned by third
for i in range(1, min(3, len(param_names))):
x_param = param_names[0]
y_param = param_names[i]
# Find a third parameter for faceting
facet_param = None
for p in param_names:
if p != x_param and p != y_param:
facet_param = p
break
if facet_param and f"{facet_param}_binned" in results_binned.columns:
# Create faceted plot
plot_data = results_binned[
[f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
].dropna()
# Convert binned column to string with sorted categories
plot_data = plot_data.sort_values(f"{facet_param}_binned")
plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
# Get colormap colors
cmap_colors = get_cmap(metric)
import matplotlib.colors as mcolors
# Sample colors from colormap
hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
f"{facet_param}_binned:N",
],
)
.properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
.facet(
facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
columns=4,
)
)
st.altair_chart(chart, use_container_width=True)
# Also create a simple 2D plot without faceting
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
chart = (
alt.Chart(plot_data)
.mark_circle(size=80, opacity=0.6)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
],
)
.properties(height=400, title=f"{x_param} vs {y_param}")
)
st.altair_chart(chart, use_container_width=True)
elif len(param_names) == 2:
# Simple 2-parameter plot
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
x_param = param_names[0]
y_param = param_names[1]
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
chart = (
alt.Chart(plot_data)
.mark_circle(size=100, opacity=0.6)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
],
)
.properties(height=500, title=f"{x_param} vs {y_param}")
)
st.altair_chart(chart, use_container_width=True)
def render_score_evolution(results: pd.DataFrame, metric: str):
@ -422,8 +621,6 @@ def render_multi_metric_comparison(results: pd.DataFrame):
results: DataFrame with CV results.
"""
st.subheader("📊 Multi-Metric Comparison")
# Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
@ -462,6 +659,9 @@ def render_multi_metric_comparison(results: pd.DataFrame):
}
)
# Get colormap from colors module
cmap = get_cmap("iteration")
chart = (
alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6)
@ -494,8 +694,6 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
top_n: Number of top configurations to show.
"""
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")

View file

@ -12,6 +12,8 @@ from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""

View file

@ -3,12 +3,11 @@
import streamlit as st
from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
render_multi_metric_comparison,
render_parameter_correlation,
render_parameter_distributions,
render_performance_summary,
render_score_evolution,
render_score_vs_parameter,
render_top_configurations,
)
from entropice.dashboard.utils.data import load_all_training_results
@ -50,16 +49,9 @@ def render_training_analysis_page():
st.divider()
# Display selected run info
st.subheader("Run Information")
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
st.write(f"**Trials:** {len(selected_result.results)}")
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
# Refit metric - determine from available metrics
available_metrics = get_available_metrics(selected_result.results)
# Try to get refit metric from settings
@ -80,15 +72,6 @@ def render_training_analysis_page():
st.error("No metrics found in results.")
return
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results)
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
else:
@ -112,6 +95,27 @@ def render_training_analysis_page():
help="Number of top configurations to display",
)
# Main content area - Run Information at the top
st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
with col2:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
with col3:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col4:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
with col5:
st.metric("Trials", len(selected_result.results))
with col6:
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Main content area
results = selected_result.results
settings = selected_result.settings
@ -151,14 +155,7 @@ def render_training_analysis_page():
st.divider()
# Score Evolution
st.header("📉 Training Progress")
render_score_evolution(results, selected_metric)
st.divider()
# Parameter Space Exploration
# Parameter Space Analysis
st.header("🔍 Parameter Space Analysis")
# Show parameter space summary
@ -170,19 +167,17 @@ def render_training_analysis_page():
st.info("No parameter information available.")
# Parameter distributions
st.subheader("📈 Parameter Distributions")
render_parameter_distributions(results)
st.divider()
# Score vs Parameters
st.header("🎯 Parameter Impact Analysis")
render_score_vs_parameter(results, selected_metric)
# Binned parameter space plots
st.subheader("🎨 Binned Parameter Space")
render_binned_parameter_space(results, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation Analysis")
st.header("🔗 Parameter Correlation")
render_parameter_correlation(results, selected_metric)
@ -190,7 +185,7 @@ def render_training_analysis_page():
# Multi-Metric Comparison
if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Analysis")
st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results)