Add an analysis dashboard with streamlit
This commit is contained in:
parent
3e0e6e0d2d
commit
ccd40ace48
2 changed files with 237 additions and 1 deletions
|
|
@ -259,7 +259,7 @@ def plot_random_cv_results(file: Path):
|
||||||
figdir = file.parent
|
figdir = file.parent
|
||||||
|
|
||||||
# K-Plots
|
# K-Plots
|
||||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
metrics = ["f1"]
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
_plot_k_binned(
|
_plot_k_binned(
|
||||||
results,
|
results,
|
||||||
|
|
|
||||||
236
src/entropice/training_analysis_dashboard.py
Normal file
236
src/entropice/training_analysis_dashboard.py
Normal file
|
|
@ -0,0 +1,236 @@
|
||||||
|
"""Streamlit dashboard for training analysis results visualization."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.paths import RESULTS_DIR
|
||||||
|
|
||||||
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_k_binned(
|
||||||
|
results: pd.DataFrame,
|
||||||
|
target: str,
|
||||||
|
*,
|
||||||
|
vmin_percentile: float | None = None,
|
||||||
|
vmax_percentile: float | None = None,
|
||||||
|
):
|
||||||
|
"""Plot K-binned results with epsilon parameters."""
|
||||||
|
assert vmin_percentile is None or vmax_percentile is None, (
|
||||||
|
"Only one of vmin_percentile or vmax_percentile can be set."
|
||||||
|
)
|
||||||
|
assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results."
|
||||||
|
assert target in results.columns, f"{target} column not found in results."
|
||||||
|
assert "eps_e" in results.columns, "eps_e column not found in results."
|
||||||
|
assert "eps_cl" in results.columns, "eps_cl column not found in results."
|
||||||
|
|
||||||
|
# add a colorbar instead of the sampled legend
|
||||||
|
cmap = sns.color_palette("ch:", as_cmap=True)
|
||||||
|
# sophisticated normalization
|
||||||
|
if vmin_percentile is not None:
|
||||||
|
vmin = np.percentile(results[target], vmin_percentile)
|
||||||
|
norm = mcolors.Normalize(vmin=vmin)
|
||||||
|
elif vmax_percentile is not None:
|
||||||
|
vmax = np.percentile(results[target], vmax_percentile)
|
||||||
|
norm = mcolors.Normalize(vmax=vmax)
|
||||||
|
else:
|
||||||
|
norm = mcolors.Normalize()
|
||||||
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
||||||
|
|
||||||
|
# nice col-wrap based on columns
|
||||||
|
n_cols = results["initial_K_binned"].unique().size
|
||||||
|
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
|
||||||
|
|
||||||
|
scatter = sns.relplot(
|
||||||
|
data=results,
|
||||||
|
x="eps_e",
|
||||||
|
y="eps_cl",
|
||||||
|
hue=target,
|
||||||
|
hue_norm=sm.norm,
|
||||||
|
palette=cmap,
|
||||||
|
legend=False,
|
||||||
|
col="initial_K_binned",
|
||||||
|
col_wrap=col_wrap,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply log scale to all axes
|
||||||
|
for ax in scatter.axes.flat:
|
||||||
|
ax.set_xscale("log")
|
||||||
|
ax.set_yscale("log")
|
||||||
|
|
||||||
|
# Tight layout
|
||||||
|
scatter.figure.tight_layout()
|
||||||
|
|
||||||
|
# Add a shared colorbar at the bottom
|
||||||
|
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
|
||||||
|
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
|
||||||
|
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
|
||||||
|
cbar.set_label(target)
|
||||||
|
|
||||||
|
return scatter
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
|
"""Plot epsilon-binned results with K parameter."""
|
||||||
|
assert "initial_K" in results.columns, "initial_K column not found in results."
|
||||||
|
assert metric in results.columns, f"{metric} not found in results."
|
||||||
|
|
||||||
|
if target == "eps_cl":
|
||||||
|
hue = "eps_cl"
|
||||||
|
col = "eps_e_binned"
|
||||||
|
elif target == "eps_e":
|
||||||
|
hue = "eps_e"
|
||||||
|
col = "eps_cl_binned"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid target: {target}")
|
||||||
|
|
||||||
|
assert hue in results.columns, f"{hue} column not found in results."
|
||||||
|
assert col in results.columns, f"{col} column not found in results."
|
||||||
|
|
||||||
|
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_prepare_results(file_path: Path) -> pd.DataFrame:
|
||||||
|
"""Load results file and prepare binned columns."""
|
||||||
|
results = pd.read_parquet(file_path)
|
||||||
|
|
||||||
|
# Bin the initial_K into 40er bins
|
||||||
|
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
|
||||||
|
|
||||||
|
# Bin the eps_cl and eps_e into logarithmic bins
|
||||||
|
eps_cl_bins = np.logspace(-3, 7, num=10)
|
||||||
|
eps_e_bins = np.logspace(-3, 7, num=10)
|
||||||
|
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||||
|
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_result_files() -> list[Path]:
|
||||||
|
"""Get all available result files from RESULTS_DIR."""
|
||||||
|
if not RESULTS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
result_files = []
|
||||||
|
for search_dir in RESULTS_DIR.iterdir():
|
||||||
|
if search_dir.is_dir():
|
||||||
|
result_file = search_dir / "search_results.parquet"
|
||||||
|
if result_file.exists():
|
||||||
|
result_files.append(result_file)
|
||||||
|
|
||||||
|
return sorted(result_files, reverse=True) # Most recent first
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run Streamlit dashboard application."""
|
||||||
|
st.set_page_config(page_title="Training Analysis Dashboard", layout="wide")
|
||||||
|
|
||||||
|
st.title("Training Analysis Dashboard")
|
||||||
|
st.markdown("Interactive visualization of RandomizedSearchCV results")
|
||||||
|
|
||||||
|
# Sidebar for file and parameter selection
|
||||||
|
st.sidebar.header("Configuration")
|
||||||
|
|
||||||
|
# Get available result files
|
||||||
|
result_files = get_available_result_files()
|
||||||
|
|
||||||
|
if not result_files:
|
||||||
|
st.error(f"No result files found in {RESULTS_DIR}")
|
||||||
|
st.info("Please run a random CV search first to generate results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# File selection
|
||||||
|
file_options = {str(f.parent.name): f for f in result_files}
|
||||||
|
selected_file_name = st.sidebar.selectbox(
|
||||||
|
"Select Result File", options=list(file_options.keys()), help="Choose a search result file to visualize"
|
||||||
|
)
|
||||||
|
selected_file = file_options[selected_file_name]
|
||||||
|
|
||||||
|
# Load and prepare data
|
||||||
|
with st.spinner("Loading results..."):
|
||||||
|
results = load_and_prepare_results(selected_file)
|
||||||
|
|
||||||
|
st.sidebar.success(f"Loaded {len(results)} results")
|
||||||
|
|
||||||
|
# Metric selection
|
||||||
|
available_metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||||
|
selected_metric = st.sidebar.selectbox(
|
||||||
|
"Select Metric", options=available_metrics, help="Choose which metric to visualize"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Percentile normalization option
|
||||||
|
use_percentile = st.sidebar.checkbox(
|
||||||
|
"Use Percentile Normalization", value=True, help="Apply percentile-based color normalization to plots"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display some basic statistics
|
||||||
|
st.header("Dataset Overview")
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Runs", len(results))
|
||||||
|
with col2:
|
||||||
|
best_score = results[f"mean_test_{selected_metric}"].max()
|
||||||
|
st.metric(f"Best {selected_metric.capitalize()}", f"{best_score:.4f}")
|
||||||
|
with col3:
|
||||||
|
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||||
|
best_k = results.loc[best_idx, "initial_K"]
|
||||||
|
st.metric("Best K", f"{best_k:.0f}")
|
||||||
|
|
||||||
|
# Show best parameters
|
||||||
|
with st.expander("Best Parameters"):
|
||||||
|
best_idx = results[f"mean_test_{selected_metric}"].idxmax()
|
||||||
|
best_params = results.loc[best_idx, ["initial_K", "eps_cl", "eps_e", f"mean_test_{selected_metric}"]]
|
||||||
|
st.dataframe(best_params.to_frame().T, use_container_width=True)
|
||||||
|
|
||||||
|
# Main plots
|
||||||
|
st.header(f"Visualization for {selected_metric.capitalize()}")
|
||||||
|
|
||||||
|
# K-binned plots
|
||||||
|
st.subheader("K-Binned Parameter Space (Mean)")
|
||||||
|
with st.spinner("Generating mean plot..."):
|
||||||
|
if use_percentile:
|
||||||
|
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}", vmin_percentile=50)
|
||||||
|
else:
|
||||||
|
fig1 = _plot_k_binned(results, f"mean_test_{selected_metric}")
|
||||||
|
st.pyplot(fig1.figure)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
st.subheader("K-Binned Parameter Space (Std)")
|
||||||
|
with st.spinner("Generating std plot..."):
|
||||||
|
if use_percentile:
|
||||||
|
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}", vmax_percentile=50)
|
||||||
|
else:
|
||||||
|
fig2 = _plot_k_binned(results, f"std_test_{selected_metric}")
|
||||||
|
st.pyplot(fig2.figure)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# Epsilon-binned plots
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.subheader("K vs eps_cl")
|
||||||
|
with st.spinner("Generating eps_cl plot..."):
|
||||||
|
fig3 = _plot_eps_binned(results, "eps_cl", f"mean_test_{selected_metric}")
|
||||||
|
st.pyplot(fig3.figure)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.subheader("K vs eps_e")
|
||||||
|
with st.spinner("Generating eps_e plot..."):
|
||||||
|
fig4 = _plot_eps_binned(results, "eps_e", f"mean_test_{selected_metric}")
|
||||||
|
st.pyplot(fig4.figure)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# Optional: Raw data table
|
||||||
|
with st.expander("View Raw Results Data"):
|
||||||
|
st.dataframe(results, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue