Update training analysis page
This commit is contained in:
parent
8338efb31e
commit
6ed5a9c224
4 changed files with 279 additions and 83 deletions
|
|
@ -72,7 +72,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
scale_factor = scale_factors[grid][level]
|
scale_factor = scale_factors[grid][level]
|
||||||
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
|
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
|
||||||
|
|
||||||
for year in track(range(2021, 2025), total=4, description="Processing years..."):
|
# 2024-2025 for hex-6
|
||||||
|
for year in track(range(2022, 2025), total=3, description="Processing years..."):
|
||||||
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
||||||
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
||||||
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
|
"""Hyperparameter analysis plotting functions for RandomizedSearchCV results."""
|
||||||
|
|
||||||
import altair as alt
|
import altair as alt
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.colors import get_cmap
|
||||||
|
|
||||||
|
|
||||||
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||||
"""Render summary statistics of model performance.
|
"""Render summary statistics of model performance.
|
||||||
|
|
@ -13,8 +16,6 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||||
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
|
refit_metric: The metric used for refit (e.g., 'f1', 'f1_weighted').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("📊 Performance Summary")
|
|
||||||
|
|
||||||
# Get all test score columns
|
# Get all test score columns
|
||||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
|
|
@ -51,7 +52,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||||
|
|
||||||
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
|
st.dataframe(pd.DataFrame(score_stats), hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
# Show best parameter combination
|
# Show best parameter combination in a cleaner format (similar to old dashboard)
|
||||||
st.markdown("#### 🏆 Best Parameter Combination")
|
st.markdown("#### 🏆 Best Parameter Combination")
|
||||||
refit_col = f"mean_test_{refit_metric}"
|
refit_col = f"mean_test_{refit_metric}"
|
||||||
|
|
||||||
|
|
@ -74,30 +75,48 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
|
||||||
if param_cols:
|
if param_cols:
|
||||||
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
|
best_params = {col.replace("param_", ""): best_row[col] for col in param_cols}
|
||||||
|
|
||||||
# Display in a nice formatted way
|
# Display in a container with metrics (similar to old dashboard style)
|
||||||
param_df = pd.DataFrame([best_params]).T
|
with st.container(border=True):
|
||||||
param_df.columns = ["Value"]
|
st.caption(f"Parameters of the best model (selected by {refit_metric.replace('_', ' ').title()} score)")
|
||||||
param_df.index.name = "Parameter"
|
|
||||||
|
|
||||||
col1, col2 = st.columns([1, 1])
|
# Display parameters as metrics
|
||||||
with col1:
|
n_params = len(best_params)
|
||||||
st.dataframe(param_df, use_container_width=True)
|
cols = st.columns(n_params)
|
||||||
|
|
||||||
with col2:
|
for idx, (param_name, param_value) in enumerate(best_params.items()):
|
||||||
st.metric(f"Best {refit_metric.replace('_', ' ').title()}", f"{best_row[refit_col]:.4f}")
|
with cols[idx]:
|
||||||
rank_col = "rank_test_" + refit_metric
|
# Format value based on type and magnitude
|
||||||
if rank_col in best_row.index:
|
if isinstance(param_value, (int, np.integer)):
|
||||||
try:
|
formatted_value = f"{param_value:.0f}"
|
||||||
# Handle potential Series or scalar values
|
elif isinstance(param_value, (float, np.floating)):
|
||||||
rank_val = best_row[rank_col]
|
# Use scientific notation for very small numbers
|
||||||
if hasattr(rank_val, "item"):
|
if abs(param_value) < 0.01:
|
||||||
rank_val = rank_val.item()
|
formatted_value = f"{param_value:.2e}"
|
||||||
rank_display = str(int(float(rank_val)))
|
else:
|
||||||
except (ValueError, TypeError, AttributeError):
|
formatted_value = f"{param_value:.4f}"
|
||||||
rank_display = "N/A"
|
else:
|
||||||
else:
|
formatted_value = str(param_value)
|
||||||
rank_display = "N/A"
|
|
||||||
st.metric("Rank", rank_display)
|
st.metric(param_name, formatted_value)
|
||||||
|
|
||||||
|
# Show all metrics for the best model
|
||||||
|
st.divider()
|
||||||
|
st.caption("Performance across all metrics")
|
||||||
|
|
||||||
|
# Get all metrics from score_cols
|
||||||
|
all_metrics = [col.replace("mean_test_", "") for col in score_cols]
|
||||||
|
metric_cols = st.columns(len(all_metrics))
|
||||||
|
|
||||||
|
for idx, metric in enumerate(all_metrics):
|
||||||
|
with metric_cols[idx]:
|
||||||
|
best_score = results.loc[best_idx, f"mean_test_{metric}"]
|
||||||
|
best_std = results.loc[best_idx, f"std_test_{metric}"]
|
||||||
|
st.metric(
|
||||||
|
metric.replace("_", " ").title(),
|
||||||
|
f"{best_score:.4f}",
|
||||||
|
delta=f"±{best_std:.4f}",
|
||||||
|
help="Mean ± std across cross-validation folds",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def render_parameter_distributions(results: pd.DataFrame):
|
def render_parameter_distributions(results: pd.DataFrame):
|
||||||
|
|
@ -107,8 +126,6 @@ def render_parameter_distributions(results: pd.DataFrame):
|
||||||
results: DataFrame with CV results.
|
results: DataFrame with CV results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("📈 Parameter Space Exploration")
|
|
||||||
|
|
||||||
# Get parameter columns
|
# Get parameter columns
|
||||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||||
|
|
||||||
|
|
@ -116,6 +133,12 @@ def render_parameter_distributions(results: pd.DataFrame):
|
||||||
st.warning("No parameter columns found in results.")
|
st.warning("No parameter columns found in results.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Get colormap from colors module
|
||||||
|
cmap = get_cmap("parameter_distribution")
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
bar_color = mcolors.rgb2hex(cmap(0.5))
|
||||||
|
|
||||||
# Create histograms for each parameter
|
# Create histograms for each parameter
|
||||||
n_params = len(param_cols)
|
n_params = len(param_cols)
|
||||||
n_cols = min(3, n_params)
|
n_cols = min(3, n_params)
|
||||||
|
|
@ -146,25 +169,25 @@ def render_parameter_distributions(results: pd.DataFrame):
|
||||||
if use_log:
|
if use_log:
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(df_plot)
|
alt.Chart(df_plot)
|
||||||
.mark_bar()
|
.mark_bar(color=bar_color)
|
||||||
.encode(
|
.encode(
|
||||||
alt.X(
|
alt.X(
|
||||||
param_name,
|
param_name,
|
||||||
bin=alt.Bin(maxbins=30),
|
bin=alt.Bin(maxbins=20),
|
||||||
scale=alt.Scale(type="log"),
|
scale=alt.Scale(type="log"),
|
||||||
title=param_name,
|
title=param_name,
|
||||||
),
|
),
|
||||||
alt.Y("count()", title="Count"),
|
alt.Y("count()", title="Count"),
|
||||||
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
|
tooltip=[alt.Tooltip(param_name, format=".2e"), "count()"],
|
||||||
)
|
)
|
||||||
.properties(height=250, title=f"{param_name} (log scale)")
|
.properties(height=250, title=param_name)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(df_plot)
|
alt.Chart(df_plot)
|
||||||
.mark_bar()
|
.mark_bar(color=bar_color)
|
||||||
.encode(
|
.encode(
|
||||||
alt.X(param_name, bin=alt.Bin(maxbins=30), title=param_name),
|
alt.X(param_name, bin=alt.Bin(maxbins=20), title=param_name),
|
||||||
alt.Y("count()", title="Count"),
|
alt.Y("count()", title="Count"),
|
||||||
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
|
tooltip=[alt.Tooltip(param_name, format=".3f"), "count()"],
|
||||||
)
|
)
|
||||||
|
|
@ -180,7 +203,7 @@ def render_parameter_distributions(results: pd.DataFrame):
|
||||||
|
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(value_counts)
|
alt.Chart(value_counts)
|
||||||
.mark_bar()
|
.mark_bar(color=bar_color)
|
||||||
.encode(
|
.encode(
|
||||||
alt.X(param_name, title=param_name, sort="-y"),
|
alt.X(param_name, title=param_name, sort="-y"),
|
||||||
alt.Y("count", title="Count"),
|
alt.Y("count", title="Count"),
|
||||||
|
|
@ -298,15 +321,13 @@ def render_score_vs_parameter(results: pd.DataFrame, metric: str):
|
||||||
|
|
||||||
|
|
||||||
def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
||||||
"""Render correlation heatmap between parameters and score.
|
"""Render correlation bar chart between parameters and score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results: DataFrame with CV results.
|
results: DataFrame with CV results.
|
||||||
metric: The metric to analyze (e.g., 'f1', 'accuracy').
|
metric: The metric to analyze (e.g., 'f1', 'accuracy').
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader(f"🔗 Parameter Correlations with {metric.replace('_', ' ').title()}")
|
|
||||||
|
|
||||||
score_col = f"mean_test_{metric}"
|
score_col = f"mean_test_{metric}"
|
||||||
if score_col not in results.columns:
|
if score_col not in results.columns:
|
||||||
st.warning(f"Metric {metric} not found in results.")
|
st.warning(f"Metric {metric} not found in results.")
|
||||||
|
|
@ -329,6 +350,14 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
||||||
|
|
||||||
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
|
corr_df = pd.DataFrame(correlations).sort_values("Correlation", ascending=False)
|
||||||
|
|
||||||
|
# Get colormap from colors module
|
||||||
|
cmap = get_cmap("correlation")
|
||||||
|
# Sample colors for red-blue diverging scheme
|
||||||
|
colors = [cmap(i / (cmap.N - 1)) for i in range(cmap.N)]
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
hex_colors = [mcolors.rgb2hex(c) for c in colors]
|
||||||
|
|
||||||
# Create bar chart
|
# Create bar chart
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(corr_df)
|
alt.Chart(corr_df)
|
||||||
|
|
@ -348,13 +377,183 @@ def render_parameter_correlation(results: pd.DataFrame, metric: str):
|
||||||
|
|
||||||
st.altair_chart(chart, use_container_width=True)
|
st.altair_chart(chart, use_container_width=True)
|
||||||
|
|
||||||
# Show correlation table
|
|
||||||
with st.expander("📋 Correlation Table"):
|
def render_binned_parameter_space(results: pd.DataFrame, metric: str):
|
||||||
st.dataframe(
|
"""Render binned parameter space plots similar to old dashboard.
|
||||||
corr_df.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1, subset=["Correlation"]),
|
|
||||||
hide_index=True,
|
This creates plots where parameters are binned and plotted against each other,
|
||||||
use_container_width=True,
|
showing the metric value as color. Handles different hyperparameters dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: DataFrame with CV results.
|
||||||
|
metric: The metric to visualize (e.g., 'f1', 'accuracy').
|
||||||
|
|
||||||
|
"""
|
||||||
|
score_col = f"mean_test_{metric}"
|
||||||
|
if score_col not in results.columns:
|
||||||
|
st.warning(f"Metric {metric} not found in results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get numeric parameter columns
|
||||||
|
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||||
|
numeric_params = [col for col in param_cols if pd.api.types.is_numeric_dtype(results[col])]
|
||||||
|
|
||||||
|
if len(numeric_params) < 2:
|
||||||
|
st.info("Need at least 2 numeric parameters for binned parameter space analysis.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare binned data
|
||||||
|
results_binned = results.copy()
|
||||||
|
bin_info = {}
|
||||||
|
|
||||||
|
for param_col in numeric_params:
|
||||||
|
param_name = param_col.replace("param_", "")
|
||||||
|
param_values = results[param_col].dropna()
|
||||||
|
|
||||||
|
if len(param_values) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine if we should use log bins
|
||||||
|
value_range = param_values.max() / (param_values.min() + 1e-10)
|
||||||
|
use_log = value_range > 100 or param_values.max() < 1.0
|
||||||
|
|
||||||
|
if use_log:
|
||||||
|
# Logarithmic binning for parameters spanning many orders of magnitude
|
||||||
|
log_min = np.log10(param_values.min())
|
||||||
|
log_max = np.log10(param_values.max())
|
||||||
|
n_bins = min(10, int(log_max - log_min) + 1)
|
||||||
|
bins = np.logspace(log_min, log_max, num=n_bins)
|
||||||
|
else:
|
||||||
|
# Linear binning
|
||||||
|
n_bins = min(10, int(np.sqrt(len(param_values))))
|
||||||
|
bins = np.linspace(param_values.min(), param_values.max(), num=n_bins)
|
||||||
|
|
||||||
|
results_binned[f"{param_name}_binned"] = pd.cut(results[param_col], bins=bins)
|
||||||
|
bin_info[param_name] = {"use_log": use_log, "bins": bins}
|
||||||
|
|
||||||
|
# Get colormap from colors module
|
||||||
|
cmap = get_cmap(metric)
|
||||||
|
|
||||||
|
# Create binned plots for pairs of parameters
|
||||||
|
# Show the most meaningful combinations
|
||||||
|
param_names = [col.replace("param_", "") for col in numeric_params]
|
||||||
|
|
||||||
|
# If we have 3+ parameters, create multiple plots
|
||||||
|
if len(param_names) >= 3:
|
||||||
|
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
|
||||||
|
|
||||||
|
# Plot first parameter vs others, binned by third
|
||||||
|
for i in range(1, min(3, len(param_names))):
|
||||||
|
x_param = param_names[0]
|
||||||
|
y_param = param_names[i]
|
||||||
|
|
||||||
|
# Find a third parameter for faceting
|
||||||
|
facet_param = None
|
||||||
|
for p in param_names:
|
||||||
|
if p != x_param and p != y_param:
|
||||||
|
facet_param = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if facet_param and f"{facet_param}_binned" in results_binned.columns:
|
||||||
|
# Create faceted plot
|
||||||
|
plot_data = results_binned[
|
||||||
|
[f"param_{x_param}", f"param_{y_param}", f"{facet_param}_binned", score_col]
|
||||||
|
].dropna()
|
||||||
|
|
||||||
|
# Convert binned column to string with sorted categories
|
||||||
|
plot_data = plot_data.sort_values(f"{facet_param}_binned")
|
||||||
|
plot_data[f"{facet_param}_binned"] = plot_data[f"{facet_param}_binned"].astype(str)
|
||||||
|
bin_order = plot_data[f"{facet_param}_binned"].unique().tolist()
|
||||||
|
|
||||||
|
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||||
|
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||||
|
|
||||||
|
# Get colormap colors
|
||||||
|
cmap_colors = get_cmap(metric)
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
# Sample colors from colormap
|
||||||
|
hex_colors = [mcolors.rgb2hex(cmap_colors(i / (cmap_colors.N - 1))) for i in range(cmap_colors.N)]
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(plot_data)
|
||||||
|
.mark_circle(size=60, opacity=0.7)
|
||||||
|
.encode(
|
||||||
|
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||||
|
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||||
|
color=alt.Color(
|
||||||
|
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||||
|
f"{facet_param}_binned:N",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(width=200, height=200, title=f"{x_param} vs {y_param} (faceted by {facet_param})")
|
||||||
|
.facet(
|
||||||
|
facet=alt.Facet(f"{facet_param}_binned:N", title=facet_param, sort=bin_order),
|
||||||
|
columns=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
st.altair_chart(chart, use_container_width=True)
|
||||||
|
|
||||||
|
# Also create a simple 2D plot without faceting
|
||||||
|
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
|
||||||
|
|
||||||
|
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||||
|
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(plot_data)
|
||||||
|
.mark_circle(size=80, opacity=0.6)
|
||||||
|
.encode(
|
||||||
|
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||||
|
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||||
|
color=alt.Color(
|
||||||
|
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(height=400, title=f"{x_param} vs {y_param}")
|
||||||
|
)
|
||||||
|
st.altair_chart(chart, use_container_width=True)
|
||||||
|
|
||||||
|
elif len(param_names) == 2:
|
||||||
|
# Simple 2-parameter plot
|
||||||
|
st.caption(f"Parameter space exploration showing {metric.replace('_', ' ').title()} values")
|
||||||
|
|
||||||
|
x_param = param_names[0]
|
||||||
|
y_param = param_names[1]
|
||||||
|
|
||||||
|
plot_data = results_binned[[f"param_{x_param}", f"param_{y_param}", score_col]].dropna()
|
||||||
|
|
||||||
|
x_scale = alt.Scale(type="log") if bin_info[x_param]["use_log"] else alt.Scale()
|
||||||
|
y_scale = alt.Scale(type="log") if bin_info[y_param]["use_log"] else alt.Scale()
|
||||||
|
|
||||||
|
chart = (
|
||||||
|
alt.Chart(plot_data)
|
||||||
|
.mark_circle(size=100, opacity=0.6)
|
||||||
|
.encode(
|
||||||
|
x=alt.X(f"param_{x_param}:Q", title=x_param, scale=x_scale),
|
||||||
|
y=alt.Y(f"param_{y_param}:Q", title=y_param, scale=y_scale),
|
||||||
|
color=alt.Color(
|
||||||
|
f"{score_col}:Q", scale=alt.Scale(scheme="viridis"), title=metric.replace("_", " ").title()
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip(f"param_{x_param}:Q", format=".2e" if bin_info[x_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"param_{y_param}:Q", format=".2e" if bin_info[y_param]["use_log"] else ".3f"),
|
||||||
|
alt.Tooltip(f"{score_col}:Q", format=".4f"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.properties(height=500, title=f"{x_param} vs {y_param}")
|
||||||
)
|
)
|
||||||
|
st.altair_chart(chart, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
def render_score_evolution(results: pd.DataFrame, metric: str):
|
def render_score_evolution(results: pd.DataFrame, metric: str):
|
||||||
|
|
@ -422,8 +621,6 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
||||||
results: DataFrame with CV results.
|
results: DataFrame with CV results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("📊 Multi-Metric Comparison")
|
|
||||||
|
|
||||||
# Get all test score columns
|
# Get all test score columns
|
||||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||||
|
|
||||||
|
|
@ -462,6 +659,9 @@ def render_multi_metric_comparison(results: pd.DataFrame):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get colormap from colors module
|
||||||
|
cmap = get_cmap("iteration")
|
||||||
|
|
||||||
chart = (
|
chart = (
|
||||||
alt.Chart(df_plot)
|
alt.Chart(df_plot)
|
||||||
.mark_circle(size=60, opacity=0.6)
|
.mark_circle(size=60, opacity=0.6)
|
||||||
|
|
@ -494,8 +694,6 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
||||||
top_n: Number of top configurations to show.
|
top_n: Number of top configurations to show.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader(f"🏆 Top {top_n} Configurations by {metric.replace('_', ' ').title()}")
|
|
||||||
|
|
||||||
score_col = f"mean_test_{metric}"
|
score_col = f"mean_test_{metric}"
|
||||||
if score_col not in results.columns:
|
if score_col not in results.columns:
|
||||||
st.warning(f"Metric {metric} not found in results.")
|
st.warning(f"Metric {metric} not found in results.")
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ from shapely.geometry import shape
|
||||||
|
|
||||||
from entropice.dashboard.plots.colors import get_cmap
|
from entropice.dashboard.plots.colors import get_cmap
|
||||||
|
|
||||||
|
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
|
||||||
|
|
||||||
|
|
||||||
def _fix_hex_geometry(geom):
|
def _fix_hex_geometry(geom):
|
||||||
"""Fix hexagon geometry crossing the antimeridian."""
|
"""Fix hexagon geometry crossing the antimeridian."""
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||||
|
render_binned_parameter_space,
|
||||||
render_multi_metric_comparison,
|
render_multi_metric_comparison,
|
||||||
render_parameter_correlation,
|
render_parameter_correlation,
|
||||||
render_parameter_distributions,
|
render_parameter_distributions,
|
||||||
render_performance_summary,
|
render_performance_summary,
|
||||||
render_score_evolution,
|
|
||||||
render_score_vs_parameter,
|
|
||||||
render_top_configurations,
|
render_top_configurations,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.data import load_all_training_results
|
from entropice.dashboard.utils.data import load_all_training_results
|
||||||
|
|
@ -50,16 +49,9 @@ def render_training_analysis_page():
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Display selected run info
|
# Metric selection for detailed analysis
|
||||||
st.subheader("Run Information")
|
st.subheader("Analysis Settings")
|
||||||
st.write(f"**Task:** {selected_result.settings.get('task', 'Unknown').capitalize()}")
|
|
||||||
st.write(f"**Grid:** {selected_result.settings.get('grid', 'Unknown').capitalize()}")
|
|
||||||
st.write(f"**Level:** {selected_result.settings.get('level', 'Unknown')}")
|
|
||||||
st.write(f"**Model:** {selected_result.settings.get('model', 'Unknown').upper()}")
|
|
||||||
st.write(f"**Trials:** {len(selected_result.results)}")
|
|
||||||
st.write(f"**CV Splits:** {selected_result.settings.get('cv_splits', 'Unknown')}")
|
|
||||||
|
|
||||||
# Refit metric - determine from available metrics
|
|
||||||
available_metrics = get_available_metrics(selected_result.results)
|
available_metrics = get_available_metrics(selected_result.results)
|
||||||
|
|
||||||
# Try to get refit metric from settings
|
# Try to get refit metric from settings
|
||||||
|
|
@ -80,15 +72,6 @@ def render_training_analysis_page():
|
||||||
st.error("No metrics found in results.")
|
st.error("No metrics found in results.")
|
||||||
return
|
return
|
||||||
|
|
||||||
st.write(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Metric selection for detailed analysis
|
|
||||||
st.subheader("Analysis Settings")
|
|
||||||
|
|
||||||
available_metrics = get_available_metrics(selected_result.results)
|
|
||||||
|
|
||||||
if refit_metric in available_metrics:
|
if refit_metric in available_metrics:
|
||||||
default_metric_idx = available_metrics.index(refit_metric)
|
default_metric_idx = available_metrics.index(refit_metric)
|
||||||
else:
|
else:
|
||||||
|
|
@ -112,6 +95,27 @@ def render_training_analysis_page():
|
||||||
help="Number of top configurations to display",
|
help="Number of top configurations to display",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Main content area - Run Information at the top
|
||||||
|
st.header("📋 Run Information")
|
||||||
|
|
||||||
|
col1, col2, col3, col4, col5, col6 = st.columns(6)
|
||||||
|
with col1:
|
||||||
|
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
|
||||||
|
with col2:
|
||||||
|
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
|
||||||
|
with col3:
|
||||||
|
st.metric("Level", selected_result.settings.get("level", "Unknown"))
|
||||||
|
with col4:
|
||||||
|
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
|
||||||
|
with col5:
|
||||||
|
st.metric("Trials", len(selected_result.results))
|
||||||
|
with col6:
|
||||||
|
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
|
||||||
|
|
||||||
|
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
# Main content area
|
# Main content area
|
||||||
results = selected_result.results
|
results = selected_result.results
|
||||||
settings = selected_result.settings
|
settings = selected_result.settings
|
||||||
|
|
@ -151,14 +155,7 @@ def render_training_analysis_page():
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Score Evolution
|
# Parameter Space Analysis
|
||||||
st.header("📉 Training Progress")
|
|
||||||
|
|
||||||
render_score_evolution(results, selected_metric)
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Parameter Space Exploration
|
|
||||||
st.header("🔍 Parameter Space Analysis")
|
st.header("🔍 Parameter Space Analysis")
|
||||||
|
|
||||||
# Show parameter space summary
|
# Show parameter space summary
|
||||||
|
|
@ -170,19 +167,17 @@ def render_training_analysis_page():
|
||||||
st.info("No parameter information available.")
|
st.info("No parameter information available.")
|
||||||
|
|
||||||
# Parameter distributions
|
# Parameter distributions
|
||||||
|
st.subheader("📈 Parameter Distributions")
|
||||||
render_parameter_distributions(results)
|
render_parameter_distributions(results)
|
||||||
|
|
||||||
st.divider()
|
# Binned parameter space plots
|
||||||
|
st.subheader("🎨 Binned Parameter Space")
|
||||||
# Score vs Parameters
|
render_binned_parameter_space(results, selected_metric)
|
||||||
st.header("🎯 Parameter Impact Analysis")
|
|
||||||
|
|
||||||
render_score_vs_parameter(results, selected_metric)
|
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Parameter Correlation
|
# Parameter Correlation
|
||||||
st.header("🔗 Parameter Correlation Analysis")
|
st.header("🔗 Parameter Correlation")
|
||||||
|
|
||||||
render_parameter_correlation(results, selected_metric)
|
render_parameter_correlation(results, selected_metric)
|
||||||
|
|
||||||
|
|
@ -190,7 +185,7 @@ def render_training_analysis_page():
|
||||||
|
|
||||||
# Multi-Metric Comparison
|
# Multi-Metric Comparison
|
||||||
if len(available_metrics) >= 2:
|
if len(available_metrics) >= 2:
|
||||||
st.header("📊 Multi-Metric Analysis")
|
st.header("📊 Multi-Metric Comparison")
|
||||||
|
|
||||||
render_multi_metric_comparison(results)
|
render_multi_metric_comparison(results)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue