Overview Page refactor done
This commit is contained in:
parent
36f8737075
commit
fca232da91
7 changed files with 270 additions and 363 deletions
|
|
@ -11,11 +11,12 @@ Pages:
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
# from entropice.dashboard.views.inference_page import render_inference_page
|
||||
# from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||
|
||||
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
# from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -24,17 +25,17 @@ def main():
|
|||
|
||||
# Setup Navigation
|
||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Overview": [overview_page],
|
||||
"Training": [training_data_page, training_analysis_page],
|
||||
"Model State": [model_state_page],
|
||||
"Inference": [inference_page],
|
||||
# "Training": [training_data_page, training_analysis_page],
|
||||
# "Model State": [model_state_page],
|
||||
# "Inference": [inference_page],
|
||||
}
|
||||
)
|
||||
pg.run()
|
||||
|
|
|
|||
|
|
@ -44,8 +44,12 @@ class TrainingResult:
|
|||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||
if not result_file.exists():
|
||||
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||
if not settings_file.exists():
|
||||
raise FileNotFoundError(f"Missing settings file in {result_path}")
|
||||
if not preds_file.exists():
|
||||
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||
|
|
@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]:
|
|||
for result_path in results_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
try:
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
except FileNotFoundError as e:
|
||||
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
|
||||
continue
|
||||
training_results.append(training_result)
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal, cast, get_args
|
||||
from typing import Literal
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
|
|
@ -17,7 +17,17 @@ import entropice.spatial.grids
|
|||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
||||
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
|
||||
from entropice.utils.types import (
|
||||
Grid,
|
||||
GridLevel,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
Task,
|
||||
all_l2_source_datasets,
|
||||
all_target_datasets,
|
||||
all_tasks,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -87,8 +97,7 @@ class TargetStatistics:
|
|||
training_coverage: dict[Task, float] = {}
|
||||
class_counts: dict[Task, dict[str, int]] = {}
|
||||
class_distribution: dict[Task, dict[str, float]] = {}
|
||||
tasks = cast(list[Task], get_args(Task))
|
||||
for task in tasks:
|
||||
for task in all_tasks:
|
||||
task_col = taskcol[task][target]
|
||||
cov_col = covcol[target]
|
||||
|
||||
|
|
@ -132,43 +141,97 @@ class DatasetStatistics:
|
|||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||
|
||||
@staticmethod
|
||||
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
"""Convert sample count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
for target_name, target_stats in stats.target.items():
|
||||
for task in all_tasks:
|
||||
training_cells = target_stats.training_cells[task]
|
||||
coverage = target_stats.coverage[task]
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Target": target_name.replace("darts_", ""),
|
||||
"Task": task.capitalize(),
|
||||
"Samples (Coverage)": training_cells,
|
||||
"Coverage %": coverage,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
"""Convert feature count data to DataFrame."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
data_sources = list(stats.members.keys())
|
||||
|
||||
# Determine minimum cells across all data sources
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||
|
||||
# Get sample count from first target dataset (darts_rts)
|
||||
first_target = stats.target["darts_rts"]
|
||||
total_samples = first_target.training_cells["binary"]
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Total Features": stats.total_features,
|
||||
"Data Sources": len(data_sources),
|
||||
"Inference Cells": min_cells,
|
||||
"Total Samples": total_samples,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
@staticmethod
|
||||
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||
rows = []
|
||||
for grid_config in grid_configs:
|
||||
stats = all_stats[grid_config.id]
|
||||
for member_name, member_stats in stats.members.items():
|
||||
rows.append(
|
||||
{
|
||||
"Grid": grid_config.display_name,
|
||||
"Data Source": member_name,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Grid_Level_Sort": grid_config.sort_key,
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
||||
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
for grid, level in grid_levels:
|
||||
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
|
||||
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
|
||||
for grid_config in grid_configs:
|
||||
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
||||
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
||||
total_cells = len(grid_gdf)
|
||||
assert total_cells > 0, "Grid must contain at least one cell."
|
||||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||
targets = cast(list[TargetDataset], get_args(TargetDataset))
|
||||
for target in targets:
|
||||
for target in all_target_datasets:
|
||||
target_statistics[target] = TargetStatistics.compute(
|
||||
grid=grid, level=level, target=target, total_cells=total_cells
|
||||
grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
|
||||
)
|
||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
|
||||
for member in members:
|
||||
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
|
||||
for member in all_l2_source_datasets:
|
||||
member_statistics[member] = MemberStatistics.compute(
|
||||
grid=grid_config.grid, level=grid_config.level, member=member
|
||||
)
|
||||
|
||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
||||
ts.size_bytes for ts in target_statistics.values()
|
||||
)
|
||||
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
|
||||
dataset_stats[grid_level] = DatasetStatistics(
|
||||
dataset_stats[grid_config.id] = DatasetStatistics(
|
||||
total_features=total_features,
|
||||
total_cells=total_cells,
|
||||
size_bytes=total_size_bytes,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Inference page: Visualization of model inference results across the study region."""
|
||||
|
||||
import streamlit as st
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
|
||||
from entropice.dashboard.plots.inference import (
|
||||
render_class_comparison,
|
||||
|
|
@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import (
|
|||
render_inference_statistics,
|
||||
render_spatial_distribution_stats,
|
||||
)
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
|
||||
|
||||
def render_inference_page():
|
||||
|
|
|
|||
|
|
@ -2,6 +2,15 @@
|
|||
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
get_members_from_settings,
|
||||
load_all_training_results,
|
||||
)
|
||||
from entropice.dashboard.utils.training import load_model_state
|
||||
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_arcticdem_heatmap,
|
||||
|
|
@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import (
|
|||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
get_members_from_settings,
|
||||
load_all_training_results,
|
||||
)
|
||||
from entropice.dashboard.utils.training import load_model_state
|
||||
|
||||
|
||||
def render_model_state_page():
|
||||
|
|
|
|||
|
|
@ -1,238 +1,20 @@
|
|||
"""Overview page: List of available result directories with some summary statistics."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import Grid, TargetDataset, Task
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
|
||||
|
||||
|
||||
# Type definitions for dataset statistics
|
||||
class GridConfig(TypedDict):
|
||||
"""Grid configuration specification with metadata."""
|
||||
|
||||
grid: Grid
|
||||
level: int
|
||||
grid_name: str
|
||||
grid_sort_key: str
|
||||
disable_alphaearth: bool
|
||||
|
||||
|
||||
class SampleCountData(TypedDict):
|
||||
"""Sample count statistics for a specific grid/target/task combination."""
|
||||
|
||||
grid_config: GridConfig
|
||||
target: str
|
||||
task: str
|
||||
samples_coverage: int
|
||||
samples_labels: int
|
||||
samples_both: int
|
||||
|
||||
|
||||
class FeatureCountData(TypedDict):
|
||||
"""Feature count statistics for a specific grid configuration."""
|
||||
|
||||
grid_config: GridConfig
|
||||
total_features: int
|
||||
data_sources: list[str]
|
||||
inference_cells: int
|
||||
total_samples: int
|
||||
member_breakdown: dict[str, int]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetAnalysisCache:
|
||||
"""Cache for dataset analysis data to avoid redundant computations."""
|
||||
|
||||
grid_configs: list[GridConfig]
|
||||
sample_counts: list[SampleCountData]
|
||||
feature_counts: list[FeatureCountData]
|
||||
|
||||
def get_sample_count_df(self) -> pd.DataFrame:
|
||||
"""Convert sample count data to DataFrame."""
|
||||
rows = []
|
||||
for item in self.sample_counts:
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Grid Type": item["grid_config"]["grid"],
|
||||
"Level": item["grid_config"]["level"],
|
||||
"Target": item["target"].replace("darts_", ""),
|
||||
"Task": item["task"].capitalize(),
|
||||
"Samples (Coverage)": item["samples_coverage"],
|
||||
"Samples (Labels)": item["samples_labels"],
|
||||
"Samples (Both)": item["samples_both"],
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def get_feature_count_df(self) -> pd.DataFrame:
|
||||
"""Convert feature count data to DataFrame."""
|
||||
rows = []
|
||||
for item in self.feature_counts:
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Grid Type": item["grid_config"]["grid"],
|
||||
"Level": item["grid_config"]["level"],
|
||||
"Total Features": item["total_features"],
|
||||
"Data Sources": len(item["data_sources"]),
|
||||
"Inference Cells": item["inference_cells"],
|
||||
"Total Samples": item["total_samples"],
|
||||
"AlphaEarth": "AlphaEarth" in item["data_sources"],
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
def get_feature_breakdown_df(self) -> pd.DataFrame:
|
||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||
rows = []
|
||||
for item in self.feature_counts:
|
||||
for source, count in item["member_breakdown"].items():
|
||||
rows.append(
|
||||
{
|
||||
"Grid": item["grid_config"]["grid_name"],
|
||||
"Data Source": source,
|
||||
"Number of Features": count,
|
||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
||||
"""Load and cache all dataset analysis data.
|
||||
|
||||
This function computes both sample counts and feature counts for all grid configurations.
|
||||
Results are cached to avoid redundant computations across different tabs.
|
||||
"""
|
||||
# Define all possible grid configurations
|
||||
grid_configs_raw: list[tuple[Grid, int]] = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
]
|
||||
|
||||
# Create structured grid config objects
|
||||
grid_configs: list[GridConfig] = []
|
||||
for grid, level in grid_configs_raw:
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
grid_configs.append(
|
||||
{
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"grid_name": f"{grid}-{level}",
|
||||
"grid_sort_key": f"{grid}_{level:02d}",
|
||||
"disable_alphaearth": disable_alphaearth,
|
||||
}
|
||||
)
|
||||
|
||||
# Compute sample counts
|
||||
sample_counts: list[SampleCountData] = []
|
||||
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
||||
tasks: list[Task] = ["binary", "count", "density"]
|
||||
|
||||
for grid_config in grid_configs:
|
||||
for target in target_datasets:
|
||||
# Create minimal ensemble just to get target data
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid_config["grid"],
|
||||
level=grid_config["level"],
|
||||
target=target,
|
||||
members=[], # type: ignore[arg-type]
|
||||
)
|
||||
targets = ensemble._read_target()
|
||||
|
||||
for task in tasks:
|
||||
# Get task-specific column
|
||||
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
|
||||
covcol = ensemble.covcol
|
||||
|
||||
# Count samples with coverage and valid labels
|
||||
if covcol in targets.columns and taskcol in targets.columns:
|
||||
valid_coverage = targets[covcol].sum()
|
||||
valid_labels = targets[taskcol].notna().sum()
|
||||
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
||||
|
||||
sample_counts.append(
|
||||
{
|
||||
"grid_config": grid_config,
|
||||
"target": target,
|
||||
"task": task,
|
||||
"samples_coverage": valid_coverage,
|
||||
"samples_labels": valid_labels,
|
||||
"samples_both": valid_both,
|
||||
}
|
||||
)
|
||||
|
||||
# Compute feature counts
|
||||
feature_counts: list[FeatureCountData] = []
|
||||
|
||||
for grid_config in grid_configs:
|
||||
# Determine which members are available for this configuration
|
||||
if grid_config["disable_alphaearth"]:
|
||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
else:
|
||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
# Use darts_rts as default target for comparison
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid_config["grid"],
|
||||
level=grid_config["level"],
|
||||
target="darts_rts",
|
||||
members=members, # type: ignore[arg-type]
|
||||
)
|
||||
stats = ensemble.get_stats()
|
||||
|
||||
# Calculate minimum cells across all data sources
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
|
||||
# Build member breakdown including lon/lat
|
||||
member_breakdown = {}
|
||||
for member, member_stats in stats["members"].items():
|
||||
member_breakdown[member] = member_stats["num_features"]
|
||||
|
||||
if ensemble.add_lonlat:
|
||||
member_breakdown["Lon/Lat"] = 2
|
||||
|
||||
feature_counts.append(
|
||||
{
|
||||
"grid_config": grid_config,
|
||||
"total_features": stats["total_features"],
|
||||
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
|
||||
"inference_cells": min_cells,
|
||||
"total_samples": stats["num_target_samples"],
|
||||
"member_breakdown": member_breakdown,
|
||||
}
|
||||
)
|
||||
|
||||
return DatasetAnalysisCache(
|
||||
grid_configs=grid_configs,
|
||||
sample_counts=sample_counts,
|
||||
feature_counts=feature_counts,
|
||||
)
|
||||
|
||||
|
||||
def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||
def render_sample_count_overview():
|
||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||
st.subheader("📊 Sample Counts by Configuration")
|
||||
|
||||
|
|
@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|||
)
|
||||
|
||||
# Get sample count DataFrame from cache
|
||||
sample_df = cache.get_sample_count_df()
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
|
||||
# Create tabs for different views
|
||||
|
|
@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|||
|
||||
with tab1:
|
||||
st.markdown("### Sample Counts Heatmap")
|
||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||
st.markdown("Showing counts of samples with coverage")
|
||||
|
||||
# Create heatmap for each target dataset
|
||||
for target in target_datasets:
|
||||
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||
|
||||
# Pivot for heatmap: Grid x Task
|
||||
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
|
||||
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
|
||||
|
||||
# Sort index by grid type and level
|
||||
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||
|
|
@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|||
|
||||
with tab2:
|
||||
st.markdown("### Sample Counts Bar Chart")
|
||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
||||
st.markdown("Showing counts of samples with coverage")
|
||||
|
||||
# Create a faceted bar chart showing both targets side by side
|
||||
# Get color palette for tasks
|
||||
|
|
@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|||
fig = px.bar(
|
||||
sample_df,
|
||||
x="Grid",
|
||||
y="Samples (Both)",
|
||||
y="Samples (Coverage)",
|
||||
color="Task",
|
||||
facet_col="Target",
|
||||
barmode="group",
|
||||
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
|
||||
labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
|
||||
color_discrete_sequence=task_colors,
|
||||
height=500,
|
||||
)
|
||||
|
|
@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|||
st.markdown("### Detailed Sample Counts")
|
||||
|
||||
# Display full table with formatting
|
||||
display_df = sample_df[
|
||||
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
|
||||
].copy()
|
||||
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||
# Format coverage as percentage with 2 decimal places
|
||||
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
||||
def render_feature_count_comparison():
|
||||
"""Render static comparison of feature counts across all grid configurations."""
|
||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||
|
||||
# Get data from cache
|
||||
comparison_df = cache.get_feature_count_df()
|
||||
breakdown_df = cache.get_feature_breakdown_df()
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
||||
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||
|
||||
# Create tabs for different comparison views
|
||||
|
|
@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
|||
|
||||
# Display full comparison table with formatting
|
||||
display_df = comparison_df[
|
||||
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
|
||||
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
|
||||
].copy()
|
||||
|
||||
# Format numbers with commas
|
||||
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||
|
||||
# Format boolean as Yes/No
|
||||
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
|
||||
|
||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||
def render_feature_count_explorer():
|
||||
"""Render interactive detailed configuration explorer using fragments."""
|
||||
st.markdown("### Detailed Configuration Explorer")
|
||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||
|
||||
# Grid selection
|
||||
grid_options = [gc["grid_name"] for gc in cache.grid_configs]
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
|
|
@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
|||
)
|
||||
|
||||
# Find the selected grid config
|
||||
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
|
||||
grid = selected_grid_config["grid"]
|
||||
level = selected_grid_config["level"]
|
||||
disable_alphaearth = selected_grid_config["disable_alphaearth"]
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
|
||||
# Get available members from the stats
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
stats = all_stats[selected_grid_config.id]
|
||||
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
|
||||
|
||||
# Members selection
|
||||
st.markdown("#### Select Data Sources")
|
||||
|
||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
all_members = cast(
|
||||
list[L2SourceDataset],
|
||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||
)
|
||||
|
||||
# Use columns for checkboxes
|
||||
cols = st.columns(len(all_members))
|
||||
selected_members = []
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
|
||||
for idx, member in enumerate(all_members):
|
||||
with cols[idx]:
|
||||
if member == "AlphaEarth" and disable_alphaearth:
|
||||
st.checkbox(
|
||||
member,
|
||||
value=False,
|
||||
disabled=True,
|
||||
help=f"Not available for {grid} level {level}",
|
||||
key=f"feature_member_{member}",
|
||||
)
|
||||
else:
|
||||
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
|
||||
selected_members.append(member)
|
||||
default_value = member in available_members
|
||||
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||
selected_members.append(cast(L2SourceDataset, member))
|
||||
|
||||
# Show results if at least one member is selected
|
||||
if selected_members:
|
||||
st.markdown("---")
|
||||
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
||||
# Get statistics from cache (already loaded)
|
||||
grid_stats = all_stats[selected_grid_config.id]
|
||||
|
||||
with st.spinner("Computing dataset statistics..."):
|
||||
stats = ensemble.get_stats()
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
|
||||
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Get target stats
|
||||
target_stats = grid_stats.target[cast(TargetDataset, target)]
|
||||
|
||||
# High-level metrics
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
with col1:
|
||||
st.metric("Total Features", f"{stats['total_features']:,}")
|
||||
st.metric("Total Features", f"{total_features:,}")
|
||||
with col2:
|
||||
# Calculate minimum cells across all data sources (for inference capability)
|
||||
min_cells = min(
|
||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
||||
for member_stats in stats["members"].values()
|
||||
)
|
||||
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
||||
with col3:
|
||||
st.metric("Data Sources", len(selected_members))
|
||||
with col4:
|
||||
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
|
||||
# Use binary task training cells as sample count
|
||||
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
|
||||
with col5:
|
||||
# Calculate total data points
|
||||
total_points = stats["total_features"] * stats["num_target_samples"]
|
||||
total_points = total_features * target_stats.training_cells["binary"]
|
||||
st.metric("Total Data Points", f"{total_points:,}")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in stats["members"].items():
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats["num_features"],
|
||||
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
|
||||
}
|
||||
)
|
||||
|
||||
# Add lon/lat
|
||||
if ensemble.add_lonlat:
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": "Lon/Lat",
|
||||
"Number of Features": 2,
|
||||
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
|||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
for member, member_stats in stats["members"].items():
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
metric_cols = st.columns(4)
|
||||
with metric_cols[0]:
|
||||
st.metric("Features", member_stats["num_features"])
|
||||
st.metric("Features", member_stats.feature_count)
|
||||
with metric_cols[1]:
|
||||
st.metric("Variables", member_stats["num_variables"])
|
||||
st.metric("Variables", len(member_stats.variable_names))
|
||||
with metric_cols[2]:
|
||||
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
||||
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
|
||||
st.metric("Shape", dim_str)
|
||||
with metric_cols[3]:
|
||||
total_points = 1
|
||||
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
||||
for dim_size in member_stats.dimensions.values():
|
||||
total_points *= dim_size
|
||||
st.metric("Data Points", f"{total_points:,}")
|
||||
|
||||
|
|
@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
|||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats["variables"] # type: ignore[union-attr]
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
|
@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
|||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size}</span>"
|
||||
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
# Size on disk
|
||||
size_mb = member_stats.size_bytes / (1024 * 1024)
|
||||
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
|
||||
|
||||
st.markdown("---")
|
||||
else:
|
||||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_feature_count_section(cache: DatasetAnalysisCache):
|
||||
def render_feature_count_section():
|
||||
"""Render the feature count section with comparison and explorer."""
|
||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||
|
||||
|
|
@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache):
|
|||
)
|
||||
|
||||
# Static comparison across all grids
|
||||
render_feature_count_comparison(cache)
|
||||
render_feature_count_comparison()
|
||||
|
||||
st.divider()
|
||||
|
||||
# Interactive explorer for detailed analysis
|
||||
render_feature_count_explorer(cache)
|
||||
render_feature_count_explorer()
|
||||
|
||||
|
||||
def render_dataset_analysis():
|
||||
"""Render the dataset analysis section with sample and feature counts."""
|
||||
st.header("📈 Dataset Analysis")
|
||||
|
||||
# Load all data once and cache it
|
||||
with st.spinner("Loading dataset analysis data..."):
|
||||
cache = load_dataset_analysis_data()
|
||||
|
||||
# Create tabs for the two different analyses
|
||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||
|
||||
with analysis_tabs[0]:
|
||||
render_sample_count_overview(cache)
|
||||
render_sample_count_overview()
|
||||
|
||||
with analysis_tabs[1]:
|
||||
render_feature_count_section(cache)
|
||||
render_feature_count_section()
|
||||
|
||||
|
||||
def render_training_results_summary(training_results):
|
||||
|
|
@ -678,15 +450,15 @@ def render_training_results_summary(training_results):
|
|||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
|
||||
tasks = {tr.settings.task for tr in training_results}
|
||||
st.metric("Tasks", len(tasks))
|
||||
|
||||
with col2:
|
||||
grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
|
||||
grids = {tr.settings.grid for tr in training_results}
|
||||
st.metric("Grid Types", len(grids))
|
||||
|
||||
with col3:
|
||||
models = {tr.settings.get("model", "Unknown") for tr in training_results}
|
||||
models = {tr.settings.model for tr in training_results}
|
||||
st.metric("Model Types", len(models))
|
||||
|
||||
with col4:
|
||||
|
|
@ -725,10 +497,10 @@ def render_experiment_results(training_results):
|
|||
summary_data.append(
|
||||
{
|
||||
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
||||
"Task": tr.settings.get("task", "Unknown"),
|
||||
"Grid": tr.settings.get("grid", "Unknown"),
|
||||
"Level": tr.settings.get("level", "Unknown"),
|
||||
"Model": tr.settings.get("model", "Unknown"),
|
||||
"Task": tr.settings.task,
|
||||
"Grid": tr.settings.grid,
|
||||
"Level": tr.settings.level,
|
||||
"Model": tr.settings.model,
|
||||
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
||||
"Trials": len(tr.results),
|
||||
"Path": str(tr.path.name),
|
||||
|
|
@ -750,17 +522,20 @@ def render_experiment_results(training_results):
|
|||
st.subheader("Individual Experiment Details")
|
||||
|
||||
for tr in training_results:
|
||||
with st.expander(tr.get_display_name("task_first")):
|
||||
display_name = (
|
||||
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
|
||||
)
|
||||
with st.expander(display_name):
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.write("**Configuration:**")
|
||||
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
|
||||
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
|
||||
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
|
||||
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
|
||||
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
|
||||
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
|
||||
st.write(f"- **Task:** {tr.settings.task}")
|
||||
st.write(f"- **Grid:** {tr.settings.grid}")
|
||||
st.write(f"- **Level:** {tr.settings.level}")
|
||||
st.write(f"- **Model:** {tr.settings.model}")
|
||||
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
||||
st.write(f"- **Classes:** {tr.settings.classes}")
|
||||
|
||||
st.write("\n**Files:**")
|
||||
st.write("- 📊 search_results.parquet")
|
||||
|
|
@ -844,3 +619,5 @@ def render_overview_page():
|
|||
render_dataset_analysis()
|
||||
|
||||
st.balloons()
|
||||
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Shared types used across the entropice codebase."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
type Grid = Literal["hex", "healpix"]
|
||||
|
|
@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
|||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||
type Task = Literal["binary", "count", "density"]
|
||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GridConfig:
|
||||
"""Grid configuration specification with metadata."""
|
||||
|
||||
grid: Grid
|
||||
level: int
|
||||
id: GridLevel
|
||||
display_name: str
|
||||
sort_key: str
|
||||
|
||||
@classmethod
|
||||
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
|
||||
"""Create a GridConfig from a GridLevel string."""
|
||||
if grid_level.startswith("hex"):
|
||||
grid = "hex"
|
||||
level = int(grid_level[3:])
|
||||
elif grid_level.startswith("healpix"):
|
||||
grid = "healpix"
|
||||
level = int(grid_level[7:])
|
||||
else:
|
||||
raise ValueError(f"Invalid grid level: {grid_level}")
|
||||
|
||||
display_name = f"{grid.capitalize()}-{level}"
|
||||
|
||||
return cls(
|
||||
grid=grid,
|
||||
level=level,
|
||||
id=grid_level,
|
||||
display_name=display_name,
|
||||
sort_key=f"{grid}_{level:02d}",
|
||||
)
|
||||
|
||||
|
||||
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
||||
all_tasks: list[Task] = ["binary", "count", "density"]
|
||||
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
||||
all_l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-yearly",
|
||||
"AlphaEarth",
|
||||
]
|
||||
|
||||
grid_configs: list[GridConfig] = [
|
||||
GridConfig.from_grid_level("hex3"),
|
||||
GridConfig.from_grid_level("hex4"),
|
||||
GridConfig.from_grid_level("hex5"),
|
||||
GridConfig.from_grid_level("hex6"),
|
||||
GridConfig.from_grid_level("healpix6"),
|
||||
GridConfig.from_grid_level("healpix7"),
|
||||
GridConfig.from_grid_level("healpix8"),
|
||||
GridConfig.from_grid_level("healpix9"),
|
||||
GridConfig.from_grid_level("healpix10"),
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue