Overview Page refactor done
This commit is contained in:
parent
36f8737075
commit
fca232da91
7 changed files with 270 additions and 363 deletions
|
|
@ -11,11 +11,12 @@ Pages:
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.views.inference_page import render_inference_page
|
# from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
# from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.views.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
|
||||||
from entropice.dashboard.views.training_data_page import render_training_data_page
|
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||||
|
# from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -24,17 +25,17 @@ def main():
|
||||||
|
|
||||||
# Setup Navigation
|
# Setup Navigation
|
||||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||||
|
|
||||||
pg = st.navigation(
|
pg = st.navigation(
|
||||||
{
|
{
|
||||||
"Overview": [overview_page],
|
"Overview": [overview_page],
|
||||||
"Training": [training_data_page, training_analysis_page],
|
# "Training": [training_data_page, training_analysis_page],
|
||||||
"Model State": [model_state_page],
|
# "Model State": [model_state_page],
|
||||||
"Inference": [inference_page],
|
# "Inference": [inference_page],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
pg.run()
|
pg.run()
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,12 @@ class TrainingResult:
|
||||||
result_file = result_path / "search_results.parquet"
|
result_file = result_path / "search_results.parquet"
|
||||||
preds_file = result_path / "predicted_probabilities.parquet"
|
preds_file = result_path / "predicted_probabilities.parquet"
|
||||||
settings_file = result_path / "search_settings.toml"
|
settings_file = result_path / "search_settings.toml"
|
||||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
if not result_file.exists():
|
||||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
raise FileNotFoundError(f"Missing results file in {result_path}")
|
||||||
|
if not settings_file.exists():
|
||||||
|
raise FileNotFoundError(f"Missing settings file in {result_path}")
|
||||||
|
if not preds_file.exists():
|
||||||
|
raise FileNotFoundError(f"Missing predictions file in {result_path}")
|
||||||
|
|
||||||
created_at = result_path.stat().st_ctime
|
created_at = result_path.stat().st_ctime
|
||||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||||
|
|
@ -119,7 +123,11 @@ def load_all_training_results() -> list[TrainingResult]:
|
||||||
for result_path in results_dir.iterdir():
|
for result_path in results_dir.iterdir():
|
||||||
if not result_path.is_dir():
|
if not result_path.is_dir():
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
training_result = TrainingResult.from_path(result_path)
|
training_result = TrainingResult.from_path(result_path)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
|
||||||
|
continue
|
||||||
training_results.append(training_result)
|
training_results.append(training_result)
|
||||||
|
|
||||||
# Sort by creation time (most recent first)
|
# Sort by creation time (most recent first)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Literal, cast, get_args
|
from typing import Literal
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -17,7 +17,17 @@ import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
from entropice.dashboard.utils.loaders import TrainingResult
|
from entropice.dashboard.utils.loaders import TrainingResult
|
||||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
||||||
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
|
from entropice.utils.types import (
|
||||||
|
Grid,
|
||||||
|
GridLevel,
|
||||||
|
L2SourceDataset,
|
||||||
|
TargetDataset,
|
||||||
|
Task,
|
||||||
|
all_l2_source_datasets,
|
||||||
|
all_target_datasets,
|
||||||
|
all_tasks,
|
||||||
|
grid_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -87,8 +97,7 @@ class TargetStatistics:
|
||||||
training_coverage: dict[Task, float] = {}
|
training_coverage: dict[Task, float] = {}
|
||||||
class_counts: dict[Task, dict[str, int]] = {}
|
class_counts: dict[Task, dict[str, int]] = {}
|
||||||
class_distribution: dict[Task, dict[str, float]] = {}
|
class_distribution: dict[Task, dict[str, float]] = {}
|
||||||
tasks = cast(list[Task], get_args(Task))
|
for task in all_tasks:
|
||||||
for task in tasks:
|
|
||||||
task_col = taskcol[task][target]
|
task_col = taskcol[task][target]
|
||||||
cov_col = covcol[target]
|
cov_col = covcol[target]
|
||||||
|
|
||||||
|
|
@ -132,43 +141,97 @@ class DatasetStatistics:
|
||||||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sample_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||||
|
"""Convert sample count data to DataFrame."""
|
||||||
|
rows = []
|
||||||
|
for grid_config in grid_configs:
|
||||||
|
stats = all_stats[grid_config.id]
|
||||||
|
for target_name, target_stats in stats.target.items():
|
||||||
|
for task in all_tasks:
|
||||||
|
training_cells = target_stats.training_cells[task]
|
||||||
|
coverage = target_stats.coverage[task]
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config.display_name,
|
||||||
|
"Target": target_name.replace("darts_", ""),
|
||||||
|
"Task": task.capitalize(),
|
||||||
|
"Samples (Coverage)": training_cells,
|
||||||
|
"Coverage %": coverage,
|
||||||
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_feature_count_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||||
|
"""Convert feature count data to DataFrame."""
|
||||||
|
rows = []
|
||||||
|
for grid_config in grid_configs:
|
||||||
|
stats = all_stats[grid_config.id]
|
||||||
|
data_sources = list(stats.members.keys())
|
||||||
|
|
||||||
|
# Determine minimum cells across all data sources
|
||||||
|
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in stats.members.values())
|
||||||
|
|
||||||
|
# Get sample count from first target dataset (darts_rts)
|
||||||
|
first_target = stats.target["darts_rts"]
|
||||||
|
total_samples = first_target.training_cells["binary"]
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config.display_name,
|
||||||
|
"Total Features": stats.total_features,
|
||||||
|
"Data Sources": len(data_sources),
|
||||||
|
"Inference Cells": min_cells,
|
||||||
|
"Total Samples": total_samples,
|
||||||
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_feature_breakdown_df(all_stats: dict[GridLevel, "DatasetStatistics"]) -> pd.DataFrame:
|
||||||
|
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
||||||
|
rows = []
|
||||||
|
for grid_config in grid_configs:
|
||||||
|
stats = all_stats[grid_config.id]
|
||||||
|
for member_name, member_stats in stats.members.items():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"Grid": grid_config.display_name,
|
||||||
|
"Data Source": member_name,
|
||||||
|
"Number of Features": member_stats.feature_count,
|
||||||
|
"Grid_Level_Sort": grid_config.sort_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
||||||
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
||||||
grid_levels: set[tuple[Grid, int]] = {
|
for grid_config in grid_configs:
|
||||||
("hex", 3),
|
with stopwatch(f"Loading statistics for grid={grid_config.grid}, level={grid_config.level}"):
|
||||||
("hex", 4),
|
grid_gdf = entropice.spatial.grids.open(grid_config.grid, grid_config.level) # Ensure grid is registered
|
||||||
("hex", 5),
|
|
||||||
("hex", 6),
|
|
||||||
("healpix", 6),
|
|
||||||
("healpix", 7),
|
|
||||||
("healpix", 8),
|
|
||||||
("healpix", 9),
|
|
||||||
("healpix", 10),
|
|
||||||
}
|
|
||||||
for grid, level in grid_levels:
|
|
||||||
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
|
|
||||||
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
|
|
||||||
total_cells = len(grid_gdf)
|
total_cells = len(grid_gdf)
|
||||||
assert total_cells > 0, "Grid must contain at least one cell."
|
assert total_cells > 0, "Grid must contain at least one cell."
|
||||||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||||
targets = cast(list[TargetDataset], get_args(TargetDataset))
|
for target in all_target_datasets:
|
||||||
for target in targets:
|
|
||||||
target_statistics[target] = TargetStatistics.compute(
|
target_statistics[target] = TargetStatistics.compute(
|
||||||
grid=grid, level=level, target=target, total_cells=total_cells
|
grid=grid_config.grid, level=grid_config.level, target=target, total_cells=total_cells
|
||||||
)
|
)
|
||||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||||
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
|
for member in all_l2_source_datasets:
|
||||||
for member in members:
|
member_statistics[member] = MemberStatistics.compute(
|
||||||
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
|
grid=grid_config.grid, level=grid_config.level, member=member
|
||||||
|
)
|
||||||
|
|
||||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
||||||
ts.size_bytes for ts in target_statistics.values()
|
ts.size_bytes for ts in target_statistics.values()
|
||||||
)
|
)
|
||||||
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
|
dataset_stats[grid_config.id] = DatasetStatistics(
|
||||||
dataset_stats[grid_level] = DatasetStatistics(
|
|
||||||
total_features=total_features,
|
total_features=total_features,
|
||||||
total_cells=total_cells,
|
total_cells=total_cells,
|
||||||
size_bytes=total_size_bytes,
|
size_bytes=total_size_bytes,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Inference page: Visualization of model inference results across the study region."""
|
"""Inference page: Visualization of model inference results across the study region."""
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
from entropice.dashboard.utils.data import load_all_training_results
|
||||||
|
|
||||||
from entropice.dashboard.plots.inference import (
|
from entropice.dashboard.plots.inference import (
|
||||||
render_class_comparison,
|
render_class_comparison,
|
||||||
|
|
@ -9,7 +10,6 @@ from entropice.dashboard.plots.inference import (
|
||||||
render_inference_statistics,
|
render_inference_statistics,
|
||||||
render_spatial_distribution_stats,
|
render_spatial_distribution_stats,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.data import load_all_training_results
|
|
||||||
|
|
||||||
|
|
||||||
def render_inference_page():
|
def render_inference_page():
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,15 @@
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from entropice.dashboard.utils.data import (
|
||||||
|
extract_arcticdem_features,
|
||||||
|
extract_common_features,
|
||||||
|
extract_embedding_features,
|
||||||
|
extract_era5_features,
|
||||||
|
get_members_from_settings,
|
||||||
|
load_all_training_results,
|
||||||
|
)
|
||||||
|
from entropice.dashboard.utils.training import load_model_state
|
||||||
|
|
||||||
from entropice.dashboard.plots.model_state import (
|
from entropice.dashboard.plots.model_state import (
|
||||||
plot_arcticdem_heatmap,
|
plot_arcticdem_heatmap,
|
||||||
|
|
@ -17,15 +26,6 @@ from entropice.dashboard.plots.model_state import (
|
||||||
plot_top_features,
|
plot_top_features,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.colors import generate_unified_colormap
|
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||||
from entropice.dashboard.utils.data import (
|
|
||||||
extract_arcticdem_features,
|
|
||||||
extract_common_features,
|
|
||||||
extract_embedding_features,
|
|
||||||
extract_era5_features,
|
|
||||||
get_members_from_settings,
|
|
||||||
load_all_training_results,
|
|
||||||
)
|
|
||||||
from entropice.dashboard.utils.training import load_model_state
|
|
||||||
|
|
||||||
|
|
||||||
def render_model_state_page():
|
def render_model_state_page():
|
||||||
|
|
|
||||||
|
|
@ -1,238 +1,20 @@
|
||||||
"""Overview page: List of available result directories with some summary statistics."""
|
"""Overview page: List of available result directories with some summary statistics."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TypedDict
|
from typing import cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||||
from entropice.utils.types import Grid, TargetDataset, Task
|
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, grid_configs
|
||||||
|
|
||||||
|
|
||||||
# Type definitions for dataset statistics
|
def render_sample_count_overview():
|
||||||
class GridConfig(TypedDict):
|
|
||||||
"""Grid configuration specification with metadata."""
|
|
||||||
|
|
||||||
grid: Grid
|
|
||||||
level: int
|
|
||||||
grid_name: str
|
|
||||||
grid_sort_key: str
|
|
||||||
disable_alphaearth: bool
|
|
||||||
|
|
||||||
|
|
||||||
class SampleCountData(TypedDict):
|
|
||||||
"""Sample count statistics for a specific grid/target/task combination."""
|
|
||||||
|
|
||||||
grid_config: GridConfig
|
|
||||||
target: str
|
|
||||||
task: str
|
|
||||||
samples_coverage: int
|
|
||||||
samples_labels: int
|
|
||||||
samples_both: int
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureCountData(TypedDict):
|
|
||||||
"""Feature count statistics for a specific grid configuration."""
|
|
||||||
|
|
||||||
grid_config: GridConfig
|
|
||||||
total_features: int
|
|
||||||
data_sources: list[str]
|
|
||||||
inference_cells: int
|
|
||||||
total_samples: int
|
|
||||||
member_breakdown: dict[str, int]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class DatasetAnalysisCache:
|
|
||||||
"""Cache for dataset analysis data to avoid redundant computations."""
|
|
||||||
|
|
||||||
grid_configs: list[GridConfig]
|
|
||||||
sample_counts: list[SampleCountData]
|
|
||||||
feature_counts: list[FeatureCountData]
|
|
||||||
|
|
||||||
def get_sample_count_df(self) -> pd.DataFrame:
|
|
||||||
"""Convert sample count data to DataFrame."""
|
|
||||||
rows = []
|
|
||||||
for item in self.sample_counts:
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"Grid": item["grid_config"]["grid_name"],
|
|
||||||
"Grid Type": item["grid_config"]["grid"],
|
|
||||||
"Level": item["grid_config"]["level"],
|
|
||||||
"Target": item["target"].replace("darts_", ""),
|
|
||||||
"Task": item["task"].capitalize(),
|
|
||||||
"Samples (Coverage)": item["samples_coverage"],
|
|
||||||
"Samples (Labels)": item["samples_labels"],
|
|
||||||
"Samples (Both)": item["samples_both"],
|
|
||||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
|
|
||||||
def get_feature_count_df(self) -> pd.DataFrame:
|
|
||||||
"""Convert feature count data to DataFrame."""
|
|
||||||
rows = []
|
|
||||||
for item in self.feature_counts:
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"Grid": item["grid_config"]["grid_name"],
|
|
||||||
"Grid Type": item["grid_config"]["grid"],
|
|
||||||
"Level": item["grid_config"]["level"],
|
|
||||||
"Total Features": item["total_features"],
|
|
||||||
"Data Sources": len(item["data_sources"]),
|
|
||||||
"Inference Cells": item["inference_cells"],
|
|
||||||
"Total Samples": item["total_samples"],
|
|
||||||
"AlphaEarth": "AlphaEarth" in item["data_sources"],
|
|
||||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
|
|
||||||
def get_feature_breakdown_df(self) -> pd.DataFrame:
|
|
||||||
"""Convert feature breakdown data to DataFrame for stacked/donut charts."""
|
|
||||||
rows = []
|
|
||||||
for item in self.feature_counts:
|
|
||||||
for source, count in item["member_breakdown"].items():
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"Grid": item["grid_config"]["grid_name"],
|
|
||||||
"Data Source": source,
|
|
||||||
"Number of Features": count,
|
|
||||||
"Grid_Level_Sort": item["grid_config"]["grid_sort_key"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data(show_spinner=False)
|
|
||||||
def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
|
||||||
"""Load and cache all dataset analysis data.
|
|
||||||
|
|
||||||
This function computes both sample counts and feature counts for all grid configurations.
|
|
||||||
Results are cached to avoid redundant computations across different tabs.
|
|
||||||
"""
|
|
||||||
# Define all possible grid configurations
|
|
||||||
grid_configs_raw: list[tuple[Grid, int]] = [
|
|
||||||
("hex", 3),
|
|
||||||
("hex", 4),
|
|
||||||
("hex", 5),
|
|
||||||
("hex", 6),
|
|
||||||
("healpix", 6),
|
|
||||||
("healpix", 7),
|
|
||||||
("healpix", 8),
|
|
||||||
("healpix", 9),
|
|
||||||
("healpix", 10),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create structured grid config objects
|
|
||||||
grid_configs: list[GridConfig] = []
|
|
||||||
for grid, level in grid_configs_raw:
|
|
||||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
|
||||||
grid_configs.append(
|
|
||||||
{
|
|
||||||
"grid": grid,
|
|
||||||
"level": level,
|
|
||||||
"grid_name": f"{grid}-{level}",
|
|
||||||
"grid_sort_key": f"{grid}_{level:02d}",
|
|
||||||
"disable_alphaearth": disable_alphaearth,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute sample counts
|
|
||||||
sample_counts: list[SampleCountData] = []
|
|
||||||
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
|
||||||
tasks: list[Task] = ["binary", "count", "density"]
|
|
||||||
|
|
||||||
for grid_config in grid_configs:
|
|
||||||
for target in target_datasets:
|
|
||||||
# Create minimal ensemble just to get target data
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid=grid_config["grid"],
|
|
||||||
level=grid_config["level"],
|
|
||||||
target=target,
|
|
||||||
members=[], # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
targets = ensemble._read_target()
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
# Get task-specific column
|
|
||||||
taskcol = ensemble.taskcol(task) # type: ignore[arg-type]
|
|
||||||
covcol = ensemble.covcol
|
|
||||||
|
|
||||||
# Count samples with coverage and valid labels
|
|
||||||
if covcol in targets.columns and taskcol in targets.columns:
|
|
||||||
valid_coverage = targets[covcol].sum()
|
|
||||||
valid_labels = targets[taskcol].notna().sum()
|
|
||||||
valid_both = (targets[covcol] & targets[taskcol].notna()).sum()
|
|
||||||
|
|
||||||
sample_counts.append(
|
|
||||||
{
|
|
||||||
"grid_config": grid_config,
|
|
||||||
"target": target,
|
|
||||||
"task": task,
|
|
||||||
"samples_coverage": valid_coverage,
|
|
||||||
"samples_labels": valid_labels,
|
|
||||||
"samples_both": valid_both,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute feature counts
|
|
||||||
feature_counts: list[FeatureCountData] = []
|
|
||||||
|
|
||||||
for grid_config in grid_configs:
|
|
||||||
# Determine which members are available for this configuration
|
|
||||||
if grid_config["disable_alphaearth"]:
|
|
||||||
members = ["ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
else:
|
|
||||||
members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
|
||||||
|
|
||||||
# Use darts_rts as default target for comparison
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid=grid_config["grid"],
|
|
||||||
level=grid_config["level"],
|
|
||||||
target="darts_rts",
|
|
||||||
members=members, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
stats = ensemble.get_stats()
|
|
||||||
|
|
||||||
# Calculate minimum cells across all data sources
|
|
||||||
min_cells = min(
|
|
||||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
|
||||||
for member_stats in stats["members"].values()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build member breakdown including lon/lat
|
|
||||||
member_breakdown = {}
|
|
||||||
for member, member_stats in stats["members"].items():
|
|
||||||
member_breakdown[member] = member_stats["num_features"]
|
|
||||||
|
|
||||||
if ensemble.add_lonlat:
|
|
||||||
member_breakdown["Lon/Lat"] = 2
|
|
||||||
|
|
||||||
feature_counts.append(
|
|
||||||
{
|
|
||||||
"grid_config": grid_config,
|
|
||||||
"total_features": stats["total_features"],
|
|
||||||
"data_sources": members + (["Lon/Lat"] if ensemble.add_lonlat else []),
|
|
||||||
"inference_cells": min_cells,
|
|
||||||
"total_samples": stats["num_target_samples"],
|
|
||||||
"member_breakdown": member_breakdown,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return DatasetAnalysisCache(
|
|
||||||
grid_configs=grid_configs,
|
|
||||||
sample_counts=sample_counts,
|
|
||||||
feature_counts=feature_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def render_sample_count_overview(cache: DatasetAnalysisCache):
|
|
||||||
"""Render overview of sample counts per task+target+grid+level combination."""
|
"""Render overview of sample counts per task+target+grid+level combination."""
|
||||||
st.subheader("📊 Sample Counts by Configuration")
|
st.subheader("📊 Sample Counts by Configuration")
|
||||||
|
|
||||||
|
|
@ -247,7 +29,8 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get sample count DataFrame from cache
|
# Get sample count DataFrame from cache
|
||||||
sample_df = cache.get_sample_count_df()
|
all_stats = load_all_default_dataset_statistics()
|
||||||
|
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
|
||||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||||
|
|
||||||
# Create tabs for different views
|
# Create tabs for different views
|
||||||
|
|
@ -255,14 +38,14 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
|
|
||||||
with tab1:
|
with tab1:
|
||||||
st.markdown("### Sample Counts Heatmap")
|
st.markdown("### Sample Counts Heatmap")
|
||||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
st.markdown("Showing counts of samples with coverage")
|
||||||
|
|
||||||
# Create heatmap for each target dataset
|
# Create heatmap for each target dataset
|
||||||
for target in target_datasets:
|
for target in target_datasets:
|
||||||
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
|
||||||
|
|
||||||
# Pivot for heatmap: Grid x Task
|
# Pivot for heatmap: Grid x Task
|
||||||
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Both)", aggfunc="mean")
|
pivot_df = target_df.pivot_table(index="Grid", columns="Task", values="Samples (Coverage)", aggfunc="mean")
|
||||||
|
|
||||||
# Sort index by grid type and level
|
# Sort index by grid type and level
|
||||||
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
|
||||||
|
|
@ -289,7 +72,7 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
|
|
||||||
with tab2:
|
with tab2:
|
||||||
st.markdown("### Sample Counts Bar Chart")
|
st.markdown("### Sample Counts Bar Chart")
|
||||||
st.markdown("Showing counts of samples with both coverage and valid labels")
|
st.markdown("Showing counts of samples with coverage")
|
||||||
|
|
||||||
# Create a faceted bar chart showing both targets side by side
|
# Create a faceted bar chart showing both targets side by side
|
||||||
# Get color palette for tasks
|
# Get color palette for tasks
|
||||||
|
|
@ -299,12 +82,12 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
fig = px.bar(
|
fig = px.bar(
|
||||||
sample_df,
|
sample_df,
|
||||||
x="Grid",
|
x="Grid",
|
||||||
y="Samples (Both)",
|
y="Samples (Coverage)",
|
||||||
color="Task",
|
color="Task",
|
||||||
facet_col="Target",
|
facet_col="Target",
|
||||||
barmode="group",
|
barmode="group",
|
||||||
title="Sample Counts by Grid Configuration and Target Dataset",
|
title="Sample Counts by Grid Configuration and Target Dataset",
|
||||||
labels={"Grid": "Grid Configuration", "Samples (Both)": "Number of Samples"},
|
labels={"Grid": "Grid Configuration", "Samples (Coverage)": "Number of Samples"},
|
||||||
color_discrete_sequence=task_colors,
|
color_discrete_sequence=task_colors,
|
||||||
height=500,
|
height=500,
|
||||||
)
|
)
|
||||||
|
|
@ -318,25 +101,25 @@ def render_sample_count_overview(cache: DatasetAnalysisCache):
|
||||||
st.markdown("### Detailed Sample Counts")
|
st.markdown("### Detailed Sample Counts")
|
||||||
|
|
||||||
# Display full table with formatting
|
# Display full table with formatting
|
||||||
display_df = sample_df[
|
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
|
||||||
["Grid", "Target", "Task", "Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]
|
|
||||||
].copy()
|
|
||||||
|
|
||||||
# Format numbers with commas
|
# Format numbers with commas
|
||||||
for col in ["Samples (Coverage)", "Samples (Labels)", "Samples (Both)"]:
|
display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}")
|
||||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
# Format coverage as percentage with 2 decimal places
|
||||||
|
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
|
||||||
|
|
||||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
def render_feature_count_comparison():
|
||||||
"""Render static comparison of feature counts across all grid configurations."""
|
"""Render static comparison of feature counts across all grid configurations."""
|
||||||
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
st.markdown("### Feature Count Comparison Across Grid Configurations")
|
||||||
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
|
||||||
|
|
||||||
# Get data from cache
|
# Get data from cache
|
||||||
comparison_df = cache.get_feature_count_df()
|
all_stats = load_all_default_dataset_statistics()
|
||||||
breakdown_df = cache.get_feature_breakdown_df()
|
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
|
||||||
|
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||||
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
|
||||||
|
|
||||||
# Create tabs for different comparison views
|
# Create tabs for different comparison views
|
||||||
|
|
@ -443,27 +226,24 @@ def render_feature_count_comparison(cache: DatasetAnalysisCache):
|
||||||
|
|
||||||
# Display full comparison table with formatting
|
# Display full comparison table with formatting
|
||||||
display_df = comparison_df[
|
display_df = comparison_df[
|
||||||
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples", "AlphaEarth"]
|
["Grid", "Total Features", "Data Sources", "Inference Cells", "Total Samples"]
|
||||||
].copy()
|
].copy()
|
||||||
|
|
||||||
# Format numbers with commas
|
# Format numbers with commas
|
||||||
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
for col in ["Total Features", "Inference Cells", "Total Samples"]:
|
||||||
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
|
||||||
|
|
||||||
# Format boolean as Yes/No
|
|
||||||
display_df["AlphaEarth"] = display_df["AlphaEarth"].apply(lambda x: "✓" if x else "✗")
|
|
||||||
|
|
||||||
st.dataframe(display_df, hide_index=True, width="stretch")
|
st.dataframe(display_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
@st.fragment
|
||||||
def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
def render_feature_count_explorer():
|
||||||
"""Render interactive detailed configuration explorer using fragments."""
|
"""Render interactive detailed configuration explorer using fragments."""
|
||||||
st.markdown("### Detailed Configuration Explorer")
|
st.markdown("### Detailed Configuration Explorer")
|
||||||
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
st.markdown("Select specific grid configuration and data sources for detailed statistics")
|
||||||
|
|
||||||
# Grid selection
|
# Grid selection
|
||||||
grid_options = [gc["grid_name"] for gc in cache.grid_configs]
|
grid_options = [gc.display_name for gc in grid_configs]
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
|
@ -486,83 +266,75 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find the selected grid config
|
# Find the selected grid config
|
||||||
selected_grid_config = next(gc for gc in cache.grid_configs if gc["grid_name"] == grid_level_combined)
|
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||||
grid = selected_grid_config["grid"]
|
|
||||||
level = selected_grid_config["level"]
|
# Get available members from the stats
|
||||||
disable_alphaearth = selected_grid_config["disable_alphaearth"]
|
all_stats = load_all_default_dataset_statistics()
|
||||||
|
stats = all_stats[selected_grid_config.id]
|
||||||
|
available_members = cast(list[L2SourceDataset], list(stats.members.keys()))
|
||||||
|
|
||||||
# Members selection
|
# Members selection
|
||||||
st.markdown("#### Select Data Sources")
|
st.markdown("#### Select Data Sources")
|
||||||
|
|
||||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
all_members = cast(
|
||||||
|
list[L2SourceDataset],
|
||||||
|
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||||
|
)
|
||||||
|
|
||||||
# Use columns for checkboxes
|
# Use columns for checkboxes
|
||||||
cols = st.columns(len(all_members))
|
cols = st.columns(len(all_members))
|
||||||
selected_members = []
|
selected_members: list[L2SourceDataset] = []
|
||||||
|
|
||||||
for idx, member in enumerate(all_members):
|
for idx, member in enumerate(all_members):
|
||||||
with cols[idx]:
|
with cols[idx]:
|
||||||
if member == "AlphaEarth" and disable_alphaearth:
|
default_value = member in available_members
|
||||||
st.checkbox(
|
if st.checkbox(member, value=default_value, key=f"feature_member_{member}"):
|
||||||
member,
|
selected_members.append(cast(L2SourceDataset, member))
|
||||||
value=False,
|
|
||||||
disabled=True,
|
|
||||||
help=f"Not available for {grid} level {level}",
|
|
||||||
key=f"feature_member_{member}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if st.checkbox(member, value=True, key=f"feature_member_{member}"):
|
|
||||||
selected_members.append(member)
|
|
||||||
|
|
||||||
# Show results if at least one member is selected
|
# Show results if at least one member is selected
|
||||||
if selected_members:
|
if selected_members:
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
|
|
||||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
# Get statistics from cache (already loaded)
|
||||||
|
grid_stats = all_stats[selected_grid_config.id]
|
||||||
|
|
||||||
with st.spinner("Computing dataset statistics..."):
|
# Filter to selected members only
|
||||||
stats = ensemble.get_stats()
|
selected_member_stats = {m: grid_stats.members[m] for m in selected_members if m in grid_stats.members}
|
||||||
|
|
||||||
|
# Calculate total features for selected members
|
||||||
|
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||||
|
|
||||||
|
# Get target stats
|
||||||
|
target_stats = grid_stats.target[cast(TargetDataset, target)]
|
||||||
|
|
||||||
# High-level metrics
|
# High-level metrics
|
||||||
col1, col2, col3, col4, col5 = st.columns(5)
|
col1, col2, col3, col4, col5 = st.columns(5)
|
||||||
with col1:
|
with col1:
|
||||||
st.metric("Total Features", f"{stats['total_features']:,}")
|
st.metric("Total Features", f"{total_features:,}")
|
||||||
with col2:
|
with col2:
|
||||||
# Calculate minimum cells across all data sources (for inference capability)
|
# Calculate minimum cells across all data sources (for inference capability)
|
||||||
min_cells = min(
|
min_cells = min(member_stats.dimensions["cell_ids"] for member_stats in selected_member_stats.values())
|
||||||
member_stats["dimensions"]["cell_ids"] # type: ignore[index]
|
|
||||||
for member_stats in stats["members"].values()
|
|
||||||
)
|
|
||||||
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
st.metric("Inference Cells", f"{min_cells:,}", help="Number of union of cells across all data sources")
|
||||||
with col3:
|
with col3:
|
||||||
st.metric("Data Sources", len(selected_members))
|
st.metric("Data Sources", len(selected_members))
|
||||||
with col4:
|
with col4:
|
||||||
st.metric("Total Samples", f"{stats['num_target_samples']:,}")
|
# Use binary task training cells as sample count
|
||||||
|
st.metric("Total Samples", f"{target_stats.training_cells['binary']:,}")
|
||||||
with col5:
|
with col5:
|
||||||
# Calculate total data points
|
# Calculate total data points
|
||||||
total_points = stats["total_features"] * stats["num_target_samples"]
|
total_points = total_features * target_stats.training_cells["binary"]
|
||||||
st.metric("Total Data Points", f"{total_points:,}")
|
st.metric("Total Data Points", f"{total_points:,}")
|
||||||
|
|
||||||
# Feature breakdown by source
|
# Feature breakdown by source
|
||||||
st.markdown("#### Feature Breakdown by Data Source")
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
|
|
||||||
breakdown_data = []
|
breakdown_data = []
|
||||||
for member, member_stats in stats["members"].items():
|
for member, member_stats in selected_member_stats.items():
|
||||||
breakdown_data.append(
|
breakdown_data.append(
|
||||||
{
|
{
|
||||||
"Data Source": member,
|
"Data Source": member,
|
||||||
"Number of Features": member_stats["num_features"],
|
"Number of Features": member_stats.feature_count,
|
||||||
"Percentage": f"{member_stats['num_features'] / stats['total_features'] * 100:.1f}%", # type: ignore[operator]
|
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add lon/lat
|
|
||||||
if ensemble.add_lonlat:
|
|
||||||
breakdown_data.append(
|
|
||||||
{
|
|
||||||
"Data Source": "Lon/Lat",
|
|
||||||
"Number of Features": 2,
|
|
||||||
"Percentage": f"{2 / stats['total_features'] * 100:.1f}%",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -589,20 +361,20 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||||
|
|
||||||
# Detailed member information
|
# Detailed member information
|
||||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||||
for member, member_stats in stats["members"].items():
|
for member, member_stats in selected_member_stats.items():
|
||||||
st.markdown(f"### {member}")
|
st.markdown(f"### {member}")
|
||||||
|
|
||||||
metric_cols = st.columns(4)
|
metric_cols = st.columns(4)
|
||||||
with metric_cols[0]:
|
with metric_cols[0]:
|
||||||
st.metric("Features", member_stats["num_features"])
|
st.metric("Features", member_stats.feature_count)
|
||||||
with metric_cols[1]:
|
with metric_cols[1]:
|
||||||
st.metric("Variables", member_stats["num_variables"])
|
st.metric("Variables", len(member_stats.variable_names))
|
||||||
with metric_cols[2]:
|
with metric_cols[2]:
|
||||||
dim_str = " x ".join([str(dim) for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
dim_str = " x ".join([str(dim) for dim in member_stats.dimensions.values()])
|
||||||
st.metric("Shape", dim_str)
|
st.metric("Shape", dim_str)
|
||||||
with metric_cols[3]:
|
with metric_cols[3]:
|
||||||
total_points = 1
|
total_points = 1
|
||||||
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
for dim_size in member_stats.dimensions.values():
|
||||||
total_points *= dim_size
|
total_points *= dim_size
|
||||||
st.metric("Data Points", f"{total_points:,}")
|
st.metric("Data Points", f"{total_points:,}")
|
||||||
|
|
||||||
|
|
@ -612,7 +384,7 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||||
[
|
[
|
||||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||||
for v in member_stats["variables"] # type: ignore[union-attr]
|
for v in member_stats.variable_names
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
st.markdown(vars_html, unsafe_allow_html=True)
|
st.markdown(vars_html, unsafe_allow_html=True)
|
||||||
|
|
@ -624,17 +396,21 @@ def render_feature_count_explorer(cache: DatasetAnalysisCache):
|
||||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||||
f"{dim_name}: {dim_size}</span>"
|
f"{dim_name}: {dim_size}</span>"
|
||||||
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
for dim_name, dim_size in member_stats.dimensions.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Size on disk
|
||||||
|
size_mb = member_stats.size_bytes / (1024 * 1024)
|
||||||
|
st.markdown(f"**Size on Disk:** {size_mb:.2f} MB")
|
||||||
|
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
else:
|
else:
|
||||||
st.info("👆 Select at least one data source to see feature statistics")
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
||||||
def render_feature_count_section(cache: DatasetAnalysisCache):
|
def render_feature_count_section():
|
||||||
"""Render the feature count section with comparison and explorer."""
|
"""Render the feature count section with comparison and explorer."""
|
||||||
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
st.subheader("🔢 Feature Counts by Dataset Configuration")
|
||||||
|
|
||||||
|
|
@ -646,30 +422,26 @@ def render_feature_count_section(cache: DatasetAnalysisCache):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Static comparison across all grids
|
# Static comparison across all grids
|
||||||
render_feature_count_comparison(cache)
|
render_feature_count_comparison()
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Interactive explorer for detailed analysis
|
# Interactive explorer for detailed analysis
|
||||||
render_feature_count_explorer(cache)
|
render_feature_count_explorer()
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_analysis():
|
def render_dataset_analysis():
|
||||||
"""Render the dataset analysis section with sample and feature counts."""
|
"""Render the dataset analysis section with sample and feature counts."""
|
||||||
st.header("📈 Dataset Analysis")
|
st.header("📈 Dataset Analysis")
|
||||||
|
|
||||||
# Load all data once and cache it
|
|
||||||
with st.spinner("Loading dataset analysis data..."):
|
|
||||||
cache = load_dataset_analysis_data()
|
|
||||||
|
|
||||||
# Create tabs for the two different analyses
|
# Create tabs for the two different analyses
|
||||||
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
|
||||||
|
|
||||||
with analysis_tabs[0]:
|
with analysis_tabs[0]:
|
||||||
render_sample_count_overview(cache)
|
render_sample_count_overview()
|
||||||
|
|
||||||
with analysis_tabs[1]:
|
with analysis_tabs[1]:
|
||||||
render_feature_count_section(cache)
|
render_feature_count_section()
|
||||||
|
|
||||||
|
|
||||||
def render_training_results_summary(training_results):
|
def render_training_results_summary(training_results):
|
||||||
|
|
@ -678,15 +450,15 @@ def render_training_results_summary(training_results):
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
|
tasks = {tr.settings.task for tr in training_results}
|
||||||
st.metric("Tasks", len(tasks))
|
st.metric("Tasks", len(tasks))
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
|
grids = {tr.settings.grid for tr in training_results}
|
||||||
st.metric("Grid Types", len(grids))
|
st.metric("Grid Types", len(grids))
|
||||||
|
|
||||||
with col3:
|
with col3:
|
||||||
models = {tr.settings.get("model", "Unknown") for tr in training_results}
|
models = {tr.settings.model for tr in training_results}
|
||||||
st.metric("Model Types", len(models))
|
st.metric("Model Types", len(models))
|
||||||
|
|
||||||
with col4:
|
with col4:
|
||||||
|
|
@ -725,10 +497,10 @@ def render_experiment_results(training_results):
|
||||||
summary_data.append(
|
summary_data.append(
|
||||||
{
|
{
|
||||||
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
|
||||||
"Task": tr.settings.get("task", "Unknown"),
|
"Task": tr.settings.task,
|
||||||
"Grid": tr.settings.get("grid", "Unknown"),
|
"Grid": tr.settings.grid,
|
||||||
"Level": tr.settings.get("level", "Unknown"),
|
"Level": tr.settings.level,
|
||||||
"Model": tr.settings.get("model", "Unknown"),
|
"Model": tr.settings.model,
|
||||||
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
|
||||||
"Trials": len(tr.results),
|
"Trials": len(tr.results),
|
||||||
"Path": str(tr.path.name),
|
"Path": str(tr.path.name),
|
||||||
|
|
@ -750,17 +522,20 @@ def render_experiment_results(training_results):
|
||||||
st.subheader("Individual Experiment Details")
|
st.subheader("Individual Experiment Details")
|
||||||
|
|
||||||
for tr in training_results:
|
for tr in training_results:
|
||||||
with st.expander(tr.get_display_name("task_first")):
|
display_name = (
|
||||||
|
f"{tr.display_info.task} | {tr.display_info.model} | {tr.display_info.grid}{tr.display_info.level}"
|
||||||
|
)
|
||||||
|
with st.expander(display_name):
|
||||||
col1, col2 = st.columns([1, 2])
|
col1, col2 = st.columns([1, 2])
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
st.write("**Configuration:**")
|
st.write("**Configuration:**")
|
||||||
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
|
st.write(f"- **Task:** {tr.settings.task}")
|
||||||
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
|
st.write(f"- **Grid:** {tr.settings.grid}")
|
||||||
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
|
st.write(f"- **Level:** {tr.settings.level}")
|
||||||
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
|
st.write(f"- **Model:** {tr.settings.model}")
|
||||||
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
|
st.write(f"- **CV Splits:** {tr.settings.cv_splits}")
|
||||||
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
|
st.write(f"- **Classes:** {tr.settings.classes}")
|
||||||
|
|
||||||
st.write("\n**Files:**")
|
st.write("\n**Files:**")
|
||||||
st.write("- 📊 search_results.parquet")
|
st.write("- 📊 search_results.parquet")
|
||||||
|
|
@ -844,3 +619,5 @@ def render_overview_page():
|
||||||
render_dataset_analysis()
|
render_dataset_analysis()
|
||||||
|
|
||||||
st.balloons()
|
st.balloons()
|
||||||
|
|
||||||
|
stopwatch.summary()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Shared types used across the entropice codebase."""
|
"""Shared types used across the entropice codebase."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
type Grid = Literal["hex", "healpix"]
|
type Grid = Literal["hex", "healpix"]
|
||||||
|
|
@ -9,3 +10,60 @@ type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||||
type Task = Literal["binary", "count", "density"]
|
type Task = Literal["binary", "count", "density"]
|
||||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GridConfig:
|
||||||
|
"""Grid configuration specification with metadata."""
|
||||||
|
|
||||||
|
grid: Grid
|
||||||
|
level: int
|
||||||
|
id: GridLevel
|
||||||
|
display_name: str
|
||||||
|
sort_key: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_grid_level(cls, grid_level: GridLevel) -> "GridConfig":
|
||||||
|
"""Create a GridConfig from a GridLevel string."""
|
||||||
|
if grid_level.startswith("hex"):
|
||||||
|
grid = "hex"
|
||||||
|
level = int(grid_level[3:])
|
||||||
|
elif grid_level.startswith("healpix"):
|
||||||
|
grid = "healpix"
|
||||||
|
level = int(grid_level[7:])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid grid level: {grid_level}")
|
||||||
|
|
||||||
|
display_name = f"{grid.capitalize()}-{level}"
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
grid=grid,
|
||||||
|
level=level,
|
||||||
|
id=grid_level,
|
||||||
|
display_name=display_name,
|
||||||
|
sort_key=f"{grid}_{level:02d}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
||||||
|
all_tasks: list[Task] = ["binary", "count", "density"]
|
||||||
|
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
||||||
|
all_l2_source_datasets: list[L2SourceDataset] = [
|
||||||
|
"ArcticDEM",
|
||||||
|
"ERA5-shoulder",
|
||||||
|
"ERA5-seasonal",
|
||||||
|
"ERA5-yearly",
|
||||||
|
"AlphaEarth",
|
||||||
|
]
|
||||||
|
|
||||||
|
grid_configs: list[GridConfig] = [
|
||||||
|
GridConfig.from_grid_level("hex3"),
|
||||||
|
GridConfig.from_grid_level("hex4"),
|
||||||
|
GridConfig.from_grid_level("hex5"),
|
||||||
|
GridConfig.from_grid_level("hex6"),
|
||||||
|
GridConfig.from_grid_level("healpix6"),
|
||||||
|
GridConfig.from_grid_level("healpix7"),
|
||||||
|
GridConfig.from_grid_level("healpix8"),
|
||||||
|
GridConfig.from_grid_level("healpix9"),
|
||||||
|
GridConfig.from_grid_level("healpix10"),
|
||||||
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue