Update training analysis page

This commit is contained in:
Tobias Hölzer 2025-12-19 16:52:02 +01:00
parent 8338efb31e
commit 6ed5a9c224
4 changed files with 279 additions and 83 deletions

View file

@ -72,7 +72,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
scale_factor = scale_factors[grid][level] scale_factor = scale_factors[grid][level]
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.") print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
for year in track(range(2021, 2025), total=4, description="Processing years..."): # 2024-2025 for hex-6
for year in track(range(2022, 2025), total=3, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]

View file

@ -1,9 +1,12 @@
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results.""" """Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
import altair as alt import altair as alt
import numpy as np
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from entropice.dashboard.plots.colors import get_cmap
def render_performance_summary(results: pd.DataFrame, refit_metric: str): def render_performance_summary(results: pd.DataFrame, refit_metric: str):
"""Render summary statistics of model performance. """Render summary statistics of model performance.
@ -13,8 +16,6 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted'). refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
""" """
st.subheader("📊 Performance Summary")
# Get all test score columns # Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")] score_cols = [col for col in results.columns if col.startswith("mean_test_")]
@ -51,7 +52,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True) st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
# Show best parameter combination # Show best parameter combination in a cleaner format (similar to old dashboard)
st.markdown("#### 🏆 Best Parameter Combination") st.markdown("#### 🏆 Best Parameter Combination")
refit_col = f"mean_test_{refit_metric}" refit_col = f"mean_test_{refit_metric}"
@ -74,30 +75,48 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
if param_cols: if param_cols:
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols} best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
# Display in a nice formatted way # Display in a container with metrics (similar to old dashboard style)
param_df = pd.DataFrame([best_params]).T with st.container(border=True):
param_df.columns = ["Value"] st.caption(f"Parameters of the best model (selected by {refit_metric.replace('_', ' ').title()} score)")
param_df.index.name = "Parameter"
col1, col2 = st.columns([1, 1]) # Display parameters as metrics
with col1: n_params = len(best_params)
st.dataframe(param_df, use_container_width=True) cols = st.columns(n_params)
with col2: for idx, (param_name, param_value) in enumerate(best_params.items()):
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}") with cols[idx]:
rank_col = "rank_test_" + refit_metric # Format value based on type and magnitude
if rank_col in best_row.index: if isinstance(param_value, (int, np.integer)):
try: formatted_value = f"{param_value:.0f}"
# Handle potential Series or scalar values elif isinstance(param_value, (float, np.floating)):
rank_val = best_row[rank_col] # Use scientific notation for very small numbers
if hasattr(rank_val, "item"): if abs(param_value) < 0.01:
rank_val = rank_val.item() formatted_value = f"{param_value:.2e}"
rank_display = str(int(float(rank_val))) else:
except (ValueError, TypeError, AttributeError): formatted_value = f"{param_value:.4f}"
rank_display = "N/A" else:
else: formatted_value = str(param_value)
rank_display = "N/A"
st.metric("Rank", rank_display) st.metric(param_name, formatted_value)
# Show all metrics for the best model
st.divider()
st.caption("Performance across all metrics")
# Get all metrics from score_cols
all_metrics = [col.replace("mean_test_", "") for col in score_cols]
metric_cols = st.columns(len(all_metrics))
for idx, metric in enumerate(all_metrics):
with metric_cols[idx]:
best_score = results.loc[best_idx, f"mean_test_{metric}"]
best_std = results.loc[best_idx, f"std_test_{metric}"]
st.metric(
metric.replace("_", " ").title(),
f"{best_score:.4f}",
delta=f"±{best_std:.4f}",
help="Mean ± std across cross-validation folds",
)
def render_parameter_distributions(results: pd.DataFrame): def render_parameter_distributions(results: pd.DataFrame):
@ -107,8 +126,6 @@ def render_parameter_distributions(results: pd.DataFrame):
results: DataFrame with CV results. results: DataFrame with CV results.
""" """
st.subheader("📈 Parameter Space Exploration")
# Get parameter columns # Get parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"] param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
@ -116,6 +133,12 @@ def render_parameter_distributions(results: pd.DataFrame):
st.warning("No parameter columns found in results.") st.warning("No parameter columns found in results.")
return return
# Get colormap from colors module
cmap = get_cmap("parameter_distribution")
import matplotlib.colors as mcolors
bar_color = mcolors.rgb2hex(cmap(0.5))
# Create histograms for each parameter # Create histograms for each parameter
n_params = len(param_cols) n_params = len(param_cols)
n_cols = min(3, n_params) n_cols = min(3, n_params)
@ -146,25 +169,25 @@ def render_parameter_distributions(results: pd.DataFrame):
if use_log: if use_log:
chart = ( chart = (
alt.Chart(df_plot) alt.Chart(df_plot)
.mark_bar() .mark_bar(color=bar_color)
.encode( .encode(
alt.X( alt.X(
param_name, param_name,
bin=alt.Bin(maxbins=30), bin=alt.Bin(maxbins=20),
scale=alt.Scale(type="log"), scale=alt.Scale(type="log"),
title=param_name, title=param_name,
), ),
alt.Y("count()", title="Count"), alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"], tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
) )
.properties(height=250, title=f"{param_name} (log scale)") .properties(height=250, title=param_name)
) )
else: else:
chart = ( chart = (
alt.Chart(df_plot) alt.Chart(df_plot)
.mark_bar() .mark_bar(color=bar_color)
.encode( .encode(
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name), alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
alt.Y("count()", title="Count"), alt.Y("count()", title="Count"),
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"], tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
) )
@ -180,7 +203,7 @@ def render_parameter_distributions(results: pd.DataFrame):
chart = ( chart = (
alt.Chart(value_counts) alt.Chart(value_counts)
.mark_bar() .mark_bar(color=bar_color)
.encode( .encode(
alt.X(param_name, title=param_name, sort="-y"), alt.X(param_name, title=param_name, sort="-y"),
alt.Y("count", title="Count"), alt.Y("count", title="Count"),
@ -298,15 +321,13 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
def render_parameter_correlation(results: pd.DataFrame, metric: str): def render_parameter_correlation(results: pd.DataFrame, metric: str):
"""Render correlation heatmap between parameters and score. """Render correlation bar chart between parameters and score.
Args: Args:
results: DataFrame with CV results. results: DataFrame with CV results.
metric: The metric to analyze (e.g., 'f1', 'accuracy'). metric: The metric to analyze (e.g., 'f1', 'accuracy').
""" """
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}" score_col = f"mean_test_{metric}"
if score_col not in results.columns: if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.") st.warning(f"Metric {metric} not found in results.")
@ -329,6 +350,14 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False) corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
# Get colormap from colors module
cmap = get_cmap("correlation")
# Sample colors for red-blue diverging scheme
colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
import matplotlib.colors as mcolors
hex_colors = [mcolors.rgb2hex(c) for c in colors]
# Create bar chart # Create bar chart
chart = ( chart = (
alt.Chart(corr_df) alt.Chart(corr_df)
@ -348,13 +377,183 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
st.altair_chart(chart, use_container_width=True) st.altair_chart(chart, use_container_width=True)
# Show correlation table
with st.expander("📋 Correlation Table"): def render_binned_parameter_space(results: pd.DataFrame, metric: str):
st.dataframe( """Render binned parameter space plots similar to old dashboard.
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
hide_index=True, This creates plots where parameters are binned and plotted against each other,
use_container_width=True, showing the metric value as color. Handles different hyperparameters dynamically.
Args:
results: DataFrame with CV results.
metric: The metric to visualize (e.g., 'f1', 'accuracy').
"""
score_col = f"mean_test_{metric}"
if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.")
return
# Get numeric parameter columns
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
if len(numeric_params) < 2:
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
return
# Prepare binned data
results_binned = results.copy()
bin_info = {}
for param_col in numeric_params:
param_name = param_col.replace("param_", "")
param_values = results[param_col].dropna()
if len(param_values) == 0:
continue
# Determine if we should use log bins
value_range = param_values.max() / (param_values.min() + 1e-10)
use_log = value_range > 100 or param_values.max() < 1.0
if use_log:
# Logarithmic binning for parameters spanning many orders of magnitude
log_min = np.log10(param_values.min())
log_max = np.log10(param_values.max())
n_bins = min(10, int(log_max - log_min) + 1)
bins = np.logspace(log_min, log_max, num=n_bins)
else:
# Linear binning
n_bins = min(10, int(np.sqrt(len(param_values))))
bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
bin_info[param_name] = {"use_log": use_log, "bins": bins}
# Get colormap from colors module
cmap = get_cmap(metric)
# Create binned plots for pairs of parameters
# Show the most meaningful combinations
param_names = [col.replace("param_", "") for col in numeric_params]
# If we have 3+ parameters, create multiple plots
if len(param_names) >= 3:
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
# Plot first parameter vs others, binned by third
for i in range(1, min(3, len(param_names))):
x_param = param_names[0]
y_param = param_names[i]
# Find a third parameter for faceting
facet_param = None
for p in param_names:
if p != x_param and p != y_param:
facet_param = p
break
if facet_param and f"{facet_param}_binned" in results_binned.columns:
# Create faceted plot
plot_data = results_binned[
[f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
].dropna()
# Convert binned column to string with sorted categories
plot_data = plot_data.sort_values(f"{facet_param}_binned")
plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
# Get colormap colors
cmap_colors = get_cmap(metric)
import matplotlib.colors as mcolors
# Sample colors from colormap
hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
chart = (
alt.Chart(plot_data)
.mark_circle(size=60, opacity=0.7)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
f"{facet_param}_binned:N",
],
)
.properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
.facet(
facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
columns=4,
)
)
st.altair_chart(chart, use_container_width=True)
# Also create a simple 2D plot without faceting
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
chart = (
alt.Chart(plot_data)
.mark_circle(size=80, opacity=0.6)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
],
)
.properties(height=400, title=f"{x_param} vs {y_param}")
)
st.altair_chart(chart, use_container_width=True)
elif len(param_names) == 2:
# Simple 2-parameter plot
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
x_param = param_names[0]
y_param = param_names[1]
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
chart = (
alt.Chart(plot_data)
.mark_circle(size=100, opacity=0.6)
.encode(
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
color=alt.Color(
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
),
tooltip=[
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
alt.Tooltip(f"{score_col}:Q", format=".4f"),
],
)
.properties(height=500, title=f"{x_param} vs {y_param}")
) )
st.altair_chart(chart, use_container_width=True)
def render_score_evolution(results: pd.DataFrame, metric: str): def render_score_evolution(results: pd.DataFrame, metric: str):
@ -422,8 +621,6 @@ def render_multi_metric_comparison(results: pd.DataFrame):
results: DataFrame with CV results. results: DataFrame with CV results.
""" """
st.subheader("📊 Multi-Metric Comparison")
# Get all test score columns # Get all test score columns
score_cols = [col for col in results.columns if col.startswith("mean_test_")] score_cols = [col for col in results.columns if col.startswith("mean_test_")]
@ -462,6 +659,9 @@ def render_multi_metric_comparison(results: pd.DataFrame):
} }
) )
# Get colormap from colors module
cmap = get_cmap("iteration")
chart = ( chart = (
alt.Chart(df_plot) alt.Chart(df_plot)
.mark_circle(size=60, opacity=0.6) .mark_circle(size=60, opacity=0.6)
@ -494,8 +694,6 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
top_n: Number of top configurations to show. top_n: Number of top configurations to show.
""" """
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
score_col = f"mean_test_{metric}" score_col = f"mean_test_{metric}"
if score_col not in results.columns: if score_col not in results.columns:
st.warning(f"Metric {metric} not found in results.") st.warning(f"Metric {metric} not found in results.")

View file

@ -12,6 +12,8 @@ from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap from entropice.dashboard.plots.colors import get_cmap
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
def _fix_hex_geometry(geom): def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian.""" """Fix hexagon geometry crossing the antimeridian."""

View file

@ -3,12 +3,11 @@
import streamlit as st import streamlit as st
from entropice.dashboard.plots.hyperparameter_analysis import ( from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
render_multi_metric_comparison, render_multi_metric_comparison,
render_parameter_correlation, render_parameter_correlation,
render_parameter_distributions, render_parameter_distributions,
render_performance_summary, render_performance_summary,
render_score_evolution,
render_score_vs_parameter,
render_top_configurations, render_top_configurations,
) )
from entropice.dashboard.utils.data import load_all_training_results from entropice.dashboard.utils.data import load_all_training_results
@ -50,16 +49,9 @@ def render_training_analysis_page():
st.divider() st.divider()
# Display selected run info # Metric selection for detailed analysis
st.subheader("Run Information") st.subheader("Analysis Settings")
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
st.write(f"**Trials:** {len(selected_result.results)}")
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
# Refit metric - determine from available metrics
available_metrics = get_available_metrics(selected_result.results) available_metrics = get_available_metrics(selected_result.results)
# Try to get refit metric from settings # Try to get refit metric from settings
@ -80,15 +72,6 @@ def render_training_analysis_page():
st.error("No metrics found in results.") st.error("No metrics found in results.")
return return
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results)
if refit_metric in available_metrics: if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric) default_metric_idx = available_metrics.index(refit_metric)
else: else:
@ -112,6 +95,27 @@ def render_training_analysis_page():
help="Number of top configurations to display", help="Number of top configurations to display",
) )
# Main content area - Run Information at the top
st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
with col2:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
with col3:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col4:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
with col5:
st.metric("Trials", len(selected_result.results))
with col6:
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Main content area # Main content area
results = selected_result.results results = selected_result.results
settings = selected_result.settings settings = selected_result.settings
@ -151,14 +155,7 @@ def render_training_analysis_page():
st.divider() st.divider()
# Score Evolution # Parameter Space Analysis
st.header("📉 Training Progress")
render_score_evolution(results, selected_metric)
st.divider()
# Parameter Space Exploration
st.header("🔍 Parameter Space Analysis") st.header("🔍 Parameter Space Analysis")
# Show parameter space summary # Show parameter space summary
@ -170,19 +167,17 @@ def render_training_analysis_page():
st.info("No parameter information available.") st.info("No parameter information available.")
# Parameter distributions # Parameter distributions
st.subheader("📈 Parameter Distributions")
render_parameter_distributions(results) render_parameter_distributions(results)
st.divider() # Binned parameter space plots
st.subheader("🎨 Binned Parameter Space")
# Score vs Parameters render_binned_parameter_space(results, selected_metric)
st.header("🎯 Parameter Impact Analysis")
render_score_vs_parameter(results, selected_metric)
st.divider() st.divider()
# Parameter Correlation # Parameter Correlation
st.header("🔗 Parameter Correlation Analysis") st.header("🔗 Parameter Correlation")
render_parameter_correlation(results, selected_metric) render_parameter_correlation(results, selected_metric)
@ -190,7 +185,7 @@ def render_training_analysis_page():
# Multi-Metric Comparison # Multi-Metric Comparison
if len(available_metrics) >= 2: if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Analysis") st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results) render_multi_metric_comparison(results)