Overview Page refactor done

This commit is contained in:
Tobias Hölzer 2026-01-04 03:14:48 +01:00
parent 36f8737075
commit fca232da91
7 changed files with 270 additions and 363 deletions

View file

@ -11,11 +11,12 @@ Pages:
import streamlit as st
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
# from entropice.dashboard.views.inference_page import render_inference_page
# from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
from entropice.dashboard.views.training_data_page import render_training_data_page
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
# from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@ -24,17 +25,17 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
"Training": [training_data_page, training_analysis_page],
"Model State": [model_state_page],
"Inference": [inference_page],
# "Training": [training_data_page, training_analysis_page],
# "Model State": [model_state_page],
# "Inference": [inference_page],
}
)
pg.run()

View file

@ -44,8 +44,12 @@ class TrainingResult:
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
if not result_file.exists():
raise FileNotFoundError(f"Missing results file in {result_path}")
if not settings_file.exists():
raise FileNotFoundError(f"Missing settings file in {result_path}")
if not preds_file.exists():
raise FileNotFoundError(f"Missing predictions file in {result_path}")
created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]:
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
training_result = TrainingResult.from_path(result_path)
try:
training_result = TrainingResult.from_path(result_path)
except FileNotFoundError as e:
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
continue
training_results.append(training_result)
# Sort by creation time (most recent first)

View file

@ -5,7 +5,7 @@
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Literal, cast, get_args
from typing import Literal
import geopandas as gpd
import pandas as pd
@ -17,7 +17,17 @@ import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
from entropice.utils.types import (
Grid,
GridLevel,
L2SourceDataset,
TargetDataset,
Task,
all_l2_source_datasets,
all_target_datasets,
all_tasks,
grid_configs,
)
@dataclass(frozen=True)
@ -87,8 +97,7 @@ class TargetStatistics:
training_coverage: dict[Task, float] = {}
class_counts: dict[Task, dict[str, int]] = {}
class_distribution: dict[Task, dict[str, float]] = {}
tasks = cast(list[Task], get_args(Task))
for task in tasks:
for task in all_tasks:
task_col = taskcol[task][target]
cov_col = covcol[target]
@ -132,43 +141,97 @@ class DatasetStatistics:
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
@staticmethod
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
for target_name, target_stats in stats.target.items():
for task in all_tasks:
training_cells = target_stats.training_cells[task]
coverage = target_stats.coverage[task]
rows.append(
{
"Grid": grid_config.display_name,
"Target": target_name.replace("darts_", ""),
"Task": task.capitalize(),
"Samples (Coverage)": training_cells,
"Coverage %": coverage,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@staticmethod
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
data_sources = list(stats.members.keys())
# Determine minimum cells across all data sources
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
# Get sample count from first target dataset (darts_rts)
first_target = stats.target["darts_rts"]
total_samples = first_target.training_cells["binary"]
rows.append(
{
"Grid": grid_config.display_name,
"Total Features": stats.total_features,
"Data Sources": len(data_sources),
"Inference Cells": min_cells,
"Total Samples": total_samples,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@staticmethod
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
for member_name, member_stats in stats.members.items():
rows.append(
{
"Grid": grid_config.display_name,
"Data Source": member_name,
"Number of Features": member_stats.feature_count,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@st.cache_data
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
grid_levels: set[tuple[Grid, int]] = {
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
for grid, level in grid_levels:
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
for grid_config in grid_configs:
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell."
target_statistics: dict[TargetDataset, TargetStatistics] = {}
targets = cast(list[TargetDataset], get_args(TargetDataset))
for target in targets:
for target in all_target_datasets:
target_statistics[target] = TargetStatistics.compute(
grid=grid, level=level, target=target, total_cells=total_cells
grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
)
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
for member in members:
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
for member in all_l2_source_datasets:
member_statistics[member] = MemberStatistics.compute(
grid=grid_config.grid, level=grid_config.level, member=member
)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
ts.size_bytes for ts in target_statistics.values()
)
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
dataset_stats[grid_level] = DatasetStatistics(
dataset_stats[grid_config.id] = DatasetStatistics(
total_features=total_features,
total_cells=total_cells,
size_bytes=total_size_bytes,

View file

@ -1,6 +1,7 @@
"""Inference page: Visualization of model inference results across the study region."""
import streamlit as st
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.plots.inference import (
render_class_comparison,
@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics,
render_spatial_distribution_stats,
)
from entropice.dashboard.utils.data import load_all_training_results
def render_inference_page():

View file

@ -2,6 +2,15 @@
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap,
@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import (
plot_top_features,
)
from entropice.dashboard.utils.colors import generate_unified_colormap
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
def render_model_state_page():

View file

@ -1,238 +1,20 @@
"""Overview page: List of available result directories with some summary statistics."""
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
from typing import cast
import pandas as pd
import plotly.express as px
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import Grid, TargetDataset, Task
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
# Type definitions for dataset statistics
class GridConfig(TypedDict):
"""Grid configuration specification with metadata."""
grid: Grid
level: int
grid_name: str
grid_sort_key: str
disable_alphaearth: bool
class SampleCountData(TypedDict):
"""Sample count statistics for a specific grid/target/task combination."""
grid_config: GridConfig
target: str
task: str
samples_coverage: int
samples_labels: int
samples_both: int
class FeatureCountData(TypedDict):
"""Feature count statistics for a specific grid configuration."""
grid_config: GridConfig
total_features: int
data_sources: list[str]
inference_cells: int
total_samples: int
member_breakdown: dict[str, int]
@dataclass(frozen=True)
class DatasetAnalysisCache:
"""Cache for dataset analysis data to avoid redundant computations."""
grid_configs: list[GridConfig]
sample_counts: list[SampleCountData]
feature_counts: list[FeatureCountData]
def get_sample_count_df(self) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for item in self.sample_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Target": item["target"].replace("darts_", ""),
"Task": item["task"].capitalize(),
"Samples (Coverage)": item["samples_coverage"],
"Samples (Labels)": item["samples_labels"],
"Samples (Both)": item["samples_both"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
def get_feature_count_df(self) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for item in self.feature_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Total Features": item["total_features"],
"Data Sources": len(item["data_sources"]),
"Inference Cells": item["inference_cells"],
"Total Samples": item["total_samples"],
"AlphaEarth": "AlphaEarth" in item["data_sources"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
def get_feature_breakdown_df(self) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for item in self.feature_counts:
for source, count in item["member_breakdown"].items():
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Data Source": source,
"Number of Features": count,
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
@st.cache_data(show_spinner=False)
def load_dataset_analysis_data() -> DatasetAnalysisCache:
"""Load and cache all dataset analysis data.
This function computes both sample counts and feature counts for all grid configurations.
Results are cached to avoid redundant computations across different tabs.
"""
# Define all possible grid configurations
grid_configs_raw: list[tuple[Grid, int]] = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
# Create structured grid config objects
grid_configs: list[GridConfig] = []
for grid, level in grid_configs_raw:
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
grid_configs.append(
{
"grid": grid,
"level": level,
"grid_name": f"{grid}-{level}",
"grid_sort_key": f"{grid}_{level:02d}",
"disable_alphaearth": disable_alphaearth,
}
)
# Compute sample counts
sample_counts: list[SampleCountData] = []
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
tasks: list[Task] = ["binary", "count", "density"]
for grid_config in grid_configs:
for target in target_datasets:
# Create minimal ensemble just to get target data
ensemble = DatasetEnsemble(
grid=grid_config["grid"],
level=grid_config["level"],
target=target,
members=[], # type: ignore[arg-type]
)
targets = ensemble._read_target()
for task in tasks:
# Get task-specific column
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
covcol = ensemble.covcol
# Count samples with coverage and valid labels
if covcol in targets.columns and taskcol in targets.columns:
valid_coverage = targets[covcol].sum()
valid_labels = targets[taskcol].notna().sum()
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
sample_counts.append(
{
"grid_config": grid_config,
"target": target,
"task": task,
"samples_coverage": valid_coverage,
"samples_labels": valid_labels,
"samples_both": valid_both,
}
)
# Compute feature counts
feature_counts: list[FeatureCountData] = []
for grid_config in grid_configs:
# Determine which members are available for this configuration
if grid_config["disable_alphaearth"]:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(
grid=grid_config["grid"],
level=grid_config["level"],
target="darts_rts",
members=members, # type: ignore[arg-type]
)
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
# Build member breakdown including lon/lat
member_breakdown = {}
for member, member_stats in stats["members"].items():
member_breakdown[member] = member_stats["num_features"]
if ensemble.add_lonlat:
member_breakdown["Lon/Lat"] = 2
feature_counts.append(
{
"grid_config": grid_config,
"total_features": stats["total_features"],
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
"inference_cells": min_cells,
"total_samples": stats["num_target_samples"],
"member_breakdown": member_breakdown,
}
)
return DatasetAnalysisCache(
grid_configs=grid_configs,
sample_counts=sample_counts,
feature_counts=feature_counts,
)
def render_sample_count_overview(cache: DatasetAnalysisCache):
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
)
# Get sample count DataFrame from cache
sample_df = cache.get_sample_count_df()
all_stats = load_all_default_dataset_statistics()
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
target_datasets = ["darts_rts", "darts_mllabels"]
# Create tabs for different views
@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab1:
st.markdown("### Sample Counts Heatmap")
st.markdown("Showing counts of samples with both coverage and valid labels")
st.markdown("Showing counts of samples with coverage")
# Create heatmap for each target dataset
for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab2:
st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with both coverage and valid labels")
st.markdown("Showing counts of samples with coverage")
# Create a faceted bar chart showing both targets side by side
# Get color palette for tasks
@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
fig = px.bar(
sample_df,
x="Grid",
y="Samples (Both)",
y="Samples (Coverage)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
color_discrete_sequence=task_colors,
height=500,
)
@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
st.markdown("### Detailed Sample Counts")
# Display full table with formatting
display_df = sample_df[
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
].copy()
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
# Format numbers with commas
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
# Format coverage as percentage with 2 decimal places
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(display_df, hide_index=True, width="stretch")
def render_feature_count_comparison(cache: DatasetAnalysisCache):
def render_feature_count_comparison():
"""Render static comparison of feature counts across all grid configurations."""
st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Get data from cache
comparison_df = cache.get_feature_count_df()
breakdown_df = cache.get_feature_breakdown_df()
all_stats = load_all_default_dataset_statistics()
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Create tabs for different comparison views
@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache):
# Display full comparison table with formatting
display_df = comparison_df[
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
# Format boolean as Yes/No
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "" if x else "")
st.dataframe(display_df, hide_index=True, width="stretch")
@st.fragment
def render_feature_count_explorer(cache: DatasetAnalysisCache):
def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments."""
st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
grid_options = [gc["grid_name"] for gc in cache.grid_configs]
grid_options = [gc.display_name for gc in grid_configs]
col1, col2 = st.columns(2)
@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
)
# Find the selected grid config
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
grid = selected_grid_config["grid"]
level = selected_grid_config["level"]
disable_alphaearth = selected_grid_config["disable_alphaearth"]
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Get available members from the stats
all_stats = load_all_default_dataset_statistics()
stats = all_stats[selected_grid_config.id]
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
# Members selection
st.markdown("#### Select Data Sources")
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
# Use columns for checkboxes
cols = st.columns(len(all_members))
selected_members = []
selected_members: list[L2SourceDataset] = []
for idx, member in enumerate(all_members):
with cols[idx]:
if member == "AlphaEarth" and disable_alphaearth:
st.checkbox(
member,
value=False,
disabled=True,
help=f"Not available for {grid} level {level}",
key=f"feature_member_{member}",
)
else:
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
selected_members.append(member)
default_value = member in available_members
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
selected_members.append(cast(L2SourceDataset, member))
# Show results if at least one member is selected
if selected_members:
st.markdown("---")
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
# Get statistics from cache (already loaded)
grid_stats = all_stats[selected_grid_config.id]
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Filter to selected members only
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
# Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
# Get target stats
target_stats = grid_stats.target[cast(TargetDataset, target)]
# High-level metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Total Features", f"{stats['total_features']:,}")
st.metric("Total Features", f"{total_features:,}")
with col2:
# Calculate minimum cells across all data sources (for inference capability)
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
with col3:
st.metric("Data Sources", len(selected_members))
with col4:
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
# Use binary task training cells as sample count
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
with col5:
# Calculate total data points
total_points = stats["total_features"] * stats["num_target_samples"]
total_points = total_features * target_stats.training_cells["binary"]
st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in stats["members"].items():
for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats["num_features"],
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
}
)
# Add lon/lat
if ensemble.add_lonlat:
breakdown_data.append(
{
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
"Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
for member, member_stats in stats["members"].items():
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
st.metric("Features", member_stats.feature_count)
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
st.metric("Variables", len(member_stats.variable_names))
with metric_cols[2]:
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
for dim_size in member_stats.dimensions.values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
# Size on disk
size_mb = member_stats.size_bytes / (1024 * 1024)
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
st.markdown("---")
else:
st.info("👆 Select at least one data source to see feature statistics")
def render_feature_count_section(cache: DatasetAnalysisCache):
def render_feature_count_section():
"""Render the feature count section with comparison and explorer."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache):
)
# Static comparison across all grids
render_feature_count_comparison(cache)
render_feature_count_comparison()
st.divider()
# Interactive explorer for detailed analysis
render_feature_count_explorer(cache)
render_feature_count_explorer()
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
# Load all data once and cache it
with st.spinner("Loading dataset analysis data..."):
cache = load_dataset_analysis_data()
# Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]:
render_sample_count_overview(cache)
render_sample_count_overview()
with analysis_tabs[1]:
render_feature_count_section(cache)
render_feature_count_section()
def render_training_results_summary(training_results):
@ -678,15 +450,15 @@ def render_training_results_summary(training_results):
col1, col2, col3, col4 = st.columns(4)
with col1:
tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
tasks = {tr.settings.task for tr in training_results}
st.metric("Tasks", len(tasks))
with col2:
grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
grids = {tr.settings.grid for tr in training_results}
st.metric("Grid Types", len(grids))
with col3:
models = {tr.settings.get("model", "Unknown") for tr in training_results}
models = {tr.settings.model for tr in training_results}
st.metric("Model Types", len(models))
with col4:
@ -725,10 +497,10 @@ def render_experiment_results(training_results):
summary_data.append(
{
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
"Task": tr.settings.get("task", "Unknown"),
"Grid": tr.settings.get("grid", "Unknown"),
"Level": tr.settings.get("level", "Unknown"),
"Model": tr.settings.get("model", "Unknown"),
"Task": tr.settings.task,
"Grid": tr.settings.grid,
"Level": tr.settings.level,
"Model": tr.settings.model,
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
"Trials": len(tr.results),
"Path": str(tr.path.name),
@ -750,17 +522,20 @@ def render_experiment_results(training_results):
st.subheader("Individual Experiment Details")
for tr in training_results:
with st.expander(tr.get_display_name("task_first")):
display_name = (
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
)
with st.expander(display_name):
col1, col2 = st.columns([1, 2])
with col1:
st.write("**Configuration:**")
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
st.write(f"- **Task:** {tr.settings.task}")
st.write(f"- **Grid:** {tr.settings.grid}")
st.write(f"- **Level:** {tr.settings.level}")
st.write(f"- **Model:** {tr.settings.model}")
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
st.write(f"- **Classes:** {tr.settings.classes}")
st.write("\n**Files:**")
st.write("- 📊 search_results.parquet")
@ -844,3 +619,5 @@ def render_overview_page():
render_dataset_analysis()
st.balloons()
stopwatch.summary()

View file

@ -1,5 +1,6 @@
"""Shared types used across the entropice codebase."""
from dataclasses import dataclass
from typing import Literal
type Grid = Literal["hex", "healpix"]
@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"]
type Model = Literal["espa", "xgboost", "rf", "knn"]
@dataclass(frozen=True)
class GridConfig:
"""Grid configuration specification with metadata."""
grid: Grid
level: int
id: GridLevel
display_name: str
sort_key: str
@classmethod
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
"""Create a GridConfig from a GridLevel string."""
if grid_level.startswith("hex"):
grid = "hex"
level = int(grid_level[3:])
elif grid_level.startswith("healpix"):
grid = "healpix"
level = int(grid_level[7:])
else:
raise ValueError(f"Invalid grid level: {grid_level}")
display_name = f"{grid.capitalize()}-{level}"
return cls(
grid=grid,
level=level,
id=grid_level,
display_name=display_name,
sort_key=f"{grid}_{level:02d}",
)
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count", "density"]
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
"ERA5-seasonal",
"ERA5-yearly",
"AlphaEarth",
]
grid_configs: list[GridConfig] = [
GridConfig.from_grid_level("hex3"),
GridConfig.from_grid_level("hex4"),
GridConfig.from_grid_level("hex5"),
GridConfig.from_grid_level("hex6"),
GridConfig.from_grid_level("healpix6"),
GridConfig.from_grid_level("healpix7"),
GridConfig.from_grid_level("healpix8"),
GridConfig.from_grid_level("healpix9"),
GridConfig.from_grid_level("healpix10"),
]