Add Dataset Analysis to the overview page
This commit is contained in:
parent
6960571742
commit
a304c96e4e
1 changed files with 641 additions and 3 deletions
|
|
@ -3,9 +3,639 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.colors import get_palette
|
||||||
from entropice.dashboard.utils.data import load_all_training_results
|
from entropice.dashboard.utils.data import load_all_training_results
|
||||||
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
|
def render_sample_count_overview():
|
||||||
|
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||||
|
st.subheader("📊 Sample Counts by Configuration")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows the number of available samples for each combination of:
|
||||||
|
- **Task**: binary, count, density
|
||||||
|
- **Target Dataset**: darts_rts, darts_mllabels
|
||||||
|
- **Grid System**: hex, healpix
|
||||||
|
- **Grid Level**: varying by grid type
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define all possible grid configurations
|
||||||
|
grid_configs = [
|
||||||
|
("hex", 3),
|
||||||
|
("hex", 4),
|
||||||
|
("hex", 5),
|
||||||
|
("hex", 6),
|
||||||
|
("healpix", 6),
|
||||||
|
("healpix", 7),
|
||||||
|
("healpix", 8),
|
||||||
|
("healpix", 9),
|
||||||
|
("healpix", 10),
|
||||||
|
]
|
||||||
|
|
||||||
|
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||||
|
tasks = ["binary", "count", "density"]
|
||||||
|
|
||||||
|
# Collect sample counts
|
||||||
|
sample_data = []
|
||||||
|
|
||||||
|
with st.spinner("Computing sample counts for all configurations..."):
|
||||||
|
for grid, level in grid_configs:
|
||||||
|
for target in target_datasets:
|
||||||
|
# Create minimal ensemble just to get target data
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Read target data
|
||||||
|
targets = ensemble._read_target()
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
# Get task-specific column
|
||||||
|
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
|
||||||
|
covcol = ensemble.covcol
|
||||||
|
|
||||||
|
# Count samples with coverage and valid labels
|
||||||
|
if covcol in targets.columns and taskcol in targets.columns:
|
||||||
|
valid_coverage = targets[covcol].sum()
|
||||||
|
valid_labels = targets[taskcol].notna().sum()
|
||||||
|
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
||||||
|
|
||||||
|
sample_data.append(
|
||||||
|
{
|
||||||
|
"Grid": f"{grid}-{level}",
|
||||||
|
"Grid Type": grid,
|
||||||
|
"Level": level,
|
||||||
|
"Target": target.replace("darts_", ""),
|
||||||
|
"Task": task.capitalize(),
|
||||||
|
"Samples (Coverage)": valid_coverage,
|
||||||
|
"Samples (Labels)": valid_labels,
|
||||||
|
"Samples (Both)": valid_both,
|
||||||
|
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_df = pd.DataFrame(sample_data)
|
||||||
|
|
||||||
|
# Create tabs for different views
|
||||||
|
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||||
|
|
||||||
|
with tab1:
|
||||||
|
st.markdown("### Sample Counts Heatmap")
|
||||||
|
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||||
|
|
||||||
|
# Create heatmap for each target dataset
|
||||||
|
for target in target_datasets:
|
||||||
|
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||||
|
|
||||||
|
# Pivot for heatmap: Grid x Task
|
||||||
|
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
|
||||||
|
|
||||||
|
# Sort index by grid type and level
|
||||||
|
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||||
|
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
|
||||||
|
|
||||||
|
# Get color palette for sample counts
|
||||||
|
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
|
||||||
|
|
||||||
|
fig = px.imshow(
|
||||||
|
pivot_df,
|
||||||
|
labels={"x": "Task", "y": "Grid Configuration", "color": "Sample Count"},
|
||||||
|
x=pivot_df.columns,
|
||||||
|
y=pivot_df.index,
|
||||||
|
color_continuous_scale=sample_colors,
|
||||||
|
aspect="auto",
|
||||||
|
title=f"Target: {target}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add text annotations
|
||||||
|
fig.update_traces(text=pivot_df.values, texttemplate="%{text:,}", textfont_size=10)
|
||||||
|
|
||||||
|
fig.update_layout(height=400)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
with tab2:
|
||||||
|
st.markdown("### Sample Counts Bar Chart")
|
||||||
|
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||||
|
|
||||||
|
# Create a faceted bar chart showing both targets side by side
|
||||||
|
# Get color palette for tasks
|
||||||
|
n_tasks = sample_df["Task"].nunique()
|
||||||
|
task_colors = get_palette("task_types", n_colors=n_tasks)
|
||||||
|
|
||||||
|
fig = px.bar(
|
||||||
|
sample_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Samples (Both)",
|
||||||
|
color="Task",
|
||||||
|
facet_col="Target",
|
||||||
|
barmode="group",
|
||||||
|
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||||
|
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
|
||||||
|
color_discrete_sequence=task_colors,
|
||||||
|
height=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update facet labels to be cleaner
|
||||||
|
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
|
||||||
|
fig.update_xaxes(tickangle=-45)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
with tab3:
|
||||||
|
st.markdown("### Detailed Sample Counts")
|
||||||
|
|
||||||
|
# Display full table with formatting
|
||||||
|
display_df = sample_df[
|
||||||
|
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
|
||||||
|
].copy()
|
||||||
|
|
||||||
|
# Format numbers with commas
|
||||||
|
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
|
||||||
|
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||||
|
|
||||||
|
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_feature_count_fragment():
|
||||||
|
"""Render interactive feature count visualization using fragments."""
|
||||||
|
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows the total number of features that would be generated
|
||||||
|
for different combinations of data sources and grid configurations.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# First section: Comparison across all grid configurations
|
||||||
|
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||||
|
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||||
|
|
||||||
|
# Define all possible grid configurations
|
||||||
|
grid_configs = [
|
||||||
|
("hex", 3),
|
||||||
|
("hex", 4),
|
||||||
|
("hex", 5),
|
||||||
|
("hex", 6),
|
||||||
|
("healpix", 6),
|
||||||
|
("healpix", 7),
|
||||||
|
("healpix", 8),
|
||||||
|
("healpix", 9),
|
||||||
|
("healpix", 10),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Collect feature statistics for all configurations
|
||||||
|
feature_comparison_data = []
|
||||||
|
|
||||||
|
with st.spinner("Computing feature counts for all grid configurations..."):
|
||||||
|
for grid, level in grid_configs:
|
||||||
|
# Determine which members are available for this configuration
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
|
||||||
|
if disable_alphaearth:
|
||||||
|
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
else:
|
||||||
|
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
# Use darts_rts as default target for comparison
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||||
|
stats = ensemble.get_stats()
|
||||||
|
|
||||||
|
# Calculate minimum cells across all data sources
|
||||||
|
min_cells = min(
|
||||||
|
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||||
|
for member_stats in stats["members"].values()
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_comparison_data.append(
|
||||||
|
{
|
||||||
|
"Grid": f"{grid}-{level}",
|
||||||
|
"Grid Type": grid,
|
||||||
|
"Level": level,
|
||||||
|
"Total Features": stats["total_features"],
|
||||||
|
"Data Sources": len(members),
|
||||||
|
"Inference Cells": min_cells,
|
||||||
|
"Total Samples": stats["num_target_samples"],
|
||||||
|
"AlphaEarth": "AlphaEarth" in members,
|
||||||
|
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
comparison_df = pd.DataFrame(feature_comparison_data)
|
||||||
|
|
||||||
|
# Create tabs for different comparison views
|
||||||
|
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
||||||
|
|
||||||
|
with comp_tab1:
|
||||||
|
st.markdown("#### Total Features by Grid Configuration")
|
||||||
|
|
||||||
|
# Collect breakdown data for stacked bar chart
|
||||||
|
stacked_data = []
|
||||||
|
|
||||||
|
for idx, row in comparison_df.iterrows():
|
||||||
|
grid_config = row["Grid"]
|
||||||
|
grid, level_str = grid_config.split("-")
|
||||||
|
level = int(level_str)
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
|
||||||
|
if disable_alphaearth:
|
||||||
|
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
else:
|
||||||
|
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||||
|
stats = ensemble.get_stats()
|
||||||
|
|
||||||
|
# Add data for each member
|
||||||
|
for member, member_stats in stats["members"].items():
|
||||||
|
stacked_data.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config,
|
||||||
|
"Data Source": member,
|
||||||
|
"Number of Features": member_stats["num_features"],
|
||||||
|
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add lon/lat
|
||||||
|
if ensemble.add_lonlat:
|
||||||
|
stacked_data.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config,
|
||||||
|
"Data Source": "Lon/Lat",
|
||||||
|
"Number of Features": 2,
|
||||||
|
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stacked_df = pd.DataFrame(stacked_data)
|
||||||
|
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
|
||||||
|
|
||||||
|
# Get color palette for data sources
|
||||||
|
unique_sources = stacked_df["Data Source"].unique()
|
||||||
|
n_sources = len(unique_sources)
|
||||||
|
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||||
|
|
||||||
|
# Create stacked bar chart
|
||||||
|
fig = px.bar(
|
||||||
|
stacked_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Number of Features",
|
||||||
|
color="Data Source",
|
||||||
|
barmode="stack",
|
||||||
|
title="Total Features by Data Source Across Grid Configurations",
|
||||||
|
labels={"Grid": "Grid Configuration", "Number of Features": "Number of Features"},
|
||||||
|
color_discrete_sequence=source_colors,
|
||||||
|
text_auto=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(height=500, xaxis_tickangle=-45)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
# Add secondary metrics
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
# Get color palette for grid configs
|
||||||
|
n_grids = len(comparison_df)
|
||||||
|
grid_colors = get_palette("grid_configs", n_colors=n_grids)
|
||||||
|
|
||||||
|
fig_cells = px.bar(
|
||||||
|
comparison_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Inference Cells",
|
||||||
|
color="Grid",
|
||||||
|
title="Inference Cells by Grid Configuration",
|
||||||
|
labels={"Grid": "Grid Configuration", "Inference Cells": "Number of Cells"},
|
||||||
|
color_discrete_sequence=grid_colors,
|
||||||
|
text="Inference Cells",
|
||||||
|
)
|
||||||
|
fig_cells.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||||
|
fig_cells.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||||
|
st.plotly_chart(fig_cells, use_container_width=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
fig_samples = px.bar(
|
||||||
|
comparison_df,
|
||||||
|
x="Grid",
|
||||||
|
y="Total Samples",
|
||||||
|
color="Grid",
|
||||||
|
title="Total Samples by Grid Configuration",
|
||||||
|
labels={"Grid": "Grid Configuration", "Total Samples": "Number of Samples"},
|
||||||
|
color_discrete_sequence=grid_colors,
|
||||||
|
text="Total Samples",
|
||||||
|
)
|
||||||
|
fig_samples.update_traces(texttemplate="%{text:,}", textposition="outside")
|
||||||
|
fig_samples.update_layout(xaxis_tickangle=-45, showlegend=False)
|
||||||
|
st.plotly_chart(fig_samples, use_container_width=True)
|
||||||
|
|
||||||
|
with comp_tab2:
|
||||||
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
|
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||||
|
|
||||||
|
# Collect breakdown data for all grid configurations
|
||||||
|
all_breakdown_data = []
|
||||||
|
|
||||||
|
for idx, row in comparison_df.iterrows():
|
||||||
|
grid_config = row["Grid"]
|
||||||
|
grid, level_str = grid_config.split("-")
|
||||||
|
level = int(level_str)
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
|
||||||
|
if disable_alphaearth:
|
||||||
|
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
else:
|
||||||
|
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
||||||
|
stats = ensemble.get_stats()
|
||||||
|
|
||||||
|
total_features = stats["total_features"]
|
||||||
|
|
||||||
|
# Add data for each member with percentage
|
||||||
|
for member, member_stats in stats["members"].items():
|
||||||
|
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
|
||||||
|
all_breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config,
|
||||||
|
"Data Source": member,
|
||||||
|
"Percentage": percentage,
|
||||||
|
"Number of Features": member_stats["num_features"],
|
||||||
|
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add lon/lat
|
||||||
|
if ensemble.add_lonlat:
|
||||||
|
percentage = (2 / total_features) * 100 # type: ignore[operator]
|
||||||
|
all_breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config,
|
||||||
|
"Data Source": "Lon/Lat",
|
||||||
|
"Percentage": percentage,
|
||||||
|
"Number of Features": 2,
|
||||||
|
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
breakdown_all_df = pd.DataFrame(all_breakdown_data)
|
||||||
|
|
||||||
|
# Sort by grid configuration
|
||||||
|
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
|
||||||
|
|
||||||
|
# Get color palette for data sources
|
||||||
|
unique_sources = breakdown_all_df["Data Source"].unique()
|
||||||
|
n_sources = len(unique_sources)
|
||||||
|
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||||
|
|
||||||
|
# Create donut charts for each grid configuration
|
||||||
|
# Organize in a grid layout
|
||||||
|
num_grids = len(comparison_df)
|
||||||
|
cols_per_row = 3
|
||||||
|
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
|
||||||
|
|
||||||
|
for row_idx in range(num_rows):
|
||||||
|
cols = st.columns(cols_per_row)
|
||||||
|
for col_idx in range(cols_per_row):
|
||||||
|
grid_idx = row_idx * cols_per_row + col_idx
|
||||||
|
if grid_idx < num_grids:
|
||||||
|
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||||
|
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
|
||||||
|
|
||||||
|
with cols[col_idx]:
|
||||||
|
fig = px.pie(
|
||||||
|
grid_data,
|
||||||
|
names="Data Source",
|
||||||
|
values="Number of Features",
|
||||||
|
title=grid_config,
|
||||||
|
hole=0.4,
|
||||||
|
color_discrete_sequence=source_colors,
|
||||||
|
)
|
||||||
|
fig.update_traces(textposition="inside", textinfo="percent")
|
||||||
|
fig.update_layout(showlegend=True, height=350)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
with comp_tab3:
|
||||||
|
st.markdown("#### Detailed Feature Count Comparison")
|
||||||
|
|
||||||
|
# Display full comparison table with formatting
|
||||||
|
display_df = comparison_df[
|
||||||
|
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
|
||||||
|
].copy()
|
||||||
|
|
||||||
|
# Format numbers with commas
|
||||||
|
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
||||||
|
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||||
|
|
||||||
|
# Format boolean as Yes/No
|
||||||
|
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
|
||||||
|
|
||||||
|
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Second section: Detailed configuration with user selection
|
||||||
|
st.markdown("### Detailed Configuration Explorer")
|
||||||
|
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||||
|
|
||||||
|
# Grid selection
|
||||||
|
grid_options = [
|
||||||
|
"hex-3",
|
||||||
|
"hex-4",
|
||||||
|
"hex-5",
|
||||||
|
"hex-6",
|
||||||
|
"healpix-6",
|
||||||
|
"healpix-7",
|
||||||
|
"healpix-8",
|
||||||
|
"healpix-9",
|
||||||
|
"healpix-10",
|
||||||
|
]
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
grid_level_combined = st.selectbox(
|
||||||
|
"Grid Configuration",
|
||||||
|
options=grid_options,
|
||||||
|
index=0,
|
||||||
|
help="Select the grid system and resolution level",
|
||||||
|
key="feature_grid_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
target = st.selectbox(
|
||||||
|
"Target Dataset",
|
||||||
|
options=["darts_rts", "darts_mllabels"],
|
||||||
|
index=0,
|
||||||
|
help="Select the target dataset",
|
||||||
|
key="feature_target_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse grid type and level
|
||||||
|
grid, level_str = grid_level_combined.split("-")
|
||||||
|
level = int(level_str)
|
||||||
|
|
||||||
|
# Members selection
|
||||||
|
st.markdown("#### Select Data Sources")
|
||||||
|
|
||||||
|
# Check if AlphaEarth should be disabled
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
|
||||||
|
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
# Use columns for checkboxes
|
||||||
|
cols = st.columns(len(all_members))
|
||||||
|
selected_members = []
|
||||||
|
|
||||||
|
for idx, member in enumerate(all_members):
|
||||||
|
with cols[idx]:
|
||||||
|
if member == "AlphaEarth" and disable_alphaearth:
|
||||||
|
st.checkbox(
|
||||||
|
member,
|
||||||
|
value=False,
|
||||||
|
disabled=True,
|
||||||
|
help=f"Not available for {grid} level {level}",
|
||||||
|
key=f"feature_member_{member}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
|
||||||
|
selected_members.append(member)
|
||||||
|
|
||||||
|
# Show results if at least one member is selected
|
||||||
|
if selected_members:
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
||||||
|
|
||||||
|
with st.spinner("Computing dataset statistics..."):
|
||||||
|
stats = ensemble.get_stats()
|
||||||
|
|
||||||
|
# High-level metrics
|
||||||
|
col1, col2, col3, col4, col5 = st.columns(5)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Features", f"{stats['total_features']:,}")
|
||||||
|
with col2:
|
||||||
|
# Calculate minimum cells across all data sources (for inference capability)
|
||||||
|
min_cells = min(
|
||||||
|
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||||
|
for member_stats in stats["members"].values()
|
||||||
|
)
|
||||||
|
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
||||||
|
with col3:
|
||||||
|
st.metric("Data Sources", len(selected_members))
|
||||||
|
with col4:
|
||||||
|
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
|
||||||
|
with col5:
|
||||||
|
# Calculate total data points
|
||||||
|
total_points = stats["total_features"] * stats["num_target_samples"]
|
||||||
|
st.metric("Total Data Points", f"{total_points:,}")
|
||||||
|
|
||||||
|
# Feature breakdown by source
|
||||||
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
|
|
||||||
|
breakdown_data = []
|
||||||
|
for member, member_stats in stats["members"].items():
|
||||||
|
breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Data Source": member,
|
||||||
|
"Number of Features": member_stats["num_features"],
|
||||||
|
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add lon/lat
|
||||||
|
if ensemble.add_lonlat:
|
||||||
|
breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Data Source": "Lon/Lat",
|
||||||
|
"Number of Features": 2,
|
||||||
|
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
breakdown_df = pd.DataFrame(breakdown_data)
|
||||||
|
|
||||||
|
# Get color palette for data sources
|
||||||
|
n_sources = len(breakdown_df)
|
||||||
|
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||||
|
|
||||||
|
# Create pie chart
|
||||||
|
fig = px.pie(
|
||||||
|
breakdown_df,
|
||||||
|
names="Data Source",
|
||||||
|
values="Number of Features",
|
||||||
|
title="Feature Distribution by Data Source",
|
||||||
|
hole=0.4,
|
||||||
|
color_discrete_sequence=source_colors,
|
||||||
|
)
|
||||||
|
fig.update_traces(textposition="inside", textinfo="percent+label")
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
# Show detailed table
|
||||||
|
st.dataframe(breakdown_df, hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
|
# Detailed member information
|
||||||
|
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||||
|
for member, member_stats in stats["members"].items():
|
||||||
|
st.markdown(f"### {member}")
|
||||||
|
|
||||||
|
metric_cols = st.columns(4)
|
||||||
|
with metric_cols[0]:
|
||||||
|
st.metric("Features", member_stats["num_features"])
|
||||||
|
with metric_cols[1]:
|
||||||
|
st.metric("Variables", member_stats["num_variables"])
|
||||||
|
with metric_cols[2]:
|
||||||
|
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
||||||
|
st.metric("Shape", dim_str)
|
||||||
|
with metric_cols[3]:
|
||||||
|
total_points = 1
|
||||||
|
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
||||||
|
total_points *= dim_size
|
||||||
|
st.metric("Data Points", f"{total_points:,}")
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
st.markdown("**Variables:**")
|
||||||
|
vars_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||||
|
for v in member_stats["variables"] # type: ignore[union-attr]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(vars_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Dimensions
|
||||||
|
st.markdown("**Dimensions:**")
|
||||||
|
dim_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||||
|
f"{dim_name}: {dim_size}</span>"
|
||||||
|
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
else:
|
||||||
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
||||||
|
def render_dataset_analysis():
|
||||||
|
"""Render the dataset analysis section with sample and feature counts."""
|
||||||
|
st.header("📈 Dataset Analysis")
|
||||||
|
|
||||||
|
# Create tabs for the two different analyses
|
||||||
|
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||||
|
|
||||||
|
with analysis_tabs[0]:
|
||||||
|
render_sample_count_overview()
|
||||||
|
|
||||||
|
with analysis_tabs[1]:
|
||||||
|
render_feature_count_fragment()
|
||||||
|
|
||||||
|
|
||||||
def render_overview_page():
|
def render_overview_page():
|
||||||
|
|
@ -20,8 +650,10 @@ def render_overview_page():
|
||||||
|
|
||||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
# Summary statistics at the top
|
# Summary statistics at the top
|
||||||
st.subheader("Summary Statistics")
|
st.header("📊 Training Results Summary")
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
|
|
@ -43,8 +675,14 @@ def render_overview_page():
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
|
# Add dataset analysis section
|
||||||
|
render_dataset_analysis()
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
# Detailed results table
|
# Detailed results table
|
||||||
st.subheader("Training Results")
|
st.header("🎯 Experiment Results")
|
||||||
|
st.subheader("Results Table")
|
||||||
|
|
||||||
# Build a summary dataframe
|
# Build a summary dataframe
|
||||||
summary_data = []
|
summary_data = []
|
||||||
|
|
@ -93,7 +731,7 @@ def render_overview_page():
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Expandable details for each result
|
# Expandable details for each result
|
||||||
st.subheader("Detailed Results")
|
st.subheader("Individual Experiment Details")
|
||||||
|
|
||||||
for tr in training_results:
|
for tr in training_results:
|
||||||
with st.expander(tr.get_display_name("task_first")):
|
with st.expander(tr.get_display_name("task_first")):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue