Overview Page refactor done

This commit is contained in:
Tobias Hölzer 2026-01-04 03:14:48 +01:00
parent 36f8737075
commit fca232da91
7 changed files with 270 additions and 363 deletions

View file

@ -11,11 +11,12 @@ Pages:
import streamlit as st import streamlit as st
from entropice.dashboard.views.inference_page import render_inference_page # from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page # from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
from entropice.dashboard.views.training_data_page import render_training_data_page # from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
# from entropice.dashboard.views.training_data_page import render_training_data_page
def main(): def main():
@ -24,17 +25,17 @@ def main():
# Setup Navigation # Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") # training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") # training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") # model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") # inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation( pg = st.navigation(
{ {
"Overview": [overview_page], "Overview": [overview_page],
"Training": [training_data_page, training_analysis_page], # "Training": [training_data_page, training_analysis_page],
"Model State": [model_state_page], # "Model State": [model_state_page],
"Inference": [inference_page], # "Inference": [inference_page],
} }
) )
pg.run() pg.run()

View file

@ -44,8 +44,12 @@ class TrainingResult:
result_file = result_path / "search_results.parquet" result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet" preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml" settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]): if not result_file.exists():
raise FileNotFoundError(f"Missing required files in {result_path}") raise FileNotFoundError(f"Missing results file in {result_path}")
if not settings_file.exists():
raise FileNotFoundError(f"Missing settings file in {result_path}")
if not preds_file.exists():
raise FileNotFoundError(f"Missing predictions file in {result_path}")
created_at = result_path.stat().st_ctime created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"])) settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]:
for result_path in results_dir.iterdir(): for result_path in results_dir.iterdir():
if not result_path.is_dir(): if not result_path.is_dir():
continue continue
try:
training_result = TrainingResult.from_path(result_path) training_result = TrainingResult.from_path(result_path)
except FileNotFoundError as e:
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
continue
training_results.append(training_result) training_results.append(training_result)
# Sort by creation time (most recent first) # Sort by creation time (most recent first)

View file

@ -5,7 +5,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Literal, cast, get_args from typing import Literal
import geopandas as gpd import geopandas as gpd
import pandas as pd import pandas as pd
@ -17,7 +17,17 @@ import entropice.spatial.grids
import entropice.utils.paths import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult from entropice.dashboard.utils.loaders import TrainingResult
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task from entropice.utils.types import (
Grid,
GridLevel,
L2SourceDataset,
TargetDataset,
Task,
all_l2_source_datasets,
all_target_datasets,
all_tasks,
grid_configs,
)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -87,8 +97,7 @@ class TargetStatistics:
training_coverage: dict[Task, float] = {} training_coverage: dict[Task, float] = {}
class_counts: dict[Task, dict[str, int]] = {} class_counts: dict[Task, dict[str, int]] = {}
class_distribution: dict[Task, dict[str, float]] = {} class_distribution: dict[Task, dict[str, float]] = {}
tasks = cast(list[Task], get_args(Task)) for task in all_tasks:
for task in tasks:
task_col = taskcol[task][target] task_col = taskcol[task][target]
cov_col = covcol[target] cov_col = covcol[target]
@ -132,43 +141,97 @@ class DatasetStatistics:
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
@staticmethod
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
for target_name, target_stats in stats.target.items():
for task in all_tasks:
training_cells = target_stats.training_cells[task]
coverage = target_stats.coverage[task]
rows.append(
{
"Grid": grid_config.display_name,
"Target": target_name.replace("darts_", ""),
"Task": task.capitalize(),
"Samples (Coverage)": training_cells,
"Coverage %": coverage,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@staticmethod
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
data_sources = list(stats.members.keys())
# Determine minimum cells across all data sources
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
# Get sample count from first target dataset (darts_rts)
first_target = stats.target["darts_rts"]
total_samples = first_target.training_cells["binary"]
rows.append(
{
"Grid": grid_config.display_name,
"Total Features": stats.total_features,
"Data Sources": len(data_sources),
"Inference Cells": min_cells,
"Total Samples": total_samples,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@staticmethod
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for grid_config in grid_configs:
stats = all_stats[grid_config.id]
for member_name, member_stats in stats.members.items():
rows.append(
{
"Grid": grid_config.display_name,
"Data Source": member_name,
"Number of Features": member_stats.feature_count,
"Grid_Level_Sort": grid_config.sort_key,
}
)
return pd.DataFrame(rows)
@st.cache_data @st.cache_data
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]: def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
dataset_stats: dict[GridLevel, DatasetStatistics] = {} dataset_stats: dict[GridLevel, DatasetStatistics] = {}
grid_levels: set[tuple[Grid, int]] = { for grid_config in grid_configs:
("hex", 3), with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
("hex", 4), grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
for grid, level in grid_levels:
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
total_cells = len(grid_gdf) total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell." assert total_cells > 0, "Grid must contain at least one cell."
target_statistics: dict[TargetDataset, TargetStatistics] = {} target_statistics: dict[TargetDataset, TargetStatistics] = {}
targets = cast(list[TargetDataset], get_args(TargetDataset)) for target in all_target_datasets:
for target in targets:
target_statistics[target] = TargetStatistics.compute( target_statistics[target] = TargetStatistics.compute(
grid=grid, level=level, target=target, total_cells=total_cells grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
) )
member_statistics: dict[L2SourceDataset, MemberStatistics] = {} member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
members = cast(list[L2SourceDataset], get_args(L2SourceDataset)) for member in all_l2_source_datasets:
for member in members: member_statistics[member] = MemberStatistics.compute(
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member) grid=grid_config.grid, level=grid_config.level, member=member
)
total_features = sum(ms.feature_count for ms in member_statistics.values()) total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum( total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
ts.size_bytes for ts in target_statistics.values() ts.size_bytes for ts in target_statistics.values()
) )
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}") dataset_stats[grid_config.id] = DatasetStatistics(
dataset_stats[grid_level] = DatasetStatistics(
total_features=total_features, total_features=total_features,
total_cells=total_cells, total_cells=total_cells,
size_bytes=total_size_bytes, size_bytes=total_size_bytes,

View file

@ -1,6 +1,7 @@
"""Inference page: Visualization of model inference results across the study region.""" """Inference page: Visualization of model inference results across the study region."""
import streamlit as st import streamlit as st
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.plots.inference import ( from entropice.dashboard.plots.inference import (
render_class_comparison, render_class_comparison,
@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics, render_inference_statistics,
render_spatial_distribution_stats, render_spatial_distribution_stats,
) )
from entropice.dashboard.utils.data import load_all_training_results
def render_inference_page(): def render_inference_page():

View file

@ -2,6 +2,15 @@
import streamlit as st import streamlit as st
import xarray as xr import xarray as xr
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
from entropice.dashboard.plots.model_state import ( from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap, plot_arcticdem_heatmap,
@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import (
plot_top_features, plot_top_features,
) )
from entropice.dashboard.utils.colors import generate_unified_colormap from entropice.dashboard.utils.colors import generate_unified_colormap
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
def render_model_state_page(): def render_model_state_page():

View file

@ -1,238 +1,20 @@
"""Overview page: List of available result directories with some summary statistics.""" """Overview page: List of available result directories with some summary statistics."""
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import TypedDict from typing import cast
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import streamlit as st import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.ml.dataset import DatasetEnsemble from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
from entropice.utils.types import Grid, TargetDataset, Task from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
# Type definitions for dataset statistics def render_sample_count_overview():
class GridConfig(TypedDict):
"""Grid configuration specification with metadata."""
grid: Grid
level: int
grid_name: str
grid_sort_key: str
disable_alphaearth: bool
class SampleCountData(TypedDict):
"""Sample count statistics for a specific grid/target/task combination."""
grid_config: GridConfig
target: str
task: str
samples_coverage: int
samples_labels: int
samples_both: int
class FeatureCountData(TypedDict):
"""Feature count statistics for a specific grid configuration."""
grid_config: GridConfig
total_features: int
data_sources: list[str]
inference_cells: int
total_samples: int
member_breakdown: dict[str, int]
@dataclass(frozen=True)
class DatasetAnalysisCache:
"""Cache for dataset analysis data to avoid redundant computations."""
grid_configs: list[GridConfig]
sample_counts: list[SampleCountData]
feature_counts: list[FeatureCountData]
def get_sample_count_df(self) -> pd.DataFrame:
"""Convert sample count data to DataFrame."""
rows = []
for item in self.sample_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Target": item["target"].replace("darts_", ""),
"Task": item["task"].capitalize(),
"Samples (Coverage)": item["samples_coverage"],
"Samples (Labels)": item["samples_labels"],
"Samples (Both)": item["samples_both"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
def get_feature_count_df(self) -> pd.DataFrame:
"""Convert feature count data to DataFrame."""
rows = []
for item in self.feature_counts:
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Grid Type": item["grid_config"]["grid"],
"Level": item["grid_config"]["level"],
"Total Features": item["total_features"],
"Data Sources": len(item["data_sources"]),
"Inference Cells": item["inference_cells"],
"Total Samples": item["total_samples"],
"AlphaEarth": "AlphaEarth" in item["data_sources"],
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
def get_feature_breakdown_df(self) -> pd.DataFrame:
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
rows = []
for item in self.feature_counts:
for source, count in item["member_breakdown"].items():
rows.append(
{
"Grid": item["grid_config"]["grid_name"],
"Data Source": source,
"Number of Features": count,
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
}
)
return pd.DataFrame(rows)
@st.cache_data(show_spinner=False)
def load_dataset_analysis_data() -> DatasetAnalysisCache:
"""Load and cache all dataset analysis data.
This function computes both sample counts and feature counts for all grid configurations.
Results are cached to avoid redundant computations across different tabs.
"""
# Define all possible grid configurations
grid_configs_raw: list[tuple[Grid, int]] = [
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
]
# Create structured grid config objects
grid_configs: list[GridConfig] = []
for grid, level in grid_configs_raw:
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
grid_configs.append(
{
"grid": grid,
"level": level,
"grid_name": f"{grid}-{level}",
"grid_sort_key": f"{grid}_{level:02d}",
"disable_alphaearth": disable_alphaearth,
}
)
# Compute sample counts
sample_counts: list[SampleCountData] = []
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
tasks: list[Task] = ["binary", "count", "density"]
for grid_config in grid_configs:
for target in target_datasets:
# Create minimal ensemble just to get target data
ensemble = DatasetEnsemble(
grid=grid_config["grid"],
level=grid_config["level"],
target=target,
members=[], # type: ignore[arg-type]
)
targets = ensemble._read_target()
for task in tasks:
# Get task-specific column
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
covcol = ensemble.covcol
# Count samples with coverage and valid labels
if covcol in targets.columns and taskcol in targets.columns:
valid_coverage = targets[covcol].sum()
valid_labels = targets[taskcol].notna().sum()
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
sample_counts.append(
{
"grid_config": grid_config,
"target": target,
"task": task,
"samples_coverage": valid_coverage,
"samples_labels": valid_labels,
"samples_both": valid_both,
}
)
# Compute feature counts
feature_counts: list[FeatureCountData] = []
for grid_config in grid_configs:
# Determine which members are available for this configuration
if grid_config["disable_alphaearth"]:
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
else:
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
# Use darts_rts as default target for comparison
ensemble = DatasetEnsemble(
grid=grid_config["grid"],
level=grid_config["level"],
target="darts_rts",
members=members, # type: ignore[arg-type]
)
stats = ensemble.get_stats()
# Calculate minimum cells across all data sources
min_cells = min(
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
# Build member breakdown including lon/lat
member_breakdown = {}
for member, member_stats in stats["members"].items():
member_breakdown[member] = member_stats["num_features"]
if ensemble.add_lonlat:
member_breakdown["Lon/Lat"] = 2
feature_counts.append(
{
"grid_config": grid_config,
"total_features": stats["total_features"],
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
"inference_cells": min_cells,
"total_samples": stats["num_target_samples"],
"member_breakdown": member_breakdown,
}
)
return DatasetAnalysisCache(
grid_configs=grid_configs,
sample_counts=sample_counts,
feature_counts=feature_counts,
)
def render_sample_count_overview(cache: DatasetAnalysisCache):
"""Render overview of sample counts per task+target+grid+level combination.""" """Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration") st.subheader("📊 Sample Counts by Configuration")
@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
) )
# Get sample count DataFrame from cache # Get sample count DataFrame from cache
sample_df = cache.get_sample_count_df() all_stats = load_all_default_dataset_statistics()
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
target_datasets = ["darts_rts", "darts_mllabels"] target_datasets = ["darts_rts", "darts_mllabels"]
# Create tabs for different views # Create tabs for different views
@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab1: with tab1:
st.markdown("### Sample Counts Heatmap") st.markdown("### Sample Counts Heatmap")
st.markdown("Showing counts of samples with both coverage and valid labels") st.markdown("Showing counts of samples with coverage")
# Create heatmap for each target dataset # Create heatmap for each target dataset
for target in target_datasets: for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task # Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean") pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
# Sort index by grid type and level # Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
with tab2: with tab2:
st.markdown("### Sample Counts Bar Chart") st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with both coverage and valid labels") st.markdown("Showing counts of samples with coverage")
# Create a faceted bar chart showing both targets side by side # Create a faceted bar chart showing both targets side by side
# Get color palette for tasks # Get color palette for tasks
@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
fig = px.bar( fig = px.bar(
sample_df, sample_df,
x="Grid", x="Grid",
y="Samples (Both)", y="Samples (Coverage)",
color="Task", color="Task",
facet_col="Target", facet_col="Target",
barmode="group", barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset", title="Sample Counts by Grid Configuration and Target Dataset",
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"}, labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
color_discrete_sequence=task_colors, color_discrete_sequence=task_colors,
height=500, height=500,
) )
@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
st.markdown("### Detailed Sample Counts") st.markdown("### Detailed Sample Counts")
# Display full table with formatting # Display full table with formatting
display_df = sample_df[ display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
].copy()
# Format numbers with commas # Format numbers with commas
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]: display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
display_df[col] = display_df[col].apply(lambda x: f"{x:,}") # Format coverage as percentage with 2 decimal places
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(display_df, hide_index=True, width="stretch") st.dataframe(display_df, hide_index=True, width="stretch")
def render_feature_count_comparison(cache: DatasetAnalysisCache): def render_feature_count_comparison():
"""Render static comparison of feature counts across all grid configurations.""" """Render static comparison of feature counts across all grid configurations."""
st.markdown("### Feature Count Comparison Across Grid Configurations") st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
# Get data from cache # Get data from cache
comparison_df = cache.get_feature_count_df() all_stats = load_all_default_dataset_statistics()
breakdown_df = cache.get_feature_breakdown_df() comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Create tabs for different comparison views # Create tabs for different comparison views
@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache):
# Display full comparison table with formatting # Display full comparison table with formatting
display_df = comparison_df[ display_df = comparison_df[
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"] ["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
].copy() ].copy()
# Format numbers with commas # Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]: for col in ["Total Features", "Inference Cells", "Total Samples"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}") display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
# Format boolean as Yes/No
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "" if x else "")
st.dataframe(display_df, hide_index=True, width="stretch") st.dataframe(display_df, hide_index=True, width="stretch")
@st.fragment @st.fragment
def render_feature_count_explorer(cache: DatasetAnalysisCache): def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments.""" """Render interactive detailed configuration explorer using fragments."""
st.markdown("### Detailed Configuration Explorer") st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics") st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection # Grid selection
grid_options = [gc["grid_name"] for gc in cache.grid_configs] grid_options = [gc.display_name for gc in grid_configs]
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
) )
# Find the selected grid config # Find the selected grid config
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined) selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
grid = selected_grid_config["grid"]
level = selected_grid_config["level"] # Get available members from the stats
disable_alphaearth = selected_grid_config["disable_alphaearth"] all_stats = load_all_default_dataset_statistics()
stats = all_stats[selected_grid_config.id]
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
# Members selection # Members selection
st.markdown("#### Select Data Sources") st.markdown("#### Select Data Sources")
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
# Use columns for checkboxes # Use columns for checkboxes
cols = st.columns(len(all_members)) cols = st.columns(len(all_members))
selected_members = [] selected_members: list[L2SourceDataset] = []
for idx, member in enumerate(all_members): for idx, member in enumerate(all_members):
with cols[idx]: with cols[idx]:
if member == "AlphaEarth" and disable_alphaearth: default_value = member in available_members
st.checkbox( if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
member, selected_members.append(cast(L2SourceDataset, member))
value=False,
disabled=True,
help=f"Not available for {grid} level {level}",
key=f"feature_member_{member}",
)
else:
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
selected_members.append(member)
# Show results if at least one member is selected # Show results if at least one member is selected
if selected_members: if selected_members:
st.markdown("---") st.markdown("---")
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members) # Get statistics from cache (already loaded)
grid_stats = all_stats[selected_grid_config.id]
with st.spinner("Computing dataset statistics..."): # Filter to selected members only
stats = ensemble.get_stats() selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
# Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
# Get target stats
target_stats = grid_stats.target[cast(TargetDataset, target)]
# High-level metrics # High-level metrics
col1, col2, col3, col4, col5 = st.columns(5) col1, col2, col3, col4, col5 = st.columns(5)
with col1: with col1:
st.metric("Total Features", f"{stats['total_features']:,}") st.metric("Total Features", f"{total_features:,}")
with col2: with col2:
# Calculate minimum cells across all data sources (for inference capability) # Calculate minimum cells across all data sources (for inference capability)
min_cells = min( min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
for member_stats in stats["members"].values()
)
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources") st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
with col3: with col3:
st.metric("Data Sources", len(selected_members)) st.metric("Data Sources", len(selected_members))
with col4: with col4:
st.metric("Total Samples", f"{stats['num_target_samples']:,}") # Use binary task training cells as sample count
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
with col5: with col5:
# Calculate total data points # Calculate total data points
total_points = stats["total_features"] * stats["num_target_samples"] total_points = total_features * target_stats.training_cells["binary"]
st.metric("Total Data Points", f"{total_points:,}") st.metric("Total Data Points", f"{total_points:,}")
# Feature breakdown by source # Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source") st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = [] breakdown_data = []
for member, member_stats in stats["members"].items(): for member, member_stats in selected_member_stats.items():
breakdown_data.append( breakdown_data.append(
{ {
"Data Source": member, "Data Source": member,
"Number of Features": member_stats["num_features"], "Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator] "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
# Add lon/lat
if ensemble.add_lonlat:
breakdown_data.append(
{
"Data Source": "Lon/Lat",
"Number of Features": 2,
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
} }
) )
@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
# Detailed member information # Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False): with st.expander("📦 Detailed Source Information", expanded=False):
for member, member_stats in stats["members"].items(): for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}") st.markdown(f"### {member}")
metric_cols = st.columns(4) metric_cols = st.columns(4)
with metric_cols[0]: with metric_cols[0]:
st.metric("Features", member_stats["num_features"]) st.metric("Features", member_stats.feature_count)
with metric_cols[1]: with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"]) st.metric("Variables", len(member_stats.variable_names))
with metric_cols[2]: with metric_cols[2]:
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr] dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
st.metric("Shape", dim_str) st.metric("Shape", dim_str)
with metric_cols[3]: with metric_cols[3]:
total_points = 1 total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr] for dim_size in member_stats.dimensions.values():
total_points *= dim_size total_points *= dim_size
st.metric("Data Points", f"{total_points:,}") st.metric("Data Points", f"{total_points:,}")
@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
[ [
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; ' f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>' f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr] for v in member_stats.variable_names
] ]
) )
st.markdown(vars_html, unsafe_allow_html=True) st.markdown(vars_html, unsafe_allow_html=True)
@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; ' f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">' f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>" f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr] for dim_name, dim_size in member_stats.dimensions.items()
] ]
) )
st.markdown(dim_html, unsafe_allow_html=True) st.markdown(dim_html, unsafe_allow_html=True)
# Size on disk
size_mb = member_stats.size_bytes / (1024 * 1024)
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
st.markdown("---") st.markdown("---")
else: else:
st.info("👆 Select at least one data source to see feature statistics") st.info("👆 Select at least one data source to see feature statistics")
def render_feature_count_section(cache: DatasetAnalysisCache): def render_feature_count_section():
"""Render the feature count section with comparison and explorer.""" """Render the feature count section with comparison and explorer."""
st.subheader("🔢 Feature Counts by Dataset Configuration") st.subheader("🔢 Feature Counts by Dataset Configuration")
@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache):
) )
# Static comparison across all grids # Static comparison across all grids
render_feature_count_comparison(cache) render_feature_count_comparison()
st.divider() st.divider()
# Interactive explorer for detailed analysis # Interactive explorer for detailed analysis
render_feature_count_explorer(cache) render_feature_count_explorer()
def render_dataset_analysis(): def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts.""" """Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis") st.header("📈 Dataset Analysis")
# Load all data once and cache it
with st.spinner("Loading dataset analysis data..."):
cache = load_dataset_analysis_data()
# Create tabs for the two different analyses # Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
with analysis_tabs[0]: with analysis_tabs[0]:
render_sample_count_overview(cache) render_sample_count_overview()
with analysis_tabs[1]: with analysis_tabs[1]:
render_feature_count_section(cache) render_feature_count_section()
def render_training_results_summary(training_results): def render_training_results_summary(training_results):
@ -678,15 +450,15 @@ def render_training_results_summary(training_results):
col1, col2, col3, col4 = st.columns(4) col1, col2, col3, col4 = st.columns(4)
with col1: with col1:
tasks = {tr.settings.get("task", "Unknown") for tr in training_results} tasks = {tr.settings.task for tr in training_results}
st.metric("Tasks", len(tasks)) st.metric("Tasks", len(tasks))
with col2: with col2:
grids = {tr.settings.get("grid", "Unknown") for tr in training_results} grids = {tr.settings.grid for tr in training_results}
st.metric("Grid Types", len(grids)) st.metric("Grid Types", len(grids))
with col3: with col3:
models = {tr.settings.get("model", "Unknown") for tr in training_results} models = {tr.settings.model for tr in training_results}
st.metric("Model Types", len(models)) st.metric("Model Types", len(models))
with col4: with col4:
@ -725,10 +497,10 @@ def render_experiment_results(training_results):
summary_data.append( summary_data.append(
{ {
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"), "Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
"Task": tr.settings.get("task", "Unknown"), "Task": tr.settings.task,
"Grid": tr.settings.get("grid", "Unknown"), "Grid": tr.settings.grid,
"Level": tr.settings.get("level", "Unknown"), "Level": tr.settings.level,
"Model": tr.settings.get("model", "Unknown"), "Model": tr.settings.model,
f"Best {primary_metric.title()}": f"{primary_score:.4f}", f"Best {primary_metric.title()}": f"{primary_score:.4f}",
"Trials": len(tr.results), "Trials": len(tr.results),
"Path": str(tr.path.name), "Path": str(tr.path.name),
@ -750,17 +522,20 @@ def render_experiment_results(training_results):
st.subheader("Individual Experiment Details") st.subheader("Individual Experiment Details")
for tr in training_results: for tr in training_results:
with st.expander(tr.get_display_name("task_first")): display_name = (
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
)
with st.expander(display_name):
col1, col2 = st.columns([1, 2]) col1, col2 = st.columns([1, 2])
with col1: with col1:
st.write("**Configuration:**") st.write("**Configuration:**")
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}") st.write(f"- **Task:** {tr.settings.task}")
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}") st.write(f"- **Grid:** {tr.settings.grid}")
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}") st.write(f"- **Level:** {tr.settings.level}")
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}") st.write(f"- **Model:** {tr.settings.model}")
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}") st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}") st.write(f"- **Classes:** {tr.settings.classes}")
st.write("\n**Files:**") st.write("\n**Files:**")
st.write("- 📊 search_results.parquet") st.write("- 📊 search_results.parquet")
@ -844,3 +619,5 @@ def render_overview_page():
render_dataset_analysis() render_dataset_analysis()
st.balloons() st.balloons()
stopwatch.summary()

View file

@ -1,5 +1,6 @@
"""Shared types used across the entropice codebase.""" """Shared types used across the entropice codebase."""
from dataclasses import dataclass
from typing import Literal from typing import Literal
type Grid = Literal["hex", "healpix"] type Grid = Literal["hex", "healpix"]
@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"] type Task = Literal["binary", "count", "density"]
type Model = Literal["espa", "xgboost", "rf", "knn"] type Model = Literal["espa", "xgboost", "rf", "knn"]
@dataclass(frozen=True)
class GridConfig:
"""Grid configuration specification with metadata."""
grid: Grid
level: int
id: GridLevel
display_name: str
sort_key: str
@classmethod
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
"""Create a GridConfig from a GridLevel string."""
if grid_level.startswith("hex"):
grid = "hex"
level = int(grid_level[3:])
elif grid_level.startswith("healpix"):
grid = "healpix"
level = int(grid_level[7:])
else:
raise ValueError(f"Invalid grid level: {grid_level}")
display_name = f"{grid.capitalize()}-{level}"
return cls(
grid=grid,
level=level,
id=grid_level,
display_name=display_name,
sort_key=f"{grid}_{level:02d}",
)
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count", "density"]
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",
"ERA5-seasonal",
"ERA5-yearly",
"AlphaEarth",
]
grid_configs: list[GridConfig] = [
GridConfig.from_grid_level("hex3"),
GridConfig.from_grid_level("hex4"),
GridConfig.from_grid_level("hex5"),
GridConfig.from_grid_level("hex6"),
GridConfig.from_grid_level("healpix6"),
GridConfig.from_grid_level("healpix7"),
GridConfig.from_grid_level("healpix8"),
GridConfig.from_grid_level("healpix9"),
GridConfig.from_grid_level("healpix10"),
]