diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 8e4555f..713f7d0 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -11,12 +11,11 @@ Pages:
import streamlit as st
-# from entropice.dashboard.views.inference_page import render_inference_page
-# from entropice.dashboard.views.model_state_page import render_model_state_page
+from entropice.dashboard.views.inference_page import render_inference_page
+from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
-
-# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
-# from entropice.dashboard.views.training_data_page import render_training_data_page
+from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
+from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@@ -25,17 +24,17 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
- # training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
- # training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
- # model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
- # inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
+ training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
+ training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
+ model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
+ inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
- # "Training": [training_data_page, training_analysis_page],
- # "Model State": [model_state_page],
- # "Inference": [inference_page],
+ "Training": [training_data_page, training_analysis_page],
+ "Model State": [model_state_page],
+ "Inference": [inference_page],
}
)
pg.run()
diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py
index e63620f..79d1867 100644
--- a/src/entropice/dashboard/plots/hyperparameter_analysis.py
+++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py
@@ -143,8 +143,8 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
# Extract scale information from settings if available
param_scales = {}
- if settings and "param_grid" in settings:
- param_grid = settings["param_grid"]
+ if settings and hasattr(settings, "param_grid"):
+ param_grid = settings.param_grid
for param_name, param_config in param_grid.items():
if isinstance(param_config, dict) and "distribution" in param_config:
# loguniform distribution indicates log scale
@@ -1181,10 +1181,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
preds_gdf = gpd.read_parquet(preds_file)
# Get task and target information from settings
- task = settings.get("task", "binary")
- target = settings.get("target", "darts_rts")
- grid = settings.get("grid", "hex")
- level = settings.get("level", 3)
+ task = settings.task
+ target = settings.target
+ grid = settings.grid
+ level = settings.level
# Create dataset ensemble to get true labels
# We need to load the target data to get true labels
diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py
index a229161..29804f7 100644
--- a/src/entropice/dashboard/plots/inference.py
+++ b/src/entropice/dashboard/plots/inference.py
@@ -8,7 +8,7 @@ import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_palette
-from entropice.dashboard.utils.data import TrainingResult
+from entropice.dashboard.utils.loaders import TrainingResult
def _fix_hex_geometry(geom):
@@ -197,8 +197,8 @@ def render_inference_map(result: TrainingResult):
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
# Get settings
- task = result.settings.get("task", "binary")
- grid = result.settings.get("grid", "hex")
+ task = result.settings.task
+ grid = result.settings.grid
# Create controls in columns
col1, col2, col3 = st.columns([2, 2, 1])
diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py
index 34233a1..bb3b467 100644
--- a/src/entropice/dashboard/utils/loaders.py
+++ b/src/entropice/dashboard/utils/loaders.py
@@ -147,10 +147,11 @@ def load_all_training_data(
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
+ dataset = e.create(filter_target_col=e.covcol)
return {
- "binary": e.create_cat_training_dataset("binary", device="cpu"),
- "count": e.create_cat_training_dataset("count", device="cpu"),
- "density": e.create_cat_training_dataset("density", device="cpu"),
+ "binary": e._cat_and_split(dataset, "binary", device="cpu"),
+ "count": e._cat_and_split(dataset, "count", device="cpu"),
+ "density": e._cat_and_split(dataset, "density", device="cpu"),
}
diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py
index 20d1759..b5f2835 100644
--- a/src/entropice/dashboard/views/inference_page.py
+++ b/src/entropice/dashboard/views/inference_page.py
@@ -1,7 +1,8 @@
"""Inference page: Visualization of model inference results across the study region."""
+import geopandas as gpd
import streamlit as st
-from entropice.dashboard.utils.data import load_all_training_results
+from stopuhr import stopwatch
from entropice.dashboard.plots.inference import (
render_class_comparison,
@@ -10,66 +11,177 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics,
render_spatial_distribution_stats,
)
+from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
+
+
+@st.fragment
+def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
+ """Render sidebar for training run selection.
+
+ Args:
+ training_results: List of available TrainingResult objects.
+
+ Returns:
+ Selected TrainingResult object.
+
+ """
+ st.header("Select Training Run")
+
+ # Create selection options with task-first naming
+ training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
+
+ selected_name = st.selectbox(
+ "Training Run",
+ options=list(training_options.keys()),
+ index=0,
+ help="Select a training run to view inference results",
+ key="inference_run_select",
+ )
+
+ selected_result = training_options[selected_name]
+
+ st.divider()
+
+ # Show run information in sidebar
+ st.subheader("Run Information")
+
+ st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}")
+ st.markdown(f"**Model:** {selected_result.settings.model.upper()}")
+ st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}")
+ st.markdown(f"**Level:** {selected_result.settings.level}")
+ st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}")
+
+ return selected_result
+
+
+def render_run_information(selected_result: TrainingResult):
+ """Render training run configuration overview.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+
+ """
+ st.header("📋 Run Configuration")
+
+ col1, col2, col3, col4, col5 = st.columns(5)
+
+ with col1:
+ st.metric("Task", selected_result.settings.task.capitalize())
+
+ with col2:
+ st.metric("Model", selected_result.settings.model.upper())
+
+ with col3:
+ st.metric("Grid", selected_result.settings.grid.capitalize())
+
+ with col4:
+ st.metric("Level", selected_result.settings.level)
+
+ with col5:
+ st.metric("Target", selected_result.settings.target.replace("darts_", ""))
+
+
+def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str):
+ """Render inference summary statistics section.
+
+ Args:
+ predictions_gdf: GeoDataFrame with predictions.
+ task: Task type ('binary', 'count', 'density').
+
+ """
+ st.header("📊 Inference Summary")
+ render_inference_statistics(predictions_gdf, task)
+
+
+def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame):
+ """Render spatial coverage statistics section.
+
+ Args:
+ predictions_gdf: GeoDataFrame with predictions.
+
+ """
+ st.header("🌍 Spatial Coverage")
+ render_spatial_distribution_stats(predictions_gdf)
+
+
+def render_map_visualization_section(selected_result: TrainingResult):
+ """Render 3D map visualization section.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+
+ """
+ st.header("🗺️ Interactive Prediction Map")
+ st.markdown(
+ """
+ 3D visualization of predictions across the study region. The map shows predicted
+ classes with color coding and spatial distribution of model outputs.
+ """
+ )
+ render_inference_map(selected_result)
+
+
+def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str):
+ """Render class distribution histogram section.
+
+ Args:
+ predictions_gdf: GeoDataFrame with predictions.
+ task: Task type ('binary', 'count', 'density').
+
+ """
+ st.header("📈 Class Distribution")
+ st.markdown("Distribution of predicted classes across all inference cells.")
+ render_class_distribution_histogram(predictions_gdf, task)
+
+
+def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str):
+ """Render class comparison analysis section.
+
+ Args:
+ predictions_gdf: GeoDataFrame with predictions.
+ task: Task type ('binary', 'count', 'density').
+
+ """
+ st.header("🔍 Class Comparison Analysis")
+ st.markdown(
+ """
+ Detailed comparison of predicted classes showing probability distributions
+ and confidence metrics for different class predictions.
+ """
+ )
+ render_class_comparison(predictions_gdf, task)
def render_inference_page():
"""Render the Inference page of the dashboard."""
st.title("🗺️ Inference Results")
+ st.markdown(
+ """
+ Explore spatial predictions from trained models across the Arctic permafrost region.
+ Select a training run from the sidebar to visualize prediction maps, class distributions,
+ and spatial coverage statistics.
+ """
+ )
+
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
- st.info("Run training using: `pixi run python -m entropice.training`")
+ st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
+ st.success(f"Found **{len(training_results)}** training result(s)")
+
+ st.divider()
+
# Sidebar: Training run selection
with st.sidebar:
- st.header("Select Training Run")
+ selected_result = render_sidebar_selection(training_results)
- # Create selection options with task-first naming
- training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
-
- selected_name = st.selectbox(
- "Training Run",
- options=list(training_options.keys()),
- index=0,
- help="Select a training run to view inference results",
- )
-
- selected_result = training_options[selected_name]
-
- st.divider()
-
- # Show run information in sidebar
- st.subheader("Run Information")
-
- settings = selected_result.settings
-
- st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
- st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
- st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
- st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
- st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
-
- # Main content area - Run Information at the top
- st.header("📋 Run Configuration")
-
- col1, col2, col3, col4, col5 = st.columns(5)
- with col1:
- st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
- with col2:
- st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
- with col3:
- st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
- with col4:
- st.metric("Level", selected_result.settings.get("level", "Unknown"))
- with col5:
- st.metric(
- "Target",
- selected_result.settings.get("target", "Unknown").replace("darts_", ""),
- )
+ # Main content area - Run Information
+ render_run_information(selected_result)
st.divider()
@@ -80,31 +192,33 @@ def render_inference_page():
st.info("Inference results are generated automatically during training.")
return
- # Load predictions for statistics
- import geopandas as gpd
-
- predictions_gdf = gpd.read_parquet(preds_file)
- task = selected_result.settings.get("task", "binary")
+ # Load predictions
+ with st.spinner("Loading inference results..."):
+ predictions_gdf = gpd.read_parquet(preds_file)
+ task = selected_result.settings.task
# Inference Statistics Section
- render_inference_statistics(predictions_gdf, task)
+ render_inference_statistics_section(predictions_gdf, task)
st.divider()
# Spatial Coverage Section
- render_spatial_distribution_stats(predictions_gdf)
+ render_spatial_coverage_section(predictions_gdf)
st.divider()
# 3D Map Visualization Section
- render_inference_map(selected_result)
+ render_map_visualization_section(selected_result)
st.divider()
# Class Distribution Section
- render_class_distribution_histogram(predictions_gdf, task)
+ render_class_distribution_section(predictions_gdf, task)
st.divider()
# Class Comparison Section
- render_class_comparison(predictions_gdf, task)
+ render_class_comparison_section(predictions_gdf, task)
+
+ st.balloons()
+ stopwatch.summary()
diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py
index 5cd90f5..c536d99 100644
--- a/src/entropice/dashboard/views/model_state_page.py
+++ b/src/entropice/dashboard/views/model_state_page.py
@@ -1,16 +1,8 @@
-"""Model State page for the Entropice dashboard."""
+"""Model State page: Visualization of model internal state and feature importance."""
import streamlit as st
import xarray as xr
-from entropice.dashboard.utils.data import (
- extract_arcticdem_features,
- extract_common_features,
- extract_embedding_features,
- extract_era5_features,
- get_members_from_settings,
- load_all_training_results,
-)
-from entropice.dashboard.utils.training import load_model_state
+from stopuhr import stopwatch
from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap,
@@ -26,46 +18,64 @@ from entropice.dashboard.plots.model_state import (
plot_top_features,
)
from entropice.dashboard.utils.colors import generate_unified_colormap
+from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
+from entropice.dashboard.utils.unsembler import (
+ extract_arcticdem_features,
+ extract_common_features,
+ extract_embedding_features,
+ extract_era5_features,
+)
+from entropice.utils.types import L2SourceDataset
-def render_model_state_page():
- """Render the Model State page of the dashboard."""
- st.title("Model State")
- st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
+def get_members_from_settings(settings) -> list[L2SourceDataset]:
+ """Extract dataset members from training settings.
- # Load available training results
- training_results = load_all_training_results()
+ Args:
+ settings: TrainingSettings object containing dataset configuration.
- if not training_results:
- st.error("No training results found. Please run a training search first.")
- return
+ Returns:
+ List of L2SourceDataset members used in training.
- # Sidebar: Training run selection
- with st.sidebar:
- st.header("Select Training Run")
+ """
+ return settings.members
- # Result selection with model-first naming
- result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
- selected_name = st.selectbox(
- "Training Run",
- options=list(result_options.keys()),
- help="Choose a training result to visualize model state",
- )
- selected_result = result_options[selected_name]
- st.divider()
+@st.fragment
+def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
+ """Render sidebar for training run selection.
- # Get the model type from settings
- model_type = selected_result.settings.get("model", "espa")
+ Args:
+ training_results: List of available TrainingResult objects.
- # Load model state
- with st.spinner("Loading model state..."):
- model_state = load_model_state(selected_result)
- if model_state is None:
- st.error("Could not load model state for this result.")
- return
+ Returns:
+ Selected TrainingResult object.
- # Display basic model state info
+ """
+ st.header("Select Training Run")
+
+ # Result selection with task-first naming
+ result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
+ selected_name = st.selectbox(
+ "Training Run",
+ options=list(result_options.keys()),
+ index=0,
+ help="Choose a training result to visualize model state",
+ key="model_state_training_run_select",
+ )
+ selected_result = result_options[selected_name]
+
+ return selected_result
+
+
+def render_model_info(model_state: xr.Dataset, model_type: str):
+ """Render basic model state information.
+
+ Args:
+ model_state: Xarray dataset containing model state.
+ model_type: Type of model (espa, xgboost, rf, knn).
+
+ """
with st.expander("Model State Information", expanded=False):
st.write(f"**Model Type:** {model_type.upper()}")
st.write(f"**Variables:** {list(model_state.data_vars)}")
@@ -73,15 +83,23 @@ def render_model_state_page():
st.write(f"**Coordinates:** {list(model_state.coords)}")
st.write(f"**Attributes:** {dict(model_state.attrs)}")
- # Display dataset members summary
- st.header("📊 Training Data Summary")
- members = get_members_from_settings(selected_result.settings)
- st.markdown(f"""
+def render_training_data_summary(members: list[L2SourceDataset]):
+ """Render summary of training data sources.
+
+ Args:
+ members: List of dataset members used in training.
+
+ """
+ st.header("📊 Training Data Summary")
+
+ st.markdown(
+ f"""
**Dataset Members Used in Training:** {len(members)}
The following data sources were used to train this model:
- """)
+ """
+ )
# Create a nice display of members with emojis
member_display = {
@@ -98,6 +116,52 @@ def render_model_state_page():
display_name = member_display.get(member, f"📁 {member}")
st.info(display_name)
+
+def render_model_state_page():
+ """Render the Model State page of the dashboard."""
+ st.title("🔬 Model State")
+ st.markdown(
+ """
+ Comprehensive visualization of the best model's internal state and feature importance.
+ Select a training run from the sidebar to explore model parameters, feature weights,
+ and data source contributions.
+ """
+ )
+
+ # Load available training results
+ training_results = load_all_training_results()
+
+ if not training_results:
+ st.warning("No training results found. Please run some training experiments first.")
+ st.info("Run training using: `pixi run python -m entropice.ml.training`")
+ return
+
+ st.success(f"Found **{len(training_results)}** training result(s)")
+
+ st.divider()
+
+ # Sidebar: Training run selection
+ with st.sidebar:
+ selected_result = render_sidebar_selection(training_results)
+
+ # Get the model type from settings
+ model_type = selected_result.settings.model
+
+ # Load model state
+ with st.spinner("Loading model state..."):
+ model_state = selected_result.load_model_state()
+ if model_state is None:
+ st.error("Could not load model state for this result.")
+ st.info("The model state file (best_estimator_state.nc) may be missing from the training results.")
+ return
+
+ # Display basic model state info
+ render_model_info(model_state, model_type)
+
+ # Display dataset members summary
+ members = get_members_from_settings(selected_result.settings)
+ render_training_data_summary(members)
+
st.divider()
# Render model-specific visualizations
@@ -112,9 +176,18 @@ def render_model_state_page():
else:
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
+ st.balloons()
+ stopwatch.summary()
-def render_espa_model_state(model_state: xr.Dataset, selected_result):
- """Render visualizations for ESPA model."""
+
+def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
+ """Render visualizations for ESPA model.
+
+ Args:
+ model_state: Xarray dataset containing ESPA model state.
+ selected_result: TrainingResult object containing training configuration.
+
+ """
# Scale feature weights by number of features
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
@@ -143,8 +216,9 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
common_feature_array = extract_common_features(model_state)
- # Generate unified colormaps
- _, _, altair_colors = generate_unified_colormap(selected_result.settings)
+ # Generate unified colormaps (convert dataclass to dict)
+ settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes}
+ _, _, altair_colors = generate_unified_colormap(settings_dict)
# Feature importance section
st.header("Feature Importance")
@@ -255,8 +329,14 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
-def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
- """Render visualizations for XGBoost model."""
+def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
+ """Render visualizations for XGBoost model.
+
+ Args:
+ model_state: Xarray dataset containing XGBoost model state.
+ selected_result: TrainingResult object containing training configuration.
+
+ """
from entropice.dashboard.plots.model_state import (
plot_xgboost_feature_importance,
plot_xgboost_importance_comparison,
@@ -382,8 +462,14 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
-def render_rf_model_state(model_state: xr.Dataset, selected_result):
- """Render visualizations for Random Forest model."""
+def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
+ """Render visualizations for Random Forest model.
+
+ Args:
+ model_state: Xarray dataset containing Random Forest model state.
+ selected_result: TrainingResult object containing training configuration.
+
+ """
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
st.header("🌳 Random Forest Model Analysis")
@@ -529,8 +615,14 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
-def render_knn_model_state(model_state: xr.Dataset, selected_result):
- """Render visualizations for KNN model."""
+def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
+ """Render visualizations for KNN model.
+
+ Args:
+ model_state: Xarray dataset containing KNN model state.
+ selected_result: TrainingResult object containing training configuration.
+
+ """
st.header("🔍 K-Nearest Neighbors Model Analysis")
st.markdown(
"""
@@ -568,8 +660,13 @@ def render_knn_model_state(model_state: xr.Dataset, selected_result):
# Helper functions for embedding/era5/common features
-def render_embedding_features(embedding_feature_array):
- """Render embedding feature visualizations."""
+def render_embedding_features(embedding_feature_array: xr.DataArray):
+ """Render embedding feature visualizations.
+
+ Args:
+ embedding_feature_array: DataArray containing AlphaEarth embedding feature weights.
+
+ """
with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis")
st.markdown(
@@ -619,7 +716,7 @@ def render_embedding_features(embedding_feature_array):
st.dataframe(top_emb, width="stretch")
-def render_era5_features(era5_feature_array, temporal_group: str = ""):
+def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""):
"""Render ERA5 feature visualizations.
Args:
@@ -631,9 +728,10 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
with st.container(border=True):
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
+ temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else ""
st.markdown(
f"""
- Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods
+ Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods
are most important for the model predictions.
"""
)
@@ -709,8 +807,13 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
st.dataframe(top_era5, width="stretch")
-def render_arcticdem_features(arcticdem_feature_array):
- """Render ArcticDEM feature visualizations."""
+def render_arcticdem_features(arcticdem_feature_array: xr.DataArray):
+ """Render ArcticDEM feature visualizations.
+
+ Args:
+ arcticdem_feature_array: DataArray containing ArcticDEM feature weights.
+
+ """
with st.container(border=True):
st.header("🏔️ ArcticDEM Feature Analysis")
st.markdown(
@@ -758,8 +861,13 @@ def render_arcticdem_features(arcticdem_feature_array):
st.dataframe(top_arcticdem, width="stretch")
-def render_common_features(common_feature_array):
- """Render common feature visualizations."""
+def render_common_features(common_feature_array: xr.DataArray):
+ """Render common feature visualizations.
+
+ Args:
+ common_feature_array: DataArray containing common feature weights.
+
+ """
with st.container(border=True):
st.header("🗺️ Common Feature Analysis")
st.markdown(
diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py
index 55c0220..8cf243b 100644
--- a/src/entropice/dashboard/views/overview_page.py
+++ b/src/entropice/dashboard/views/overview_page.py
@@ -658,5 +658,4 @@ def render_overview_page():
render_dataset_analysis()
st.balloons()
-
stopwatch.summary()
diff --git a/src/entropice/dashboard/views/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py
index e2484ff..8609331 100644
--- a/src/entropice/dashboard/views/training_analysis_page.py
+++ b/src/entropice/dashboard/views/training_analysis_page.py
@@ -1,6 +1,7 @@
"""Training Results Analysis page: Analysis of training results and model performance."""
import streamlit as st
+from stopuhr import stopwatch
from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
@@ -12,169 +13,176 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
render_performance_summary,
render_top_configurations,
)
-from entropice.dashboard.utils.data import load_all_training_results
-from entropice.dashboard.utils.training import (
- format_metric_name,
- get_available_metrics,
- get_cv_statistics,
- get_parameter_space_summary,
-)
+from entropice.dashboard.utils.formatters import format_metric_name
+from entropice.dashboard.utils.loaders import load_all_training_results
+from entropice.dashboard.utils.stats import CVResultsStatistics
-def render_training_analysis_page():
- """Render the Training Results Analysis page of the dashboard."""
- st.title("🦾 Training Results Analysis")
+@st.fragment
+def render_analysis_settings_sidebar(training_results):
+ """Render sidebar for training run and analysis settings selection.
- # Load all available training results
- training_results = load_all_training_results()
+ Args:
+ training_results: List of available TrainingResult objects.
- if not training_results:
- st.warning("No training results found. Please run some training experiments first.")
- st.info("Run training using: `pixi run python -m entropice.training`")
- return
+ Returns:
+ Tuple of (selected_result, selected_metric, refit_metric, top_n).
- # Sidebar: Training run selection
- with st.sidebar:
- st.header("Select Training Run")
+ """
+ st.header("Select Training Run")
- # Create selection options with task-first naming
- training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
+ # Create selection options with task-first naming
+ training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
- selected_name = st.selectbox(
- "Training Run",
- options=list(training_options.keys()),
- index=0,
- help="Select a training run to analyze",
- )
+ selected_name = st.selectbox(
+ "Training Run",
+ options=list(training_options.keys()),
+ index=0,
+ help="Select a training run to analyze",
+ key="training_run_select",
+ )
- selected_result = training_options[selected_name]
+ selected_result = training_options[selected_name]
- st.divider()
+ st.divider()
- # Metric selection for detailed analysis
- st.subheader("Analysis Settings")
+ # Metric selection for detailed analysis
+ st.subheader("Analysis Settings")
- available_metrics = get_available_metrics(selected_result.results)
+ available_metrics = selected_result.available_metrics
- # Try to get refit metric from settings
- refit_metric = selected_result.settings.get("refit_metric")
+ # Try to get refit metric from settings
+ refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
- if not refit_metric or refit_metric not in available_metrics:
- # Infer from task or use first available metric
- task = selected_result.settings.get("task", "binary")
- if task == "binary" and "f1" in available_metrics:
- refit_metric = "f1"
- elif "f1_weighted" in available_metrics:
- refit_metric = "f1_weighted"
- elif "accuracy" in available_metrics:
- refit_metric = "accuracy"
- elif available_metrics:
- refit_metric = available_metrics[0]
- else:
- st.error("No metrics found in results.")
- return
-
- if refit_metric in available_metrics:
- default_metric_idx = available_metrics.index(refit_metric)
+ if not refit_metric or refit_metric not in available_metrics:
+ # Infer from task or use first available metric
+ task = selected_result.settings.task
+ if task == "binary" and "f1" in available_metrics:
+ refit_metric = "f1"
+ elif "f1_weighted" in available_metrics:
+ refit_metric = "f1_weighted"
+ elif "accuracy" in available_metrics:
+ refit_metric = "accuracy"
+ elif available_metrics:
+ refit_metric = available_metrics[0]
else:
- default_metric_idx = 0
+ st.error("No metrics found in results.")
+ return None, None, None, None
- selected_metric = st.selectbox(
- "Primary Metric for Analysis",
- options=available_metrics,
- index=default_metric_idx,
- format_func=format_metric_name,
- help="Select the metric to focus on for detailed analysis",
- )
+ if refit_metric in available_metrics:
+ default_metric_idx = available_metrics.index(refit_metric)
+ else:
+ default_metric_idx = 0
- # Top N configurations
- top_n = st.slider(
- "Top N Configurations",
- min_value=5,
- max_value=50,
- value=10,
- step=5,
- help="Number of top configurations to display",
- )
+ selected_metric = st.selectbox(
+ "Primary Metric for Analysis",
+ options=available_metrics,
+ index=default_metric_idx,
+ format_func=format_metric_name,
+ help="Select the metric to focus on for detailed analysis",
+ key="metric_select",
+ )
- # Main content area - Run Information at the top
+ # Top N configurations
+ top_n = st.slider(
+ "Top N Configurations",
+ min_value=5,
+ max_value=50,
+ value=10,
+ step=5,
+ help="Number of top configurations to display",
+ key="top_n_slider",
+ )
+
+ return selected_result, selected_metric, refit_metric, top_n
+
+
+def render_run_information(selected_result, refit_metric):
+ """Render training run configuration overview.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+ refit_metric: The refit metric used for model selection.
+
+ """
st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6)
with col1:
- st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
+ st.metric("Task", selected_result.settings.task.capitalize())
with col2:
- st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
+ st.metric("Grid", selected_result.settings.grid.capitalize())
with col3:
- st.metric("Level", selected_result.settings.get("level", "Unknown"))
+ st.metric("Level", selected_result.settings.level)
with col4:
- st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
+ st.metric("Model", selected_result.settings.model.upper())
with col5:
st.metric("Trials", len(selected_result.results))
with col6:
- st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
+ st.metric("CV Splits", selected_result.settings.cv_splits)
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
- st.divider()
- # Main content area
- results = selected_result.results
- settings = selected_result.settings
+def render_cv_statistics_section(selected_result, selected_metric):
+ """Render cross-validation statistics for selected metric.
- # Performance Summary Section
- st.header("📊 Performance Overview")
+ Args:
+ selected_result: The selected TrainingResult object.
+ selected_metric: The metric to display statistics for.
- render_performance_summary(results, refit_metric)
-
- st.divider()
-
- # Confusion Matrix Map Section
- st.header("🗺️ Prediction Results Map")
-
- render_confusion_matrix_map(selected_result.path, settings)
-
- st.divider()
-
- # Quick Statistics
+ """
st.header("📈 Cross-Validation Statistics")
- cv_stats = get_cv_statistics(results, selected_metric)
+ from entropice.dashboard.utils.stats import CVMetricStatistics
- if cv_stats:
- col1, col2, col3, col4, col5 = st.columns(5)
+ cv_stats = CVMetricStatistics.compute(selected_result, selected_metric)
- with col1:
- st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
+ col1, col2, col3, col4, col5 = st.columns(5)
- with col2:
- st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
+ with col1:
+ st.metric("Best Score", f"{cv_stats.best_score:.4f}")
- with col3:
- st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
+ with col2:
+ st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
- with col4:
- st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
+ with col3:
+ st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
- with col5:
- st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
+ with col4:
+ st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
- if "mean_cv_std" in cv_stats:
- st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
+ with col5:
+ st.metric("Median Score", f"{cv_stats.median_score:.4f}")
- st.divider()
+ if cv_stats.mean_cv_std is not None:
+ st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
- # Parameter Space Analysis
+
+def render_parameter_space_section(selected_result, selected_metric):
+ """Render parameter space analysis section.
+
+ Args:
+ selected_result: The selected TrainingResult object.
+ selected_metric: The metric to analyze parameters against.
+
+ """
st.header("🔍 Parameter Space Analysis")
+ # Compute CV results statistics
+ cv_results_stats = CVResultsStatistics.compute(selected_result)
+
# Show parameter space summary
with st.expander("📋 Parameter Space Summary", expanded=False):
- param_summary = get_parameter_space_summary(results)
- if not param_summary.empty:
- st.dataframe(param_summary, hide_index=True, width="stretch")
+ param_summary_df = cv_results_stats.parameters_to_dataframe()
+ if not param_summary_df.empty:
+ st.dataframe(param_summary_df, hide_index=True, width="stretch")
else:
st.info("No parameter information available.")
+ results = selected_result.results
+ settings = selected_result.settings
+
# Parameter distributions
st.subheader("📈 Parameter Distributions")
render_parameter_distributions(results, settings)
@@ -183,7 +191,7 @@ def render_training_analysis_page():
st.subheader("🎨 Binned Parameter Space")
# Check if this is an ESPA model and show ESPA-specific plots
- model_type = settings.get("model", "espa")
+ model_type = settings.model
if model_type == "espa":
# Show ESPA-specific binned plots (eps_cl vs eps_e binned by K)
render_espa_binned_parameter_space(results, selected_metric)
@@ -196,31 +204,15 @@ def render_training_analysis_page():
# For non-ESPA models, show the generic binned plots
render_binned_parameter_space(results, selected_metric)
- st.divider()
- # Parameter Correlation
- st.header("🔗 Parameter Correlation")
+def render_data_export_section(results, selected_result):
+ """Render data export section with download buttons.
- render_parameter_correlation(results, selected_metric)
+ Args:
+ results: DataFrame with CV results.
+ selected_result: The selected TrainingResult object.
- st.divider()
-
- # Multi-Metric Comparison
- if len(available_metrics) >= 2:
- st.header("📊 Multi-Metric Comparison")
-
- render_multi_metric_comparison(results)
-
- st.divider()
-
- # Top Configurations
- st.header("🏆 Top Performing Configurations")
-
- render_top_configurations(results, selected_metric, top_n)
-
- st.divider()
-
- # Raw Data Export
+ """
with st.expander("💾 Export Data", expanded=False):
st.subheader("Download Results")
@@ -234,22 +226,114 @@ def render_training_analysis_page():
data=csv_data,
file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv",
- width="stretch",
)
with col2:
- # Download settings as text
+ # Download settings as JSON
import json
- settings_json = json.dumps(settings, indent=2)
+ settings_dict = {
+ "task": selected_result.settings.task,
+ "grid": selected_result.settings.grid,
+ "level": selected_result.settings.level,
+ "model": selected_result.settings.model,
+ "cv_splits": selected_result.settings.cv_splits,
+ "classes": selected_result.settings.classes,
+ }
+ settings_json = json.dumps(settings_dict, indent=2)
st.download_button(
label="⚙️ Download Settings (JSON)",
data=settings_json,
file_name=f"{selected_result.path.name}_settings.json",
mime="application/json",
- width="stretch",
)
# Show raw data preview
st.subheader("Raw Data Preview")
st.dataframe(results.head(100), width="stretch")
+
+
+def render_training_analysis_page():
+ """Render the Training Results Analysis page of the dashboard."""
+ st.title("🦾 Training Results Analysis")
+
+ st.markdown(
+ """
+ Analyze training results, hyperparameter search performance, and model configurations.
+ Select a training run from the sidebar to explore detailed metrics and parameter analysis.
+ """
+ )
+
+ # Load all available training results
+ training_results = load_all_training_results()
+
+ if not training_results:
+ st.warning("No training results found. Please run some training experiments first.")
+ st.info("Run training using: `pixi run python -m entropice.ml.training`")
+ return
+
+ st.success(f"Found **{len(training_results)}** training result(s)")
+
+ st.divider()
+
+ # Sidebar: Training run selection
+ with st.sidebar:
+ selection_result = render_analysis_settings_sidebar(training_results)
+ if selection_result[0] is None:
+ return
+ selected_result, selected_metric, refit_metric, top_n = selection_result
+
+ # Main content area
+ results = selected_result.results
+ settings = selected_result.settings
+
+ # Run Information
+ render_run_information(selected_result, refit_metric)
+
+ st.divider()
+
+ # Performance Summary Section
+ st.header("📊 Performance Overview")
+ render_performance_summary(results, refit_metric)
+
+ st.divider()
+
+ # Confusion Matrix Map Section
+ st.header("🗺️ Prediction Results Map")
+ render_confusion_matrix_map(selected_result.path, settings)
+
+ st.divider()
+
+ # Cross-Validation Statistics
+ render_cv_statistics_section(selected_result, selected_metric)
+
+ st.divider()
+
+ # Parameter Space Analysis
+ render_parameter_space_section(selected_result, selected_metric)
+
+ st.divider()
+
+ # Parameter Correlation
+ st.header("🔗 Parameter Correlation")
+ render_parameter_correlation(results, selected_metric)
+
+ st.divider()
+
+ # Multi-Metric Comparison
+ if len(selected_result.available_metrics) >= 2:
+ st.header("📊 Multi-Metric Comparison")
+ render_multi_metric_comparison(results)
+ st.divider()
+
+ # Top Configurations
+ st.header("🏆 Top Performing Configurations")
+ render_top_configurations(results, selected_metric, top_n)
+
+ st.divider()
+
+ # Raw Data Export
+ render_data_export_section(results, selected_result)
+
+ st.balloons()
+ stopwatch.summary()
diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py
index 4523f4c..5ee37de 100644
--- a/src/entropice/dashboard/views/training_data_page.py
+++ b/src/entropice/dashboard/views/training_data_page.py
@@ -1,6 +1,9 @@
"""Training Data page: Visualization of training data distributions."""
+from typing import cast
+
import streamlit as st
+from stopuhr import stopwatch
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
@@ -19,30 +22,21 @@ from entropice.dashboard.plots.training_data import (
render_spatial_map,
)
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
-from entropice.ml.dataset import DatasetEnsemble
+from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.spatial import grids
+from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
-def render_training_data_page():
- """Render the Training Data page of the dashboard."""
- st.title("Training Data")
+def render_dataset_configuration_sidebar():
+ """Render dataset configuration selector in sidebar with form.
- # Sidebar widgets for dataset configuration in a form
+ Stores the selected ensemble in session state when form is submitted.
+ """
with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration")
- # Combined grid and level selection
- grid_options = [
- "hex-3",
- "hex-4",
- "hex-5",
- "hex-6",
- "healpix-6",
- "healpix-7",
- "healpix-8",
- "healpix-9",
- "healpix-10",
- ]
+ # Grid selection
+ grid_options = [gc.display_name for gc in grid_configs]
grid_level_combined = st.selectbox(
"Grid Configuration",
@@ -51,9 +45,8 @@ def render_training_data_page():
help="Select the grid system and resolution level",
)
- # Parse grid type and level
- grid, level_str = grid_level_combined.split("-")
- level = int(level_str)
+ # Find the selected grid config
+ selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Target feature selection
target = st.selectbox(
@@ -66,317 +59,422 @@ def render_training_data_page():
# Members selection
st.subheader("Dataset Members")
- all_members = [
- "AlphaEarth",
- "ArcticDEM",
- "ERA5-yearly",
- "ERA5-seasonal",
- "ERA5-shoulder",
- ]
- selected_members = []
+ all_members = cast(
+ list[L2SourceDataset],
+ ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
+ )
+ selected_members: list[L2SourceDataset] = []
for member in all_members:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
- selected_members.append(member)
+ selected_members.append(member) # type: ignore[arg-type]
# Form submit button
load_button = st.form_submit_button(
"Load Dataset",
type="primary",
- width="stretch",
+ use_container_width=True,
disabled=len(selected_members) == 0,
)
# Create DatasetEnsemble only when form is submitted
if load_button:
- ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
+ ensemble = DatasetEnsemble(
+ grid=selected_grid_config.grid,
+ level=selected_grid_config.level,
+ target=cast(TargetDataset, target),
+ members=selected_members,
+ )
# Store ensemble in session state
st.session_state["dataset_ensemble"] = ensemble
st.session_state["dataset_loaded"] = True
- # Display dataset information if loaded
- if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
- ensemble = st.session_state["dataset_ensemble"]
- # Display current configuration
- st.subheader("📊 Current Configuration")
+def render_dataset_statistics(ensemble: DatasetEnsemble):
+ """Render dataset statistics and configuration overview.
- # Create a visually appealing layout with columns
- col1, col2, col3, col4 = st.columns(4)
+ Args:
+ ensemble: The dataset ensemble configuration.
- with col1:
- st.metric(label="Grid Type", value=ensemble.grid.upper())
+ """
+ st.markdown("### 📊 Dataset Configuration")
- with col2:
- st.metric(label="Grid Level", value=ensemble.level)
+ # Display current configuration in columns
+ col1, col2, col3, col4 = st.columns(4)
- with col3:
- st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
+ with col1:
+ st.metric(label="Grid Type", value=ensemble.grid.upper())
- with col4:
- st.metric(label="Members", value=len(ensemble.members))
+ with col2:
+ st.metric(label="Grid Level", value=ensemble.level)
- # Display members in an expandable section
- with st.expander("🗂️ Dataset Members", expanded=False):
- members_cols = st.columns(len(ensemble.members))
- for idx, member in enumerate(ensemble.members):
- with members_cols[idx]:
- st.markdown(f"✓ **{member}**")
+ with col3:
+ st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
- # Display dataset ID in a styled container
- st.info(f"**Dataset ID:** `{ensemble.id()}`")
+ with col4:
+ st.metric(label="Members", value=len(ensemble.members))
- # Display dataset statistics
- st.markdown("---")
- st.subheader("📈 Dataset Statistics")
+ # Display members in an expandable section
+ with st.expander("🗂️ Dataset Members", expanded=False):
+ members_cols = st.columns(len(ensemble.members))
+ for idx, member in enumerate(ensemble.members):
+ with members_cols[idx]:
+ st.markdown(f"✓ **{member}**")
- with st.spinner("Computing dataset statistics..."):
- stats = ensemble.get_stats()
+ # Display dataset ID in a styled container
+ st.info(f"**Dataset ID:** `{ensemble.id()}`")
- # High-level summary metrics
- col1, col2, col3 = st.columns(3)
- with col1:
- st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
- with col2:
- st.metric(label="Total Features", value=f"{stats['total_features']:,}")
- with col3:
- st.metric(label="Data Sources", value=len(stats["members"]))
+ # Display detailed dataset statistics
+ st.markdown("---")
+ st.markdown("### 📈 Dataset Statistics")
- # Detailed member statistics in expandable section
- with st.expander("📦 Data Source Details", expanded=False):
- for member, member_stats in stats["members"].items():
- st.markdown(f"### {member}")
+ with st.spinner("Computing dataset statistics..."):
+ stats = ensemble.get_stats()
- # Create metrics for this member
- metric_cols = st.columns(4)
- with metric_cols[0]:
- st.metric("Features", member_stats["num_features"])
- with metric_cols[1]:
- st.metric("Variables", member_stats["num_variables"])
- with metric_cols[2]:
- # Display dimensions in a more readable format
- dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
- st.metric("Shape", dim_str)
- with metric_cols[3]:
- # Calculate total data points
- total_points = 1
- for dim_size in member_stats["dimensions"].values():
- total_points *= dim_size
- st.metric("Data Points", f"{total_points:,}")
+ # High-level summary metrics
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
+ with col2:
+ st.metric(label="Total Features", value=f"{stats['total_features']:,}")
+ with col3:
+ st.metric(label="Data Sources", value=len(stats["members"]))
- # Show variables as colored badges
- st.markdown("**Variables:**")
- vars_html = " ".join(
- [
- f'{v}'
- for v in member_stats["variables"]
- ]
- )
- st.markdown(vars_html, unsafe_allow_html=True)
+ # Detailed member statistics in expandable section
+ with st.expander("📦 Data Source Details", expanded=False):
+ for member, member_stats in stats["members"].items():
+ st.markdown(f"### {member}")
- # Show dimension details
- st.markdown("**Dimensions:**")
- dim_html = " ".join(
- [
- f''
- f"{dim_name}: {dim_size}"
- for dim_name, dim_size in member_stats["dimensions"].items()
- ]
- )
- st.markdown(dim_html, unsafe_allow_html=True)
+ # Create metrics for this member
+ metric_cols = st.columns(4)
+ with metric_cols[0]:
+ st.metric("Features", member_stats["num_features"])
+ with metric_cols[1]:
+ st.metric("Variables", member_stats["num_variables"])
+ with metric_cols[2]:
+ # Display dimensions in a more readable format
+ dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
+ st.metric("Shape", dim_str)
+ with metric_cols[3]:
+ # Calculate total data points
+ total_points = 1
+ for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
+ total_points *= dim_size
+ st.metric("Data Points", f"{total_points:,}")
- st.markdown("---")
-
- st.markdown("---")
-
- # Create tabs for different data views
- tab_names = ["📊 Labels", "📐 Areas"]
-
- # Add tabs for each member
- for member in ensemble.members:
- if member == "AlphaEarth":
- tab_names.append("🌍 AlphaEarth")
- elif member == "ArcticDEM":
- tab_names.append("🏔️ ArcticDEM")
- elif member.startswith("ERA5"):
- # Group ERA5 temporal variants
- if "🌡️ ERA5" not in tab_names:
- tab_names.append("🌡️ ERA5")
-
- tabs = st.tabs(tab_names)
-
- # Labels tab
- with tabs[0]:
- st.markdown("### Target Labels Distribution and Spatial Visualization")
-
- # Load training data for all three tasks
- with st.spinner("Loading training data for all tasks..."):
- train_data_dict = load_all_training_data(ensemble)
-
- # Calculate total samples (use binary as reference)
- total_samples = len(train_data_dict["binary"])
- train_samples = (train_data_dict["binary"].split == "train").sum().item()
- test_samples = (train_data_dict["binary"].split == "test").sum().item()
-
- st.success(
- f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
+ # Show variables as colored badges
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats["variables"] # type: ignore[union-attr]
+ ]
)
+ st.markdown(vars_html, unsafe_allow_html=True)
- # Render distribution histograms
- st.markdown("---")
- render_all_distribution_histograms(train_data_dict)
+ # Show dimension details
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size}"
+ for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
- # Render spatial map
- binary_dataset = train_data_dict["binary"]
- assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
- render_spatial_map(train_data_dict)
+def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
+ """Render target labels distribution and spatial visualization.
- # Areas tab
- with tabs[1]:
- st.markdown("### Grid Cell Areas and Land/Water Distribution")
+ Args:
+ ensemble: The dataset ensemble configuration.
+ train_data_dict: Pre-loaded training data for all tasks.
- st.markdown(
- "This visualization shows the spatial distribution of cell areas, land areas, "
- "water areas, and land ratio across the grid. The grid has been filtered to "
- "include only cells in the permafrost region (>50° latitude, <85° latitude) "
- "with >10% land coverage."
- )
+ """
+ st.markdown("### Target Labels Distribution and Spatial Visualization")
- # Load grid data
- grid_gdf = grids.open(ensemble.grid, ensemble.level)
+ # Calculate total samples (use binary as reference)
+ total_samples = len(train_data_dict["binary"])
+ train_samples = (train_data_dict["binary"].split == "train").sum().item()
+ test_samples = (train_data_dict["binary"].split == "test").sum().item()
- st.success(
- f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
- f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
- )
+ st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
- # Show summary statistics
- col1, col2, col3, col4 = st.columns(4)
- with col1:
- st.metric("Total Cells", f"{len(grid_gdf):,}")
- with col2:
- st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
- with col3:
- st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
- with col4:
- total_land = grid_gdf["land_area"].sum()
- st.metric("Total Land Area", f"{total_land:,.0f} km²")
+ # Render distribution histograms
+ st.markdown("---")
+ render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
- st.markdown("---")
+ st.markdown("---")
- if (ensemble.grid == "hex" and ensemble.level == 6) or (
- ensemble.grid == "healpix" and ensemble.level == 10
- ):
- st.warning(
- "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
- )
- else:
- render_areas_map(grid_gdf, ensemble.grid)
+ # Render spatial map
+ binary_dataset = train_data_dict["binary"]
+ assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
- # AlphaEarth tab
- tab_idx = 2
- if "AlphaEarth" in ensemble.members:
- with tabs[tab_idx]:
- st.markdown("### AlphaEarth Embeddings Analysis")
+ render_spatial_map(train_data_dict)
- with st.spinner("Loading AlphaEarth data..."):
- alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
- st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
+def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
+ """Render grid cell areas and land/water distribution.
- render_alphaearth_overview(alphaearth_ds)
- render_alphaearth_plots(alphaearth_ds)
+ Args:
+ ensemble: The dataset ensemble configuration.
+ grid_gdf: Pre-loaded grid GeoDataFrame.
- st.markdown("---")
+ """
+ st.markdown("### Grid Cell Areas and Land/Water Distribution")
- if (ensemble.grid == "hex" and ensemble.level == 6) or (
- ensemble.grid == "healpix" and ensemble.level == 10
- ):
- st.warning(
- "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
- )
- else:
- render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
+ st.markdown(
+ "This visualization shows the spatial distribution of cell areas, land areas, "
+ "water areas, and land ratio across the grid. The grid has been filtered to "
+ "include only cells in the permafrost region (>50° latitude, <85° latitude) "
+ "with >10% land coverage."
+ )
- tab_idx += 1
+ st.success(
+ f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
+ f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
+ )
- # ArcticDEM tab
- if "ArcticDEM" in ensemble.members:
- with tabs[tab_idx]:
- st.markdown("### ArcticDEM Terrain Analysis")
+ # Show summary statistics
+ col1, col2, col3, col4 = st.columns(4)
+ with col1:
+ st.metric("Total Cells", f"{len(grid_gdf):,}")
+ with col2:
+ st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
+ with col3:
+ st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
+ with col4:
+ total_land = grid_gdf["land_area"].sum()
+ st.metric("Total Land Area", f"{total_land:,.0f} km²")
- with st.spinner("Loading ArcticDEM data..."):
- arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
-
- st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
-
- render_arcticdem_overview(arcticdem_ds)
- render_arcticdem_plots(arcticdem_ds)
-
- st.markdown("---")
-
- if (ensemble.grid == "hex" and ensemble.level == 6) or (
- ensemble.grid == "healpix" and ensemble.level == 10
- ):
- st.warning(
- "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
- )
- else:
- render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
-
- tab_idx += 1
-
- # ERA5 tab (combining all temporal variants)
- era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
- if era5_members:
- with tabs[tab_idx]:
- st.markdown("### ERA5 Climate Data Analysis")
-
- # Let user select which ERA5 temporal aggregation to view
- era5_options = {
- "ERA5-yearly": "Yearly",
- "ERA5-seasonal": "Seasonal (Winter/Summer)",
- "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
- }
-
- available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
-
- selected_era5 = st.selectbox(
- "Select ERA5 temporal aggregation",
- options=list(available_era5.keys()),
- format_func=lambda x: available_era5[x],
- key="era5_temporal_select",
- )
-
- if selected_era5:
- temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
-
- with st.spinner(f"Loading {selected_era5} data..."):
- era5_ds, targets = load_source_data(ensemble, selected_era5)
-
- st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
-
- render_era5_overview(era5_ds, temporal_type)
- render_era5_plots(era5_ds, temporal_type)
-
- st.markdown("---")
-
- if (ensemble.grid == "hex" and ensemble.level == 6) or (
- ensemble.grid == "healpix" and ensemble.level == 10
- ):
- st.warning(
- "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
- )
- else:
- render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
-
- # Show balloons once after all tabs are rendered
- st.balloons()
+ st.markdown("---")
+ # Check if we should skip map rendering for performance
+ if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
+ st.warning(
+ "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
+ "due to performance considerations."
+ )
else:
- st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
+ render_areas_map(grid_gdf, ensemble.grid)
+
+
+def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
+ """Render AlphaEarth embeddings analysis.
+
+ Args:
+ ensemble: The dataset ensemble configuration.
+ alphaearth_ds: Pre-loaded AlphaEarth dataset.
+ targets: Pre-loaded targets GeoDataFrame.
+
+ """
+ st.markdown("### AlphaEarth Embeddings Analysis")
+
+ st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
+
+ render_alphaearth_overview(alphaearth_ds)
+ render_alphaearth_plots(alphaearth_ds)
+
+ st.markdown("---")
+
+ # Check if we should skip map rendering for performance
+ if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
+ st.warning(
+ "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
+ "due to performance considerations."
+ )
+ else:
+ render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
+
+
+def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
+ """Render ArcticDEM terrain analysis.
+
+ Args:
+ ensemble: The dataset ensemble configuration.
+ arcticdem_ds: Pre-loaded ArcticDEM dataset.
+ targets: Pre-loaded targets GeoDataFrame.
+
+ """
+ st.markdown("### ArcticDEM Terrain Analysis")
+
+ st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
+
+ render_arcticdem_overview(arcticdem_ds)
+ render_arcticdem_plots(arcticdem_ds)
+
+ st.markdown("---")
+
+ # Check if we should skip map rendering for performance
+ if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
+ st.warning(
+ "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
+ "due to performance considerations."
+ )
+ else:
+ render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
+
+
+def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
+ """Render ERA5 climate data analysis.
+
+ Args:
+ ensemble: The dataset ensemble configuration.
+ era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
+ targets: Pre-loaded targets GeoDataFrame.
+
+ """
+ st.markdown("### ERA5 Climate Data Analysis")
+
+ # Let user select which ERA5 temporal aggregation to view
+ era5_options = {
+ "ERA5-yearly": "Yearly",
+ "ERA5-seasonal": "Seasonal (Winter/Summer)",
+ "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
+ }
+
+ available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
+
+ selected_era5 = st.selectbox(
+ "Select ERA5 temporal aggregation",
+ options=list(available_era5.keys()),
+ format_func=lambda x: available_era5[x],
+ key="era5_temporal_select",
+ )
+
+ if selected_era5 and selected_era5 in era5_data:
+ era5_ds, temporal_type = era5_data[selected_era5]
+
+ render_era5_overview(era5_ds, temporal_type)
+ render_era5_plots(era5_ds, temporal_type)
+
+ st.markdown("---")
+
+ # Check if we should skip map rendering for performance
+ if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
+ st.warning(
+ "🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
+ "due to performance considerations."
+ )
+ else:
+ render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
+
+
+def render_training_data_page():
+ """Render the Training Data page of the dashboard."""
+ st.title("🎯 Training Data")
+
+ st.markdown(
+ """
+ Explore and visualize the training data for RTS prediction models.
+ Configure your dataset by selecting grid configuration, target dataset,
+ and data sources in the sidebar, then click "Load Dataset" to begin.
+ """
+ )
+
+ # Render sidebar configuration
+ render_dataset_configuration_sidebar()
+
+ # Check if dataset is loaded in session state
+ if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
+ st.info(
+ "👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
+ )
+ return
+
+ # Get ensemble from session state
+ ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
+
+ st.divider()
+
+ # Load all necessary data once
+ with st.spinner("Loading dataset..."):
+ # Load training data for all tasks
+ train_data_dict = load_all_training_data(ensemble)
+
+ # Load grid data
+ grid_gdf = grids.open(ensemble.grid, ensemble.level)
+
+ # Load targets (needed by all source data views)
+ targets = ensemble._read_target()
+
+ # Load AlphaEarth data if in members
+ alphaearth_ds = None
+ if "AlphaEarth" in ensemble.members:
+ alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
+
+ # Load ArcticDEM data if in members
+ arcticdem_ds = None
+ if "ArcticDEM" in ensemble.members:
+ arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
+
+ # Load ERA5 data for all temporal aggregations in members
+ era5_data = {}
+ era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
+ for era5_member in era5_members:
+ era5_ds, _ = load_source_data(ensemble, era5_member)
+ temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
+ era5_data[era5_member] = (era5_ds, temporal_type)
+
+ st.success(
+ f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
+ )
+
+ # Render dataset statistics
+ render_dataset_statistics(ensemble)
+
+ st.markdown("---")
+
+ # Create tabs for different data views
+ tab_names = ["📊 Labels", "📐 Areas"]
+
+ # Add tabs for each member based on what's in the ensemble
+ if "AlphaEarth" in ensemble.members:
+ tab_names.append("🌍 AlphaEarth")
+ if "ArcticDEM" in ensemble.members:
+ tab_names.append("🏔️ ArcticDEM")
+
+ # Check for ERA5 members
+ if era5_members:
+ tab_names.append("🌡️ ERA5")
+
+ tabs = st.tabs(tab_names)
+
+ # Track current tab index
+ tab_idx = 0
+
+ # Labels tab
+ with tabs[tab_idx]:
+ render_labels_view(ensemble, train_data_dict)
+ tab_idx += 1
+
+ # Areas tab
+ with tabs[tab_idx]:
+ render_areas_view(ensemble, grid_gdf)
+ tab_idx += 1
+
+ # AlphaEarth tab
+ if "AlphaEarth" in ensemble.members:
+ with tabs[tab_idx]:
+ render_alphaearth_view(ensemble, alphaearth_ds, targets)
+ tab_idx += 1
+
+ # ArcticDEM tab
+ if "ArcticDEM" in ensemble.members:
+ with tabs[tab_idx]:
+ render_arcticdem_view(ensemble, arcticdem_ds, targets)
+ tab_idx += 1
+
+ # ERA5 tab (combining all temporal variants)
+ if era5_members:
+ with tabs[tab_idx]:
+ render_era5_view(ensemble, era5_data, targets)
+
+ # Show balloons once after all tabs are rendered
+ st.balloons()
+ stopwatch.summary()