Refactor overview page
This commit is contained in:
parent
a304c96e4e
commit
1ee3d532fc
1 changed files with 302 additions and 257 deletions
|
|
@ -1,6 +1,8 @@
|
|||
"""Overview page: List of available result directories with some summary statistics."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
|
|
@ -11,22 +13,110 @@ from entropice.dashboard.utils.data import load_all_training_results
|
|||
from entropice.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.subheader("📊 Sample Counts by Configuration")
|
||||
# Type definitions for dataset statistics
|
||||
class GridConfig(TypedDict):
|
||||
"""Grid configuration specification with metadata."""
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
- **Grid Level**: varying by grid type
|
||||
"""
|
||||
grid: str
|
||||
level: int
|
||||
grid_name: str
|
||||
grid_sort_key: str
|
||||
disable_alphaearth: bool
|
||||
|
||||
|
||||
class SampleCountData(TypedDict):
|
||||
"""Sample count statistics for a specific grid/target/task combination."""
|
||||
|
||||
grid_config: GridConfig
|
||||
target: str
|
||||
task: str
|
||||
samples_coverage: int
|
||||
samples_labels: int
|
||||
samples_both: int
|
||||
|
||||
|
||||
class FeatureCountData(TypedDict):
|
||||
"""Feature count statistics for a specific grid configuration."""
|
||||
|
||||
grid_config: GridConfig
|
||||
total_features: int
|
||||
data_sources: list[str]
|
||||
inference_cells: int
|
||||
total_samples: int
|
||||
member_breakdown: dict[str, int]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetAnalysisCache:
|
||||
"""Cache for dataset analysis data to avoid redundant computations."""
|
||||
|
||||
grid_configs: list[GridConfig]
|
||||
sample_counts: list[SampleCountData]
|
||||
feature_counts: list[FeatureCountData]
|
||||
|
||||
def get_sample_count_df(self) -> pd.DataFrame:
|
||||
"""Convert sample count data to DataFrame."""
|
||||
rows = []
|
||||
for item in self.sample_counts:
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Grid Type": item["grid_config"]["grid"],
|
||||
"Level": item["grid_config"]["level"],
|
||||
"Target": item["target"].replace("darts_", ""),
|
||||
"Task": item["task"].capitalize(),
|
||||
"Samples (Coverage)": item["samples_coverage"],
|
||||
"Samples (Labels)": item["samples_labels"],
|
||||
"Samples (Both)": item["samples_both"],
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def get_feature_count_df(self) -> pd.DataFrame:
|
||||
"""Convert feature count data to DataFrame."""
|
||||
rows = []
|
||||
for item in self.feature_counts:
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Grid Type": item["grid_config"]["grid"],
|
||||
"Level": item["grid_config"]["level"],
|
||||
"Total Features": item["total_features"],
|
||||
"Data Sources": len(item["data_sources"]),
|
||||
"Inference Cells": item["inference_cells"],
|
||||
"Total Samples": item["total_samples"],
|
||||
"AlphaEarth": "AlphaEarth" in item["data_sources"],
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def get_feature_breakdown_df(self) -> pd.DataFrame:
|
||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||
rows = []
|
||||
for item in self.feature_counts:
|
||||
for source, count in item["member_breakdown"].items():
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Data Source": source,
|
||||
"Number of Features": count,
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
||||
"""Load and cache all dataset analysis data.
|
||||
|
||||
This function computes both sample counts and feature counts for all grid configurations.
|
||||
Results are cached to avoid redundant computations across different tabs.
|
||||
"""
|
||||
# Define all possible grid configurations
|
||||
grid_configs = [
|
||||
grid_configs_raw = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
|
|
@ -38,19 +128,34 @@ def render_sample_count_overview():
|
|||
("healpix", 10),
|
||||
]
|
||||
|
||||
# Create structured grid config objects
|
||||
grid_configs: list[GridConfig] = []
|
||||
for grid, level in grid_configs_raw:
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
grid_configs.append(
|
||||
{
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"grid_name": f"{grid}-{level}",
|
||||
"grid_sort_key": f"{grid}_{level:02d}",
|
||||
"disable_alphaearth": disable_alphaearth,
|
||||
}
|
||||
)
|
||||
|
||||
# Compute sample counts
|
||||
sample_counts: list[SampleCountData] = []
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
tasks = ["binary", "count", "density"]
|
||||
|
||||
# Collect sample counts
|
||||
sample_data = []
|
||||
|
||||
with st.spinner("Computing sample counts for all configurations..."):
|
||||
for grid, level in grid_configs:
|
||||
for grid_config in grid_configs:
|
||||
for target in target_datasets:
|
||||
# Create minimal ensemble just to get target data
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
|
||||
|
||||
# Read target data
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid_config["grid"],
|
||||
level=grid_config["level"],
|
||||
target=target,
|
||||
members=[], # type: ignore[arg-type]
|
||||
)
|
||||
targets = ensemble._read_target()
|
||||
|
||||
for task in tasks:
|
||||
|
|
@ -64,21 +169,85 @@ def render_sample_count_overview():
|
|||
valid_labels = targets[taskcol].notna().sum()
|
||||
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
||||
|
||||
sample_data.append(
|
||||
sample_counts.append(
|
||||
{
|
||||
"Grid": f"{grid}-{level}",
|
||||
"Grid Type": grid,
|
||||
"Level": level,
|
||||
"Target": target.replace("darts_", ""),
|
||||
"Task": task.capitalize(),
|
||||
"Samples (Coverage)": valid_coverage,
|
||||
"Samples (Labels)": valid_labels,
|
||||
"Samples (Both)": valid_both,
|
||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||
"grid_config": grid_config,
|
||||
"target": target,
|
||||
"task": task,
|
||||
"samples_coverage": valid_coverage,
|
||||
"samples_labels": valid_labels,
|
||||
"samples_both": valid_both,
|
||||
}
|
||||
)
|
||||
|
||||
sample_df = pd.DataFrame(sample_data)
|
||||
# Compute feature counts
|
||||
feature_counts: list[FeatureCountData] = []
|
||||
|
||||
for grid_config in grid_configs:
|
||||
# Determine which members are available for this configuration
|
||||
if grid_config["disable_alphaearth"]:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use darts_rts as default target for comparison
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid_config["grid"],
|
||||
level=grid_config["level"],
|
||||
target="darts_rts",
|
||||
members=members, # type: ignore[arg-type]
|
||||
)
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Calculate minimum cells across all data sources
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
|
||||
# Build member breakdown including lon/lat
|
||||
member_breakdown = {}
|
||||
for member, member_stats in stats["members"].items():
|
||||
member_breakdown[member] = member_stats["num_features"]
|
||||
|
||||
if ensemble.add_lonlat:
|
||||
member_breakdown["Lon/Lat"] = 2
|
||||
|
||||
feature_counts.append(
|
||||
{
|
||||
"grid_config": grid_config,
|
||||
"total_features": stats["total_features"],
|
||||
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
|
||||
"inference_cells": min_cells,
|
||||
"total_samples": stats["num_target_samples"],
|
||||
"member_breakdown": member_breakdown,
|
||||
}
|
||||
)
|
||||
|
||||
return DatasetAnalysisCache(
|
||||
grid_configs=grid_configs,
|
||||
sample_counts=sample_counts,
|
||||
feature_counts=feature_counts,
|
||||
)
|
||||
|
||||
|
||||
def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.subheader("📊 Sample Counts by Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
- **Grid Level**: varying by grid type
|
||||
"""
|
||||
)
|
||||
|
||||
# Get sample count DataFrame from cache
|
||||
sample_df = cache.get_sample_count_df()
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
|
||||
# Create tabs for different views
|
||||
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||
|
|
@ -159,73 +328,15 @@ def render_sample_count_overview():
|
|||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_count_fragment():
|
||||
"""Render interactive feature count visualization using fragments."""
|
||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the total number of features that would be generated
|
||||
for different combinations of data sources and grid configurations.
|
||||
"""
|
||||
)
|
||||
|
||||
# First section: Comparison across all grid configurations
|
||||
def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
||||
"""Render static comparison of feature counts across all grid configurations."""
|
||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||
|
||||
# Define all possible grid configurations
|
||||
grid_configs = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
]
|
||||
|
||||
# Collect feature statistics for all configurations
|
||||
feature_comparison_data = []
|
||||
|
||||
with st.spinner("Computing feature counts for all grid configurations..."):
|
||||
for grid, level in grid_configs:
|
||||
# Determine which members are available for this configuration
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use darts_rts as default target for comparison
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Calculate minimum cells across all data sources
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
|
||||
feature_comparison_data.append(
|
||||
{
|
||||
"Grid": f"{grid}-{level}",
|
||||
"Grid Type": grid,
|
||||
"Level": level,
|
||||
"Total Features": stats["total_features"],
|
||||
"Data Sources": len(members),
|
||||
"Inference Cells": min_cells,
|
||||
"Total Samples": stats["num_target_samples"],
|
||||
"AlphaEarth": "AlphaEarth" in members,
|
||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||
}
|
||||
)
|
||||
|
||||
comparison_df = pd.DataFrame(feature_comparison_data)
|
||||
# Get data from cache
|
||||
comparison_df = cache.get_feature_count_df()
|
||||
breakdown_df = cache.get_feature_breakdown_df()
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Create tabs for different comparison views
|
||||
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
||||
|
|
@ -233,56 +344,14 @@ def render_feature_count_fragment():
|
|||
with comp_tab1:
|
||||
st.markdown("#### Total Features by Grid Configuration")
|
||||
|
||||
# Collect breakdown data for stacked bar chart
|
||||
stacked_data = []
|
||||
|
||||
for idx, row in comparison_df.iterrows():
|
||||
grid_config = row["Grid"]
|
||||
grid, level_str = grid_config.split("-")
|
||||
level = int(level_str)
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Add data for each member
|
||||
for member, member_stats in stats["members"].items():
|
||||
stacked_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
stacked_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": "Lon/Lat",
|
||||
"Number of Features": 2,
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
stacked_df = pd.DataFrame(stacked_data)
|
||||
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = stacked_df["Data Source"].unique()
|
||||
unique_sources = breakdown_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
# Create stacked bar chart
|
||||
fig = px.bar(
|
||||
stacked_df,
|
||||
breakdown_df,
|
||||
x="Grid",
|
||||
y="Number of Features",
|
||||
color="Data Source",
|
||||
|
|
@ -336,58 +405,8 @@ def render_feature_count_fragment():
|
|||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||
|
||||
# Collect breakdown data for all grid configurations
|
||||
all_breakdown_data = []
|
||||
|
||||
for idx, row in comparison_df.iterrows():
|
||||
grid_config = row["Grid"]
|
||||
grid, level_str = grid_config.split("-")
|
||||
level = int(level_str)
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
total_features = stats["total_features"]
|
||||
|
||||
# Add data for each member with percentage
|
||||
for member, member_stats in stats["members"].items():
|
||||
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
|
||||
all_breakdown_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": member,
|
||||
"Percentage": percentage,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
percentage = (2 / total_features) * 100 # type: ignore[operator]
|
||||
all_breakdown_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": "Lon/Lat",
|
||||
"Percentage": percentage,
|
||||
"Number of Features": 2,
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_all_df = pd.DataFrame(all_breakdown_data)
|
||||
|
||||
# Sort by grid configuration
|
||||
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = breakdown_all_df["Data Source"].unique()
|
||||
unique_sources = breakdown_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
|
|
@ -403,7 +422,7 @@ def render_feature_count_fragment():
|
|||
grid_idx = row_idx * cols_per_row + col_idx
|
||||
if grid_idx < num_grids:
|
||||
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
|
||||
grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
||||
|
||||
with cols[col_idx]:
|
||||
fig = px.pie(
|
||||
|
|
@ -435,24 +454,15 @@ def render_feature_count_fragment():
|
|||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Second section: Detailed configuration with user selection
|
||||
@st.fragment
|
||||
def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||
"""Render interactive detailed configuration explorer using fragments."""
|
||||
st.markdown("### Detailed Configuration Explorer")
|
||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||
|
||||
# Grid selection
|
||||
grid_options = [
|
||||
"hex-3",
|
||||
"hex-4",
|
||||
"hex-5",
|
||||
"hex-6",
|
||||
"healpix-6",
|
||||
"healpix-7",
|
||||
"healpix-8",
|
||||
"healpix-9",
|
||||
"healpix-10",
|
||||
]
|
||||
grid_options = [gc["grid_name"] for gc in cache.grid_configs]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
|
|
@ -474,16 +484,15 @@ def render_feature_count_fragment():
|
|||
key="feature_target_select",
|
||||
)
|
||||
|
||||
# Parse grid type and level
|
||||
grid, level_str = grid_level_combined.split("-")
|
||||
level = int(level_str)
|
||||
# Find the selected grid config
|
||||
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
|
||||
grid = selected_grid_config["grid"]
|
||||
level = selected_grid_config["level"]
|
||||
disable_alphaearth = selected_grid_config["disable_alphaearth"]
|
||||
|
||||
# Members selection
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
# Check if AlphaEarth should be disabled
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use columns for checkboxes
|
||||
|
|
@ -624,35 +633,46 @@ def render_feature_count_fragment():
|
|||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_feature_count_section(cache: DatasetAnalysisCache):
|
||||
"""Render the feature count section with comparison and explorer."""
|
||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the total number of features that would be generated
|
||||
for different combinations of data sources and grid configurations.
|
||||
"""
|
||||
)
|
||||
|
||||
# Static comparison across all grids
|
||||
render_feature_count_comparison(cache)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Interactive explorer for detailed analysis
|
||||
render_feature_count_explorer(cache)
|
||||
|
||||
|
||||
def render_dataset_analysis():
|
||||
"""Render the dataset analysis section with sample and feature counts."""
|
||||
st.header("📈 Dataset Analysis")
|
||||
|
||||
# Load all data once and cache it
|
||||
with st.spinner("Loading dataset analysis data..."):
|
||||
cache = load_dataset_analysis_data()
|
||||
|
||||
# Create tabs for the two different analyses
|
||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||
|
||||
with analysis_tabs[0]:
|
||||
render_sample_count_overview()
|
||||
render_sample_count_overview(cache)
|
||||
|
||||
with analysis_tabs[1]:
|
||||
render_feature_count_fragment()
|
||||
render_feature_count_section(cache)
|
||||
|
||||
|
||||
def render_overview_page():
|
||||
"""Render the Overview page of the dashboard."""
|
||||
st.title("🏡 Training Results Overview")
|
||||
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
return
|
||||
|
||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Summary statistics at the top
|
||||
def render_training_results_summary(training_results):
|
||||
"""Render summary metrics for training results."""
|
||||
st.header("📊 Training Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
|
|
@ -673,14 +693,9 @@ def render_overview_page():
|
|||
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
||||
st.metric("Latest Run", latest_date)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Add dataset analysis section
|
||||
render_dataset_analysis()
|
||||
|
||||
st.divider()
|
||||
|
||||
# Detailed results table
|
||||
def render_experiment_results(training_results):
|
||||
"""Render detailed experiment results table and expandable details."""
|
||||
st.header("🎯 Experiment Results")
|
||||
st.subheader("Results Table")
|
||||
|
||||
|
|
@ -791,3 +806,33 @@ def render_overview_page():
|
|||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||
|
||||
st.write(f"\n**Path:** `{tr.path}`")
|
||||
|
||||
|
||||
def render_overview_page():
|
||||
"""Render the Overview page of the dashboard."""
|
||||
st.title("🏡 Training Results Overview")
|
||||
|
||||
# Load training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
return
|
||||
|
||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render training results sections
|
||||
render_training_results_summary(training_results)
|
||||
|
||||
st.divider()
|
||||
|
||||
render_experiment_results(training_results)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render dataset analysis section
|
||||
render_dataset_analysis()
|
||||
|
||||
st.balloons()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue