Refactor overview page
This commit is contained in:
parent
a304c96e4e
commit
1ee3d532fc
1 changed files with 302 additions and 257 deletions
|
|
@ -1,6 +1,8 @@
|
||||||
"""Overview page: List of available result directories with some summary statistics."""
|
"""Overview page: List of available result directories with some summary statistics."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
|
|
@ -11,22 +13,110 @@ from entropice.dashboard.utils.data import load_all_training_results
|
||||||
from entropice.dataset import DatasetEnsemble
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
def render_sample_count_overview():
|
# Type definitions for dataset statistics
|
||||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
class GridConfig(TypedDict):
|
||||||
st.subheader("📊 Sample Counts by Configuration")
|
"""Grid configuration specification with metadata."""
|
||||||
|
|
||||||
st.markdown(
|
grid: str
|
||||||
"""
|
level: int
|
||||||
This visualization shows the number of available samples for each combination of:
|
grid_name: str
|
||||||
- **Task**: binary, count, density
|
grid_sort_key: str
|
||||||
- **Target Dataset**: darts_rts, darts_mllabels
|
disable_alphaearth: bool
|
||||||
- **Grid System**: hex, healpix
|
|
||||||
- **Grid Level**: varying by grid type
|
|
||||||
"""
|
class SampleCountData(TypedDict):
|
||||||
|
"""Sample count statistics for a specific grid/target/task combination."""
|
||||||
|
|
||||||
|
grid_config: GridConfig
|
||||||
|
target: str
|
||||||
|
task: str
|
||||||
|
samples_coverage: int
|
||||||
|
samples_labels: int
|
||||||
|
samples_both: int
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureCountData(TypedDict):
|
||||||
|
"""Feature count statistics for a specific grid configuration."""
|
||||||
|
|
||||||
|
grid_config: GridConfig
|
||||||
|
total_features: int
|
||||||
|
data_sources: list[str]
|
||||||
|
inference_cells: int
|
||||||
|
total_samples: int
|
||||||
|
member_breakdown: dict[str, int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatasetAnalysisCache:
|
||||||
|
"""Cache for dataset analysis data to avoid redundant computations."""
|
||||||
|
|
||||||
|
grid_configs: list[GridConfig]
|
||||||
|
sample_counts: list[SampleCountData]
|
||||||
|
feature_counts: list[FeatureCountData]
|
||||||
|
|
||||||
|
def get_sample_count_df(self) -> pd.DataFrame:
|
||||||
|
"""Convert sample count data to DataFrame."""
|
||||||
|
rows = []
|
||||||
|
for item in self.sample_counts:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": item["grid_config"]["grid_name"],
|
||||||
|
"Grid Type": item["grid_config"]["grid"],
|
||||||
|
"Level": item["grid_config"]["level"],
|
||||||
|
"Target": item["target"].replace("darts_", ""),
|
||||||
|
"Task": item["task"].capitalize(),
|
||||||
|
"Samples (Coverage)": item["samples_coverage"],
|
||||||
|
"Samples (Labels)": item["samples_labels"],
|
||||||
|
"Samples (Both)": item["samples_both"],
|
||||||
|
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
def get_feature_count_df(self) -> pd.DataFrame:
|
||||||
|
"""Convert feature count data to DataFrame."""
|
||||||
|
rows = []
|
||||||
|
for item in self.feature_counts:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": item["grid_config"]["grid_name"],
|
||||||
|
"Grid Type": item["grid_config"]["grid"],
|
||||||
|
"Level": item["grid_config"]["level"],
|
||||||
|
"Total Features": item["total_features"],
|
||||||
|
"Data Sources": len(item["data_sources"]),
|
||||||
|
"Inference Cells": item["inference_cells"],
|
||||||
|
"Total Samples": item["total_samples"],
|
||||||
|
"AlphaEarth": "AlphaEarth" in item["data_sources"],
|
||||||
|
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
def get_feature_breakdown_df(self) -> pd.DataFrame:
|
||||||
|
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||||
|
rows = []
|
||||||
|
for item in self.feature_counts:
|
||||||
|
for source, count in item["member_breakdown"].items():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": item["grid_config"]["grid_name"],
|
||||||
|
"Data Source": source,
|
||||||
|
"Number of Features": count,
|
||||||
|
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data(show_spinner=False)
|
||||||
|
def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
||||||
|
"""Load and cache all dataset analysis data.
|
||||||
|
|
||||||
|
This function computes both sample counts and feature counts for all grid configurations.
|
||||||
|
Results are cached to avoid redundant computations across different tabs.
|
||||||
|
"""
|
||||||
# Define all possible grid configurations
|
# Define all possible grid configurations
|
||||||
grid_configs = [
|
grid_configs_raw = [
|
||||||
("hex", 3),
|
("hex", 3),
|
||||||
("hex", 4),
|
("hex", 4),
|
||||||
("hex", 5),
|
("hex", 5),
|
||||||
|
|
@ -38,19 +128,34 @@ def render_sample_count_overview():
|
||||||
("healpix", 10),
|
("healpix", 10),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Create structured grid config objects
|
||||||
|
grid_configs: list[GridConfig] = []
|
||||||
|
for grid, level in grid_configs_raw:
|
||||||
|
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||||
|
grid_configs.append(
|
||||||
|
{
|
||||||
|
"grid": grid,
|
||||||
|
"level": level,
|
||||||
|
"grid_name": f"{grid}-{level}",
|
||||||
|
"grid_sort_key": f"{grid}_{level:02d}",
|
||||||
|
"disable_alphaearth": disable_alphaearth,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute sample counts
|
||||||
|
sample_counts: list[SampleCountData] = []
|
||||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||||
tasks = ["binary", "count", "density"]
|
tasks = ["binary", "count", "density"]
|
||||||
|
|
||||||
# Collect sample counts
|
for grid_config in grid_configs:
|
||||||
sample_data = []
|
|
||||||
|
|
||||||
with st.spinner("Computing sample counts for all configurations..."):
|
|
||||||
for grid, level in grid_configs:
|
|
||||||
for target in target_datasets:
|
for target in target_datasets:
|
||||||
# Create minimal ensemble just to get target data
|
# Create minimal ensemble just to get target data
|
||||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=[]) # type: ignore[arg-type]
|
ensemble = DatasetEnsemble(
|
||||||
|
grid=grid_config["grid"],
|
||||||
# Read target data
|
level=grid_config["level"],
|
||||||
|
target=target,
|
||||||
|
members=[], # type: ignore[arg-type]
|
||||||
|
)
|
||||||
targets = ensemble._read_target()
|
targets = ensemble._read_target()
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
|
|
@ -64,21 +169,85 @@ def render_sample_count_overview():
|
||||||
valid_labels = targets[taskcol].notna().sum()
|
valid_labels = targets[taskcol].notna().sum()
|
||||||
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
||||||
|
|
||||||
sample_data.append(
|
sample_counts.append(
|
||||||
{
|
{
|
||||||
"Grid": f"{grid}-{level}",
|
"grid_config": grid_config,
|
||||||
"Grid Type": grid,
|
"target": target,
|
||||||
"Level": level,
|
"task": task,
|
||||||
"Target": target.replace("darts_", ""),
|
"samples_coverage": valid_coverage,
|
||||||
"Task": task.capitalize(),
|
"samples_labels": valid_labels,
|
||||||
"Samples (Coverage)": valid_coverage,
|
"samples_both": valid_both,
|
||||||
"Samples (Labels)": valid_labels,
|
|
||||||
"Samples (Both)": valid_both,
|
|
||||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_df = pd.DataFrame(sample_data)
|
# Compute feature counts
|
||||||
|
feature_counts: list[FeatureCountData] = []
|
||||||
|
|
||||||
|
for grid_config in grid_configs:
|
||||||
|
# Determine which members are available for this configuration
|
||||||
|
if grid_config["disable_alphaearth"]:
|
||||||
|
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
else:
|
||||||
|
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
|
# Use darts_rts as default target for comparison
|
||||||
|
ensemble = DatasetEnsemble(
|
||||||
|
grid=grid_config["grid"],
|
||||||
|
level=grid_config["level"],
|
||||||
|
target="darts_rts",
|
||||||
|
members=members, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
stats = ensemble.get_stats()
|
||||||
|
|
||||||
|
# Calculate minimum cells across all data sources
|
||||||
|
min_cells = min(
|
||||||
|
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||||
|
for member_stats in stats["members"].values()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build member breakdown including lon/lat
|
||||||
|
member_breakdown = {}
|
||||||
|
for member, member_stats in stats["members"].items():
|
||||||
|
member_breakdown[member] = member_stats["num_features"]
|
||||||
|
|
||||||
|
if ensemble.add_lonlat:
|
||||||
|
member_breakdown["Lon/Lat"] = 2
|
||||||
|
|
||||||
|
feature_counts.append(
|
||||||
|
{
|
||||||
|
"grid_config": grid_config,
|
||||||
|
"total_features": stats["total_features"],
|
||||||
|
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
|
||||||
|
"inference_cells": min_cells,
|
||||||
|
"total_samples": stats["num_target_samples"],
|
||||||
|
"member_breakdown": member_breakdown,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return DatasetAnalysisCache(
|
||||||
|
grid_configs=grid_configs,
|
||||||
|
sample_counts=sample_counts,
|
||||||
|
feature_counts=feature_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
|
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||||
|
st.subheader("📊 Sample Counts by Configuration")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows the number of available samples for each combination of:
|
||||||
|
- **Task**: binary, count, density
|
||||||
|
- **Target Dataset**: darts_rts, darts_mllabels
|
||||||
|
- **Grid System**: hex, healpix
|
||||||
|
- **Grid Level**: varying by grid type
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sample count DataFrame from cache
|
||||||
|
sample_df = cache.get_sample_count_df()
|
||||||
|
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||||
|
|
||||||
# Create tabs for different views
|
# Create tabs for different views
|
||||||
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
|
||||||
|
|
@ -159,73 +328,15 @@ def render_sample_count_overview():
|
||||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
||||||
def render_feature_count_fragment():
|
"""Render static comparison of feature counts across all grid configurations."""
|
||||||
"""Render interactive feature count visualization using fragments."""
|
|
||||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
This visualization shows the total number of features that would be generated
|
|
||||||
for different combinations of data sources and grid configurations.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# First section: Comparison across all grid configurations
|
|
||||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||||
|
|
||||||
# Define all possible grid configurations
|
# Get data from cache
|
||||||
grid_configs = [
|
comparison_df = cache.get_feature_count_df()
|
||||||
("hex", 3),
|
breakdown_df = cache.get_feature_breakdown_df()
|
||||||
("hex", 4),
|
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||||
("hex", 5),
|
|
||||||
("hex", 6),
|
|
||||||
("healpix", 6),
|
|
||||||
("healpix", 7),
|
|
||||||
("healpix", 8),
|
|
||||||
("healpix", 9),
|
|
||||||
("healpix", 10),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Collect feature statistics for all configurations
|
|
||||||
feature_comparison_data = []
|
|
||||||
|
|
||||||
with st.spinner("Computing feature counts for all grid configurations..."):
|
|
||||||
for grid, level in grid_configs:
|
|
||||||
# Determine which members are available for this configuration
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
|
|
||||||
if disable_alphaearth:
|
|
||||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
else:
|
|
||||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
|
|
||||||
# Use darts_rts as default target for comparison
|
|
||||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
|
||||||
stats = ensemble.get_stats()
|
|
||||||
|
|
||||||
# Calculate minimum cells across all data sources
|
|
||||||
min_cells = min(
|
|
||||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
|
||||||
for member_stats in stats["members"].values()
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_comparison_data.append(
|
|
||||||
{
|
|
||||||
"Grid": f"{grid}-{level}",
|
|
||||||
"Grid Type": grid,
|
|
||||||
"Level": level,
|
|
||||||
"Total Features": stats["total_features"],
|
|
||||||
"Data Sources": len(members),
|
|
||||||
"Inference Cells": min_cells,
|
|
||||||
"Total Samples": stats["num_target_samples"],
|
|
||||||
"AlphaEarth": "AlphaEarth" in members,
|
|
||||||
"Grid_Level_Sort": f"{grid}_{level:02d}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
comparison_df = pd.DataFrame(feature_comparison_data)
|
|
||||||
|
|
||||||
# Create tabs for different comparison views
|
# Create tabs for different comparison views
|
||||||
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
|
||||||
|
|
@ -233,56 +344,14 @@ def render_feature_count_fragment():
|
||||||
with comp_tab1:
|
with comp_tab1:
|
||||||
st.markdown("#### Total Features by Grid Configuration")
|
st.markdown("#### Total Features by Grid Configuration")
|
||||||
|
|
||||||
# Collect breakdown data for stacked bar chart
|
|
||||||
stacked_data = []
|
|
||||||
|
|
||||||
for idx, row in comparison_df.iterrows():
|
|
||||||
grid_config = row["Grid"]
|
|
||||||
grid, level_str = grid_config.split("-")
|
|
||||||
level = int(level_str)
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
|
|
||||||
if disable_alphaearth:
|
|
||||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
else:
|
|
||||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
|
|
||||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
|
||||||
stats = ensemble.get_stats()
|
|
||||||
|
|
||||||
# Add data for each member
|
|
||||||
for member, member_stats in stats["members"].items():
|
|
||||||
stacked_data.append(
|
|
||||||
{
|
|
||||||
"Grid": grid_config,
|
|
||||||
"Data Source": member,
|
|
||||||
"Number of Features": member_stats["num_features"],
|
|
||||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add lon/lat
|
|
||||||
if ensemble.add_lonlat:
|
|
||||||
stacked_data.append(
|
|
||||||
{
|
|
||||||
"Grid": grid_config,
|
|
||||||
"Data Source": "Lon/Lat",
|
|
||||||
"Number of Features": 2,
|
|
||||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
stacked_df = pd.DataFrame(stacked_data)
|
|
||||||
stacked_df = stacked_df.sort_values("Grid_Level_Sort")
|
|
||||||
|
|
||||||
# Get color palette for data sources
|
# Get color palette for data sources
|
||||||
unique_sources = stacked_df["Data Source"].unique()
|
unique_sources = breakdown_df["Data Source"].unique()
|
||||||
n_sources = len(unique_sources)
|
n_sources = len(unique_sources)
|
||||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||||
|
|
||||||
# Create stacked bar chart
|
# Create stacked bar chart
|
||||||
fig = px.bar(
|
fig = px.bar(
|
||||||
stacked_df,
|
breakdown_df,
|
||||||
x="Grid",
|
x="Grid",
|
||||||
y="Number of Features",
|
y="Number of Features",
|
||||||
color="Data Source",
|
color="Data Source",
|
||||||
|
|
@ -336,58 +405,8 @@ def render_feature_count_fragment():
|
||||||
st.markdown("#### Feature Breakdown by Data Source")
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
st.markdown("Showing percentage contribution of each data source across all grid configurations")
|
||||||
|
|
||||||
# Collect breakdown data for all grid configurations
|
|
||||||
all_breakdown_data = []
|
|
||||||
|
|
||||||
for idx, row in comparison_df.iterrows():
|
|
||||||
grid_config = row["Grid"]
|
|
||||||
grid, level_str = grid_config.split("-")
|
|
||||||
level = int(level_str)
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
|
|
||||||
if disable_alphaearth:
|
|
||||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
else:
|
|
||||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
|
|
||||||
ensemble = DatasetEnsemble(grid=grid, level=level, target="darts_rts", members=members) # type: ignore[arg-type]
|
|
||||||
stats = ensemble.get_stats()
|
|
||||||
|
|
||||||
total_features = stats["total_features"]
|
|
||||||
|
|
||||||
# Add data for each member with percentage
|
|
||||||
for member, member_stats in stats["members"].items():
|
|
||||||
percentage = (member_stats["num_features"] / total_features) * 100 # type: ignore[operator]
|
|
||||||
all_breakdown_data.append(
|
|
||||||
{
|
|
||||||
"Grid": grid_config,
|
|
||||||
"Data Source": member,
|
|
||||||
"Percentage": percentage,
|
|
||||||
"Number of Features": member_stats["num_features"],
|
|
||||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add lon/lat
|
|
||||||
if ensemble.add_lonlat:
|
|
||||||
percentage = (2 / total_features) * 100 # type: ignore[operator]
|
|
||||||
all_breakdown_data.append(
|
|
||||||
{
|
|
||||||
"Grid": grid_config,
|
|
||||||
"Data Source": "Lon/Lat",
|
|
||||||
"Percentage": percentage,
|
|
||||||
"Number of Features": 2,
|
|
||||||
"Grid_Level_Sort": row["Grid_Level_Sort"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
breakdown_all_df = pd.DataFrame(all_breakdown_data)
|
|
||||||
|
|
||||||
# Sort by grid configuration
|
|
||||||
breakdown_all_df = breakdown_all_df.sort_values("Grid_Level_Sort")
|
|
||||||
|
|
||||||
# Get color palette for data sources
|
# Get color palette for data sources
|
||||||
unique_sources = breakdown_all_df["Data Source"].unique()
|
unique_sources = breakdown_df["Data Source"].unique()
|
||||||
n_sources = len(unique_sources)
|
n_sources = len(unique_sources)
|
||||||
source_colors = get_palette("data_sources", n_colors=n_sources)
|
source_colors = get_palette("data_sources", n_colors=n_sources)
|
||||||
|
|
||||||
|
|
@ -403,7 +422,7 @@ def render_feature_count_fragment():
|
||||||
grid_idx = row_idx * cols_per_row + col_idx
|
grid_idx = row_idx * cols_per_row + col_idx
|
||||||
if grid_idx < num_grids:
|
if grid_idx < num_grids:
|
||||||
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
grid_config = comparison_df.iloc[grid_idx]["Grid"]
|
||||||
grid_data = breakdown_all_df[breakdown_all_df["Grid"] == grid_config]
|
grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
|
||||||
|
|
||||||
with cols[col_idx]:
|
with cols[col_idx]:
|
||||||
fig = px.pie(
|
fig = px.pie(
|
||||||
|
|
@ -435,24 +454,15 @@ def render_feature_count_fragment():
|
||||||
|
|
||||||
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
st.dataframe(display_df, hide_index=True, use_container_width=True)
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Second section: Detailed configuration with user selection
|
@st.fragment
|
||||||
|
def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||||
|
"""Render interactive detailed configuration explorer using fragments."""
|
||||||
st.markdown("### Detailed Configuration Explorer")
|
st.markdown("### Detailed Configuration Explorer")
|
||||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||||
|
|
||||||
# Grid selection
|
# Grid selection
|
||||||
grid_options = [
|
grid_options = [gc["grid_name"] for gc in cache.grid_configs]
|
||||||
"hex-3",
|
|
||||||
"hex-4",
|
|
||||||
"hex-5",
|
|
||||||
"hex-6",
|
|
||||||
"healpix-6",
|
|
||||||
"healpix-7",
|
|
||||||
"healpix-8",
|
|
||||||
"healpix-9",
|
|
||||||
"healpix-10",
|
|
||||||
]
|
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
|
@ -474,16 +484,15 @@ def render_feature_count_fragment():
|
||||||
key="feature_target_select",
|
key="feature_target_select",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse grid type and level
|
# Find the selected grid config
|
||||||
grid, level_str = grid_level_combined.split("-")
|
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
|
||||||
level = int(level_str)
|
grid = selected_grid_config["grid"]
|
||||||
|
level = selected_grid_config["level"]
|
||||||
|
disable_alphaearth = selected_grid_config["disable_alphaearth"]
|
||||||
|
|
||||||
# Members selection
|
# Members selection
|
||||||
st.markdown("#### Select Data Sources")
|
st.markdown("#### Select Data Sources")
|
||||||
|
|
||||||
# Check if AlphaEarth should be disabled
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
|
|
||||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
|
|
||||||
# Use columns for checkboxes
|
# Use columns for checkboxes
|
||||||
|
|
@ -624,35 +633,46 @@ def render_feature_count_fragment():
|
||||||
st.info("👆 Select at least one data source to see feature statistics")
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
||||||
|
def render_feature_count_section(cache: DatasetAnalysisCache):
|
||||||
|
"""Render the feature count section with comparison and explorer."""
|
||||||
|
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This visualization shows the total number of features that would be generated
|
||||||
|
for different combinations of data sources and grid configurations.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Static comparison across all grids
|
||||||
|
render_feature_count_comparison(cache)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Interactive explorer for detailed analysis
|
||||||
|
render_feature_count_explorer(cache)
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_analysis():
|
def render_dataset_analysis():
|
||||||
"""Render the dataset analysis section with sample and feature counts."""
|
"""Render the dataset analysis section with sample and feature counts."""
|
||||||
st.header("📈 Dataset Analysis")
|
st.header("📈 Dataset Analysis")
|
||||||
|
|
||||||
|
# Load all data once and cache it
|
||||||
|
with st.spinner("Loading dataset analysis data..."):
|
||||||
|
cache = load_dataset_analysis_data()
|
||||||
|
|
||||||
# Create tabs for the two different analyses
|
# Create tabs for the two different analyses
|
||||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||||
|
|
||||||
with analysis_tabs[0]:
|
with analysis_tabs[0]:
|
||||||
render_sample_count_overview()
|
render_sample_count_overview(cache)
|
||||||
|
|
||||||
with analysis_tabs[1]:
|
with analysis_tabs[1]:
|
||||||
render_feature_count_fragment()
|
render_feature_count_section(cache)
|
||||||
|
|
||||||
|
|
||||||
def render_overview_page():
|
def render_training_results_summary(training_results):
|
||||||
"""Render the Overview page of the dashboard."""
|
"""Render summary metrics for training results."""
|
||||||
st.title("🏡 Training Results Overview")
|
|
||||||
|
|
||||||
training_results = load_all_training_results()
|
|
||||||
|
|
||||||
if not training_results:
|
|
||||||
st.warning("No training results found. Please run some training experiments first.")
|
|
||||||
return
|
|
||||||
|
|
||||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Summary statistics at the top
|
|
||||||
st.header("📊 Training Results Summary")
|
st.header("📊 Training Results Summary")
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
|
@ -673,14 +693,9 @@ def render_overview_page():
|
||||||
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
|
||||||
st.metric("Latest Run", latest_date)
|
st.metric("Latest Run", latest_date)
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Add dataset analysis section
|
def render_experiment_results(training_results):
|
||||||
render_dataset_analysis()
|
"""Render detailed experiment results table and expandable details."""
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Detailed results table
|
|
||||||
st.header("🎯 Experiment Results")
|
st.header("🎯 Experiment Results")
|
||||||
st.subheader("Results Table")
|
st.subheader("Results Table")
|
||||||
|
|
||||||
|
|
@ -791,3 +806,33 @@ def render_overview_page():
|
||||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||||
|
|
||||||
st.write(f"\n**Path:** `{tr.path}`")
|
st.write(f"\n**Path:** `{tr.path}`")
|
||||||
|
|
||||||
|
|
||||||
|
def render_overview_page():
|
||||||
|
"""Render the Overview page of the dashboard."""
|
||||||
|
st.title("🏡 Training Results Overview")
|
||||||
|
|
||||||
|
# Load training results
|
||||||
|
training_results = load_all_training_results()
|
||||||
|
|
||||||
|
if not training_results:
|
||||||
|
st.warning("No training results found. Please run some training experiments first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Render training results sections
|
||||||
|
render_training_results_summary(training_results)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
render_experiment_results(training_results)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Render dataset analysis section
|
||||||
|
render_dataset_analysis()
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue