Add Dataset Analysis to the overview page
This commit is contained in:
parent
6960571742
commit
a304c96e4e
1 changed files with 641 additions and 3 deletions
|
|
@ -3,9 +3,639 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.colors import get_palette
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
from entropice.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.subheader("📊 Sample Counts by Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the number of available samples for each combination of:
|
||||
- **Task**: binary, count, density
|
||||
- **Target Dataset**: darts_rts, darts_mllabels
|
||||
- **Grid System**: hex, healpix
|
||||
- **Grid Level**: varying by grid type
|
||||
"""
|
||||
)
|
||||
|
||||
# Define all possible grid configurations
|
||||
grid_configs = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
]
|
||||
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
tasks = ["binary", "count", "density"]
|
||||
|
||||
# Collect sample counts
|
||||
sample_data = []
|
||||
|
||||
with st.spinner("Computing sample counts for all configurations..."):
|
||||
for grid, level in grid_configs:
|
||||
for target in target_datasets:
|
||||
# Create minimal ensemble just to get target data
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
|
||||
|
||||
# Read target data
|
||||
targets = ensemble._read_target()
|
||||
|
||||
for task in tasks:
|
||||
# Get task-specific column
|
||||
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
|
||||
covcol = ensemble.covcol
|
||||
|
||||
# Count samples with coverage and valid labels
|
||||
if covcol in targets.columns and taskcol in targets.columns:
|
||||
valid_coverage = targets[covcol].sum()
|
||||
valid_labels = targets[taskcol].notna().sum()
|
||||
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
||||
|
||||
sample_data.append(
|
||||
{
|
||||
"Grid": f"{grid}-{level}",
|
||||
"Grid Type": grid,
|
||||
"Level": level,
|
||||
"Target": target.replace("darts_", ""),
|
||||
"Task": task.capitalize(),
|
||||
"Samples (Coverage)": valid_coverage,
|
||||
"Samples (Labels)": valid_labels,
|
||||
"Samples (Both)": valid_both,
|
||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||
}
|
||||
)
|
||||
|
||||
sample_df = pd.DataFrame(sample_data)
|
||||
|
||||
# Create tabs for different views
|
||||
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||
|
||||
with tab1:
|
||||
st.markdown("### Sample Counts Heatmap")
|
||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||
|
||||
# Create heatmap for each target dataset
|
||||
for target in target_datasets:
|
||||
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||
|
||||
# Pivot for heatmap: Grid x Task
|
||||
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
|
||||
|
||||
# Sort index by grid type and level
|
||||
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
|
||||
|
||||
# Get color palette for sample counts
|
||||
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
|
||||
|
||||
fig = px.imshow(
|
||||
pivot_df,
|
||||
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
|
||||
x=pivot_df.columns,
|
||||
y=pivot_df.index,
|
||||
color_continuous_scale=sample_colors,
|
||||
aspect="auto",
|
||||
title=f"Target: {target}",
|
||||
)
|
||||
|
||||
# Add text annotations
|
||||
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
|
||||
|
||||
fig.update_layout(height=400)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with tab2:
|
||||
st.markdown("### Sample Counts Bar Chart")
|
||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||
|
||||
# Create a faceted bar chart showing both targets side by side
|
||||
# Get color palette for tasks
|
||||
n_tasks = sample_df["Task"].nunique()
|
||||
task_colors = get_palette("task_types", n_colors=n_tasks)
|
||||
|
||||
fig = px.bar(
|
||||
sample_df,
|
||||
x="Grid",
|
||||
y="Samples (Both)",
|
||||
color="Task",
|
||||
facet_col="Target",
|
||||
barmode="group",
|
||||
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
|
||||
color_discrete_sequence=task_colors,
|
||||
height=500,
|
||||
)
|
||||
|
||||
# Update facet labels to be cleaner
|
||||
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
|
||||
fig.update_xaxes(tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with tab3:
|
||||
st.markdown("### Detailed Sample Counts")
|
||||
|
||||
# Display full table with formatting
|
||||
display_df = sample_df[
|
||||
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_count_fragment():
|
||||
"""Render interactive feature count visualization using fragments."""
|
||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
This visualization shows the total number of features that would be generated
|
||||
for different combinations of data sources and grid configurations.
|
||||
"""
|
||||
)
|
||||
|
||||
# First section: Comparison across all grid configurations
|
||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||
|
||||
# Define all possible grid configurations
|
||||
grid_configs = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
]
|
||||
|
||||
# Collect feature statistics for all configurations
|
||||
feature_comparison_data = []
|
||||
|
||||
with st.spinner("Computing feature counts for all grid configurations..."):
|
||||
for grid, level in grid_configs:
|
||||
# Determine which members are available for this configuration
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use darts_rts as default target for comparison
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Calculate minimum cells across all data sources
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
|
||||
feature_comparison_data.append(
|
||||
{
|
||||
"Grid": f"{grid}-{level}",
|
||||
"Grid Type": grid,
|
||||
"Level": level,
|
||||
"Total Features": stats["total_features"],
|
||||
"Data Sources": len(members),
|
||||
"Inference Cells": min_cells,
|
||||
"Total Samples": stats["num_target_samples"],
|
||||
"AlphaEarth": "AlphaEarth" in members,
|
||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||
}
|
||||
)
|
||||
|
||||
comparison_df = pd.DataFrame(feature_comparison_data)
|
||||
|
||||
# Create tabs for different comparison views
|
||||
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
||||
|
||||
with comp_tab1:
|
||||
st.markdown("#### Total Features by Grid Configuration")
|
||||
|
||||
# Collect breakdown data for stacked bar chart
|
||||
stacked_data = []
|
||||
|
||||
for idx, row in comparison_df.iterrows():
|
||||
grid_config = row["Grid"]
|
||||
grid, level_str = grid_config.split("-")
|
||||
level = int(level_str)
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Add data for each member
|
||||
for member, member_stats in stats["members"].items():
|
||||
stacked_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
stacked_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": "Lon/Lat",
|
||||
"Number of Features": 2,
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
stacked_df = pd.DataFrame(stacked_data)
|
||||
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = stacked_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
# Create stacked bar chart
|
||||
fig = px.bar(
|
||||
stacked_df,
|
||||
x="Grid",
|
||||
y="Number of Features",
|
||||
color="Data Source",
|
||||
barmode="stack",
|
||||
title="Total Features by Data Source Across Grid Configurations",
|
||||
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
|
||||
color_discrete_sequence=source_colors,
|
||||
text_auto=False,
|
||||
)
|
||||
|
||||
fig.update_layout(height=500, xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Add secondary metrics
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
# Get color palette for grid configs
|
||||
n_grids = len(comparison_df)
|
||||
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
||||
|
||||
fig_cells = px.bar(
|
||||
comparison_df,
|
||||
x="Grid",
|
||||
y="Inference Cells",
|
||||
color="Grid",
|
||||
title="Inference Cells by Grid Configuration",
|
||||
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Inference Cells",
|
||||
)
|
||||
fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||
st.plotly_chart(fig_cells, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
fig_samples = px.bar(
|
||||
comparison_df,
|
||||
x="Grid",
|
||||
y="Total Samples",
|
||||
color="Grid",
|
||||
title="Total Samples by Grid Configuration",
|
||||
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
|
||||
color_discrete_sequence=grid_colors,
|
||||
text="Total Samples",
|
||||
)
|
||||
fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||
fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||
st.plotly_chart(fig_samples, use_container_width=True)
|
||||
|
||||
with comp_tab2:
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||
|
||||
# Collect breakdown data for all grid configurations
|
||||
all_breakdown_data = []
|
||||
|
||||
for idx, row in comparison_df.iterrows():
|
||||
grid_config = row["Grid"]
|
||||
grid, level_str = grid_config.split("-")
|
||||
level = int(level_str)
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
if disable_alphaearth:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
total_features = stats["total_features"]
|
||||
|
||||
# Add data for each member with percentage
|
||||
for member, member_stats in stats["members"].items():
|
||||
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
|
||||
all_breakdown_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": member,
|
||||
"Percentage": percentage,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
percentage = (2 / total_features) * 100 # type: ignore[operator]
|
||||
all_breakdown_data.append(
|
||||
{
|
||||
"Grid": grid_config,
|
||||
"Data Source": "Lon/Lat",
|
||||
"Percentage": percentage,
|
||||
"Number of Features": 2,
|
||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_all_df = pd.DataFrame(all_breakdown_data)
|
||||
|
||||
# Sort by grid configuration
|
||||
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Get color palette for data sources
|
||||
unique_sources = breakdown_all_df["Data Source"].unique()
|
||||
n_sources = len(unique_sources)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
# Create donut charts for each grid configuration
|
||||
# Organize in a grid layout
|
||||
num_grids = len(comparison_df)
|
||||
cols_per_row = 3
|
||||
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
||||
|
||||
for row_idx in range(num_rows):
|
||||
cols = st.columns(cols_per_row)
|
||||
for col_idx in range(cols_per_row):
|
||||
grid_idx = row_idx * cols_per_row + col_idx
|
||||
if grid_idx < num_grids:
|
||||
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
|
||||
|
||||
with cols[col_idx]:
|
||||
fig = px.pie(
|
||||
grid_data,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title=grid_config,
|
||||
hole=0.4,
|
||||
color_discrete_sequence=source_colors,
|
||||
)
|
||||
fig.update_traces(textposition="inside", textinfo="percent")
|
||||
fig.update_layout(showlegend=True, height=350)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with comp_tab3:
|
||||
st.markdown("#### Detailed Feature Count Comparison")
|
||||
|
||||
# Display full comparison table with formatting
|
||||
display_df = comparison_df[
|
||||
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
# Format boolean as Yes/No
|
||||
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Second section: Detailed configuration with user selection
|
||||
st.markdown("### Detailed Configuration Explorer")
|
||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||
|
||||
# Grid selection
|
||||
grid_options = [
|
||||
"hex-3",
|
||||
"hex-4",
|
||||
"hex-5",
|
||||
"hex-6",
|
||||
"healpix-6",
|
||||
"healpix-7",
|
||||
"healpix-8",
|
||||
"healpix-9",
|
||||
"healpix-10",
|
||||
]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
grid_level_combined = st.selectbox(
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
help="Select the grid system and resolution level",
|
||||
key="feature_grid_select",
|
||||
)
|
||||
|
||||
with col2:
|
||||
target = st.selectbox(
|
||||
"Target Dataset",
|
||||
options=["darts_rts", "darts_mllabels"],
|
||||
index=0,
|
||||
help="Select the target dataset",
|
||||
key="feature_target_select",
|
||||
)
|
||||
|
||||
# Parse grid type and level
|
||||
grid, level_str = grid_level_combined.split("-")
|
||||
level = int(level_str)
|
||||
|
||||
# Members selection
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
# Check if AlphaEarth should be disabled
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use columns for checkboxes
|
||||
cols = st.columns(len(all_members))
|
||||
selected_members = []
|
||||
|
||||
for idx, member in enumerate(all_members):
|
||||
with cols[idx]:
|
||||
if member == "AlphaEarth" and disable_alphaearth:
|
||||
st.checkbox(
|
||||
member,
|
||||
value=False,
|
||||
disabled=True,
|
||||
help=f"Not available for {grid} level {level}",
|
||||
key=f"feature_member_{member}",
|
||||
)
|
||||
else:
|
||||
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
|
||||
selected_members.append(member)
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
st.markdown("---")
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
||||
|
||||
with st.spinner("Computing dataset statistics..."):
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# High-level metrics
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{stats['total_features']:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = stats["total_features"] * stats["num_target_samples"]
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in stats["members"].items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": "Lon/Lat",
|
||||
"Number of Features": 2,
|
||||
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get color palette for data sources
|
||||
n_sources = len(breakdown_df)
|
||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||
|
||||
# Create pie chart
|
||||
fig = px.pie(
|
||||
breakdown_df,
|
||||
names="Data Source",
|
||||
values="Number of Features",
|
||||
title="Feature Distribution by Data Source",
|
||||
hole=0.4,
|
||||
color_discrete_sequence=source_colors,
|
||||
)
|
||||
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, use_container_width=True)
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
for member, member_stats in stats["members"].items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
metric_cols = st.columns(4)
|
||||
with metric_cols[0]:
|
||||
st.metric("Features", member_stats["num_features"])
|
||||
with metric_cols[1]:
|
||||
st.metric("Variables", member_stats["num_variables"])
|
||||
with metric_cols[2]:
|
||||
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
||||
st.metric("Shape", dim_str)
|
||||
with metric_cols[3]:
|
||||
total_points = 1
|
||||
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
||||
total_points *= dim_size
|
||||
st.metric("Data Points", f"{total_points:,}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats["variables"] # type: ignore[union-attr]
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size}</span>"
|
||||
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
st.markdown("---")
|
||||
else:
|
||||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_dataset_analysis():
|
||||
"""Render the dataset analysis section with sample and feature counts."""
|
||||
st.header("📈 Dataset Analysis")
|
||||
|
||||
# Create tabs for the two different analyses
|
||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||
|
||||
with analysis_tabs[0]:
|
||||
render_sample_count_overview()
|
||||
|
||||
with analysis_tabs[1]:
|
||||
render_feature_count_fragment()
|
||||
|
||||
|
||||
def render_overview_page():
|
||||
|
|
@ -20,8 +650,10 @@ def render_overview_page():
|
|||
|
||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Summary statistics at the top
|
||||
st.subheader("Summary Statistics")
|
||||
st.header("📊 Training Results Summary")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
|
|
@ -43,8 +675,14 @@ def render_overview_page():
|
|||
|
||||
st.divider()
|
||||
|
||||
# Add dataset analysis section
|
||||
render_dataset_analysis()
|
||||
|
||||
st.divider()
|
||||
|
||||
# Detailed results table
|
||||
st.subheader("Training Results")
|
||||
st.header("🎯 Experiment Results")
|
||||
st.subheader("Results Table")
|
||||
|
||||
# Build a summary dataframe
|
||||
summary_data = []
|
||||
|
|
@ -93,7 +731,7 @@ def render_overview_page():
|
|||
st.divider()
|
||||
|
||||
# Expandable details for each result
|
||||
st.subheader("Detailed Results")
|
||||
st.subheader("Individual Experiment Details")
|
||||
|
||||
for tr in training_results:
|
||||
with st.expander(tr.get_display_name("task_first")):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue