Refactor other pages

This commit is contained in:
Tobias Hölzer 2026-01-04 18:39:38 +01:00
parent 4260b492ab
commit 393cc968cb
9 changed files with 962 additions and 559 deletions

View file

@ -11,12 +11,11 @@ Pages:
import streamlit as st
# from entropice.dashboard.views.inference_page import render_inference_page
# from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
# from entropice.dashboard.views.training_data_page import render_training_data_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@ -25,17 +24,17 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
# "Training": [training_data_page, training_analysis_page],
# "Model State": [model_state_page],
# "Inference": [inference_page],
"Training": [training_data_page, training_analysis_page],
"Model State": [model_state_page],
"Inference": [inference_page],
}
)
pg.run()

View file

@ -143,8 +143,8 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
# Extract scale information from settings if available
param_scales = {}
if settings and "param_grid" in settings:
param_grid = settings["param_grid"]
if settings and hasattr(settings, "param_grid"):
param_grid = settings.param_grid
for param_name, param_config in param_grid.items():
if isinstance(param_config, dict) and "distribution" in param_config:
# loguniform distribution indicates log scale
@ -1181,10 +1181,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
preds_gdf = gpd.read_parquet(preds_file)
# Get task and target information from settings
task = settings.get("task", "binary")
target = settings.get("target", "darts_rts")
grid = settings.get("grid", "hex")
level = settings.get("level", 3)
task = settings.task
target = settings.target
grid = settings.grid
level = settings.level
# Create dataset ensemble to get true labels
# We need to load the target data to get true labels

View file

@ -8,7 +8,7 @@ import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.data import TrainingResult
from entropice.dashboard.utils.loaders import TrainingResult
def _fix_hex_geometry(geom):
@ -197,8 +197,8 @@ def render_inference_map(result: TrainingResult):
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
# Get settings
task = result.settings.get("task", "binary")
grid = result.settings.get("grid", "hex")
task = result.settings.task
grid = result.settings.grid
# Create controls in columns
col1, col2, col3 = st.columns([2, 2, 1])

View file

@ -147,10 +147,11 @@ def load_all_training_data(
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
dataset = e.create(filter_target_col=e.covcol)
return {
"binary": e.create_cat_training_dataset("binary", device="cpu"),
"count": e.create_cat_training_dataset("count", device="cpu"),
"density": e.create_cat_training_dataset("density", device="cpu"),
"binary": e._cat_and_split(dataset, "binary", device="cpu"),
"count": e._cat_and_split(dataset, "count", device="cpu"),
"density": e._cat_and_split(dataset, "density", device="cpu"),
}

View file

@ -1,7 +1,8 @@
"""Inference page: Visualization of model inference results across the study region."""
import geopandas as gpd
import streamlit as st
from entropice.dashboard.utils.data import load_all_training_results
from stopuhr import stopwatch
from entropice.dashboard.plots.inference import (
render_class_comparison,
@ -10,66 +11,177 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics,
render_spatial_distribution_stats,
)
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
@st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.
Args:
training_results: List of available TrainingResult objects.
Returns:
Selected TrainingResult object.
"""
st.header("Select Training Run")
# Create selection options with task-first naming
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to view inference results",
key="inference_run_select",
)
selected_result = training_options[selected_name]
st.divider()
# Show run information in sidebar
st.subheader("Run Information")
st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}")
st.markdown(f"**Model:** {selected_result.settings.model.upper()}")
st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}")
st.markdown(f"**Level:** {selected_result.settings.level}")
st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}")
return selected_result
def render_run_information(selected_result: TrainingResult):
"""Render training run configuration overview.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("📋 Run Configuration")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.task.capitalize())
with col2:
st.metric("Model", selected_result.settings.model.upper())
with col3:
st.metric("Grid", selected_result.settings.grid.capitalize())
with col4:
st.metric("Level", selected_result.settings.level)
with col5:
st.metric("Target", selected_result.settings.target.replace("darts_", ""))
def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render inference summary statistics section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("📊 Inference Summary")
render_inference_statistics(predictions_gdf, task)
def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame):
"""Render spatial coverage statistics section.
Args:
predictions_gdf: GeoDataFrame with predictions.
"""
st.header("🌍 Spatial Coverage")
render_spatial_distribution_stats(predictions_gdf)
def render_map_visualization_section(selected_result: TrainingResult):
"""Render 3D map visualization section.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("🗺️ Interactive Prediction Map")
st.markdown(
"""
3D visualization of predictions across the study region. The map shows predicted
classes with color coding and spatial distribution of model outputs.
"""
)
render_inference_map(selected_result)
def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render class distribution histogram section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("📈 Class Distribution")
st.markdown("Distribution of predicted classes across all inference cells.")
render_class_distribution_histogram(predictions_gdf, task)
def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render class comparison analysis section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("🔍 Class Comparison Analysis")
st.markdown(
"""
Detailed comparison of predicted classes showing probability distributions
and confidence metrics for different class predictions.
"""
)
render_class_comparison(predictions_gdf, task)
def render_inference_page():
"""Render the Inference page of the dashboard."""
st.title("🗺️ Inference Results")
st.markdown(
"""
Explore spatial predictions from trained models across the Arctic permafrost region.
Select a training run from the sidebar to visualize prediction maps, class distributions,
and spatial coverage statistics.
"""
)
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`")
st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
selected_result = render_sidebar_selection(training_results)
# Create selection options with task-first naming
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to view inference results",
)
selected_result = training_options[selected_name]
st.divider()
# Show run information in sidebar
st.subheader("Run Information")
settings = selected_result.settings
st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
# Main content area - Run Information at the top
st.header("📋 Run Configuration")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
with col2:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
with col3:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
with col4:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col5:
st.metric(
"Target",
selected_result.settings.get("target", "Unknown").replace("darts_", ""),
)
# Main content area - Run Information
render_run_information(selected_result)
st.divider()
@ -80,31 +192,33 @@ def render_inference_page():
st.info("Inference results are generated automatically during training.")
return
# Load predictions for statistics
import geopandas as gpd
predictions_gdf = gpd.read_parquet(preds_file)
task = selected_result.settings.get("task", "binary")
# Load predictions
with st.spinner("Loading inference results..."):
predictions_gdf = gpd.read_parquet(preds_file)
task = selected_result.settings.task
# Inference Statistics Section
render_inference_statistics(predictions_gdf, task)
render_inference_statistics_section(predictions_gdf, task)
st.divider()
# Spatial Coverage Section
render_spatial_distribution_stats(predictions_gdf)
render_spatial_coverage_section(predictions_gdf)
st.divider()
# 3D Map Visualization Section
render_inference_map(selected_result)
render_map_visualization_section(selected_result)
st.divider()
# Class Distribution Section
render_class_distribution_histogram(predictions_gdf, task)
render_class_distribution_section(predictions_gdf, task)
st.divider()
# Class Comparison Section
render_class_comparison(predictions_gdf, task)
render_class_comparison_section(predictions_gdf, task)
st.balloons()
stopwatch.summary()

View file

@ -1,16 +1,8 @@
"""Model State page for the Entropice dashboard."""
"""Model State page: Visualization of model internal state and feature importance."""
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.data import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
from stopuhr import stopwatch
from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap,
@ -26,46 +18,64 @@ from entropice.dashboard.plots.model_state import (
plot_top_features,
)
from entropice.dashboard.utils.colors import generate_unified_colormap
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
from entropice.dashboard.utils.unsembler import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
)
from entropice.utils.types import L2SourceDataset
def render_model_state_page():
"""Render the Model State page of the dashboard."""
st.title("Model State")
st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
def get_members_from_settings(settings) -> list[L2SourceDataset]:
"""Extract dataset members from training settings.
# Load available training results
training_results = load_all_training_results()
Args:
settings: TrainingSettings object containing dataset configuration.
if not training_results:
st.error("No training results found. Please run a training search first.")
return
Returns:
List of L2SourceDataset members used in training.
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
"""
return settings.members
# Result selection with model-first naming
result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(result_options.keys()),
help="Choose a training result to visualize model state",
)
selected_result = result_options[selected_name]
st.divider()
@st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.
# Get the model type from settings
model_type = selected_result.settings.get("model", "espa")
Args:
training_results: List of available TrainingResult objects.
# Load model state
with st.spinner("Loading model state..."):
model_state = load_model_state(selected_result)
if model_state is None:
st.error("Could not load model state for this result.")
return
Returns:
Selected TrainingResult object.
# Display basic model state info
"""
st.header("Select Training Run")
# Result selection with task-first naming
result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(result_options.keys()),
index=0,
help="Choose a training result to visualize model state",
key="model_state_training_run_select",
)
selected_result = result_options[selected_name]
return selected_result
def render_model_info(model_state: xr.Dataset, model_type: str):
"""Render basic model state information.
Args:
model_state: Xarray dataset containing model state.
model_type: Type of model (espa, xgboost, rf, knn).
"""
with st.expander("Model State Information", expanded=False):
st.write(f"**Model Type:** {model_type.upper()}")
st.write(f"**Variables:** {list(model_state.data_vars)}")
@ -73,15 +83,23 @@ def render_model_state_page():
st.write(f"**Coordinates:** {list(model_state.coords)}")
st.write(f"**Attributes:** {dict(model_state.attrs)}")
# Display dataset members summary
st.header("📊 Training Data Summary")
members = get_members_from_settings(selected_result.settings)
st.markdown(f"""
def render_training_data_summary(members: list[L2SourceDataset]):
"""Render summary of training data sources.
Args:
members: List of dataset members used in training.
"""
st.header("📊 Training Data Summary")
st.markdown(
f"""
**Dataset Members Used in Training:** {len(members)}
The following data sources were used to train this model:
""")
"""
)
# Create a nice display of members with emojis
member_display = {
@ -98,6 +116,52 @@ def render_model_state_page():
display_name = member_display.get(member, f"📁 {member}")
st.info(display_name)
def render_model_state_page():
"""Render the Model State page of the dashboard."""
st.title("🔬 Model State")
st.markdown(
"""
Comprehensive visualization of the best model's internal state and feature importance.
Select a training run from the sidebar to explore model parameters, feature weights,
and data source contributions.
"""
)
# Load available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection
with st.sidebar:
selected_result = render_sidebar_selection(training_results)
# Get the model type from settings
model_type = selected_result.settings.model
# Load model state
with st.spinner("Loading model state..."):
model_state = selected_result.load_model_state()
if model_state is None:
st.error("Could not load model state for this result.")
st.info("The model state file (best_estimator_state.nc) may be missing from the training results.")
return
# Display basic model state info
render_model_info(model_state, model_type)
# Display dataset members summary
members = get_members_from_settings(selected_result.settings)
render_training_data_summary(members)
st.divider()
# Render model-specific visualizations
@ -112,9 +176,18 @@ def render_model_state_page():
else:
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
st.balloons()
stopwatch.summary()
def render_espa_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for ESPA model."""
def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for ESPA model.
Args:
model_state: Xarray dataset containing ESPA model state.
selected_result: TrainingResult object containing training configuration.
"""
# Scale feature weights by number of features
n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features
@ -143,8 +216,9 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
common_feature_array = extract_common_features(model_state)
# Generate unified colormaps
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
# Generate unified colormaps (convert dataclass to dict)
settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes}
_, _, altair_colors = generate_unified_colormap(settings_dict)
# Feature importance section
st.header("Feature Importance")
@ -255,8 +329,14 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for XGBoost model."""
def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for XGBoost model.
Args:
model_state: Xarray dataset containing XGBoost model state.
selected_result: TrainingResult object containing training configuration.
"""
from entropice.dashboard.plots.model_state import (
plot_xgboost_feature_importance,
plot_xgboost_importance_comparison,
@ -382,8 +462,14 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
def render_rf_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for Random Forest model."""
def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for Random Forest model.
Args:
model_state: Xarray dataset containing Random Forest model state.
selected_result: TrainingResult object containing training configuration.
"""
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
st.header("🌳 Random Forest Model Analysis")
@ -529,8 +615,14 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array)
def render_knn_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for KNN model."""
def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for KNN model.
Args:
model_state: Xarray dataset containing KNN model state.
selected_result: TrainingResult object containing training configuration.
"""
st.header("🔍 K-Nearest Neighbors Model Analysis")
st.markdown(
"""
@ -568,8 +660,13 @@ def render_knn_model_state(model_state: xr.Dataset, selected_result):
# Helper functions for embedding/era5/common features
def render_embedding_features(embedding_feature_array):
"""Render embedding feature visualizations."""
def render_embedding_features(embedding_feature_array: xr.DataArray):
"""Render embedding feature visualizations.
Args:
embedding_feature_array: DataArray containing AlphaEarth embedding feature weights.
"""
with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis")
st.markdown(
@ -619,7 +716,7 @@ def render_embedding_features(embedding_feature_array):
st.dataframe(top_emb, width="stretch")
def render_era5_features(era5_feature_array, temporal_group: str = ""):
def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""):
"""Render ERA5 feature visualizations.
Args:
@ -631,9 +728,10 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
with st.container(border=True):
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else ""
st.markdown(
f"""
Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods
Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods
are most important for the model predictions.
"""
)
@ -709,8 +807,13 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
st.dataframe(top_era5, width="stretch")
def render_arcticdem_features(arcticdem_feature_array):
"""Render ArcticDEM feature visualizations."""
def render_arcticdem_features(arcticdem_feature_array: xr.DataArray):
"""Render ArcticDEM feature visualizations.
Args:
arcticdem_feature_array: DataArray containing ArcticDEM feature weights.
"""
with st.container(border=True):
st.header("🏔️ ArcticDEM Feature Analysis")
st.markdown(
@ -758,8 +861,13 @@ def render_arcticdem_features(arcticdem_feature_array):
st.dataframe(top_arcticdem, width="stretch")
def render_common_features(common_feature_array):
"""Render common feature visualizations."""
def render_common_features(common_feature_array: xr.DataArray):
"""Render common feature visualizations.
Args:
common_feature_array: DataArray containing common feature weights.
"""
with st.container(border=True):
st.header("🗺️ Common Feature Analysis")
st.markdown(

View file

@ -658,5 +658,4 @@ def render_overview_page():
render_dataset_analysis()
st.balloons()
stopwatch.summary()

View file

@ -1,6 +1,7 @@
"""Training Results Analysis page: Analysis of training results and model performance."""
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
@ -12,169 +13,176 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
render_performance_summary,
render_top_configurations,
)
from entropice.dashboard.utils.data import load_all_training_results
from entropice.dashboard.utils.training import (
format_metric_name,
get_available_metrics,
get_cv_statistics,
get_parameter_space_summary,
)
from entropice.dashboard.utils.formatters import format_metric_name
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import CVResultsStatistics
def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard."""
st.title("🦾 Training Results Analysis")
@st.fragment
def render_analysis_settings_sidebar(training_results):
"""Render sidebar for training run and analysis settings selection.
# Load all available training results
training_results = load_all_training_results()
Args:
training_results: List of available TrainingResult objects.
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`")
return
Returns:
Tuple of (selected_result, selected_metric, refit_metric, top_n).
# Sidebar: Training run selection
with st.sidebar:
st.header("Select Training Run")
"""
st.header("Select Training Run")
# Create selection options with task-first naming
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
# Create selection options with task-first naming
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to analyze",
)
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to analyze",
key="training_run_select",
)
selected_result = training_options[selected_name]
selected_result = training_options[selected_name]
st.divider()
st.divider()
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
# Metric selection for detailed analysis
st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results)
available_metrics = selected_result.available_metrics
# Try to get refit metric from settings
refit_metric = selected_result.settings.get("refit_metric")
# Try to get refit metric from settings
refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric
task = selected_result.settings.get("task", "binary")
if task == "binary" and "f1" in available_metrics:
refit_metric = "f1"
elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted"
elif "accuracy" in available_metrics:
refit_metric = "accuracy"
elif available_metrics:
refit_metric = available_metrics[0]
else:
st.error("No metrics found in results.")
return
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric
task = selected_result.settings.task
if task == "binary" and "f1" in available_metrics:
refit_metric = "f1"
elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted"
elif "accuracy" in available_metrics:
refit_metric = "accuracy"
elif available_metrics:
refit_metric = available_metrics[0]
else:
default_metric_idx = 0
st.error("No metrics found in results.")
return None, None, None, None
selected_metric = st.selectbox(
"Primary Metric for Analysis",
options=available_metrics,
index=default_metric_idx,
format_func=format_metric_name,
help="Select the metric to focus on for detailed analysis",
)
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
else:
default_metric_idx = 0
# Top N configurations
top_n = st.slider(
"Top N Configurations",
min_value=5,
max_value=50,
value=10,
step=5,
help="Number of top configurations to display",
)
selected_metric = st.selectbox(
"Primary Metric for Analysis",
options=available_metrics,
index=default_metric_idx,
format_func=format_metric_name,
help="Select the metric to focus on for detailed analysis",
key="metric_select",
)
# Main content area - Run Information at the top
# Top N configurations
top_n = st.slider(
"Top N Configurations",
min_value=5,
max_value=50,
value=10,
step=5,
help="Number of top configurations to display",
key="top_n_slider",
)
return selected_result, selected_metric, refit_metric, top_n
def render_run_information(selected_result, refit_metric):
"""Render training run configuration overview.
Args:
selected_result: The selected TrainingResult object.
refit_metric: The refit metric used for model selection.
"""
st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
st.metric("Task", selected_result.settings.task.capitalize())
with col2:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
st.metric("Grid", selected_result.settings.grid.capitalize())
with col3:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
st.metric("Level", selected_result.settings.level)
with col4:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
st.metric("Model", selected_result.settings.model.upper())
with col5:
st.metric("Trials", len(selected_result.results))
with col6:
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
st.metric("CV Splits", selected_result.settings.cv_splits)
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Main content area
results = selected_result.results
settings = selected_result.settings
def render_cv_statistics_section(selected_result, selected_metric):
"""Render cross-validation statistics for selected metric.
# Performance Summary Section
st.header("📊 Performance Overview")
Args:
selected_result: The selected TrainingResult object.
selected_metric: The metric to display statistics for.
render_performance_summary(results, refit_metric)
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
st.divider()
# Quick Statistics
"""
st.header("📈 Cross-Validation Statistics")
cv_stats = get_cv_statistics(results, selected_metric)
from entropice.dashboard.utils.stats import CVMetricStatistics
if cv_stats:
col1, col2, col3, col4, col5 = st.columns(5)
cv_stats = CVMetricStatistics.compute(selected_result, selected_metric)
with col1:
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
col1, col2, col3, col4, col5 = st.columns(5)
with col2:
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
with col1:
st.metric("Best Score", f"{cv_stats.best_score:.4f}")
with col3:
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
with col2:
st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
with col4:
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
with col3:
st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
with col5:
st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
with col4:
st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
if "mean_cv_std" in cv_stats:
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
with col5:
st.metric("Median Score", f"{cv_stats.median_score:.4f}")
st.divider()
if cv_stats.mean_cv_std is not None:
st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
# Parameter Space Analysis
def render_parameter_space_section(selected_result, selected_metric):
"""Render parameter space analysis section.
Args:
selected_result: The selected TrainingResult object.
selected_metric: The metric to analyze parameters against.
"""
st.header("🔍 Parameter Space Analysis")
# Compute CV results statistics
cv_results_stats = CVResultsStatistics.compute(selected_result)
# Show parameter space summary
with st.expander("📋 Parameter Space Summary", expanded=False):
param_summary = get_parameter_space_summary(results)
if not param_summary.empty:
st.dataframe(param_summary, hide_index=True, width="stretch")
param_summary_df = cv_results_stats.parameters_to_dataframe()
if not param_summary_df.empty:
st.dataframe(param_summary_df, hide_index=True, width="stretch")
else:
st.info("No parameter information available.")
results = selected_result.results
settings = selected_result.settings
# Parameter distributions
st.subheader("📈 Parameter Distributions")
render_parameter_distributions(results, settings)
@ -183,7 +191,7 @@ def render_training_analysis_page():
st.subheader("🎨 Binned Parameter Space")
# Check if this is an ESPA model and show ESPA-specific plots
model_type = settings.get("model", "espa")
model_type = settings.model
if model_type == "espa":
# Show ESPA-specific binned plots (eps_cl vs eps_e binned by K)
render_espa_binned_parameter_space(results, selected_metric)
@ -196,31 +204,15 @@ def render_training_analysis_page():
# For non-ESPA models, show the generic binned plots
render_binned_parameter_space(results, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation")
def render_data_export_section(results, selected_result):
"""Render data export section with download buttons.
render_parameter_correlation(results, selected_metric)
Args:
results: DataFrame with CV results.
selected_result: The selected TrainingResult object.
st.divider()
# Multi-Metric Comparison
if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
"""
with st.expander("💾 Export Data", expanded=False):
st.subheader("Download Results")
@ -234,22 +226,114 @@ def render_training_analysis_page():
data=csv_data,
file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv",
width="stretch",
)
with col2:
# Download settings as text
# Download settings as JSON
import json
settings_json = json.dumps(settings, indent=2)
settings_dict = {
"task": selected_result.settings.task,
"grid": selected_result.settings.grid,
"level": selected_result.settings.level,
"model": selected_result.settings.model,
"cv_splits": selected_result.settings.cv_splits,
"classes": selected_result.settings.classes,
}
settings_json = json.dumps(settings_dict, indent=2)
st.download_button(
label="⚙️ Download Settings (JSON)",
data=settings_json,
file_name=f"{selected_result.path.name}_settings.json",
mime="application/json",
width="stretch",
)
# Show raw data preview
st.subheader("Raw Data Preview")
st.dataframe(results.head(100), width="stretch")
def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard."""
st.title("🦾 Training Results Analysis")
st.markdown(
"""
Analyze training results, hyperparameter search performance, and model configurations.
Select a training run from the sidebar to explore detailed metrics and parameter analysis.
"""
)
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection
with st.sidebar:
selection_result = render_analysis_settings_sidebar(training_results)
if selection_result[0] is None:
return
selected_result, selected_metric, refit_metric, top_n = selection_result
# Main content area
results = selected_result.results
settings = selected_result.settings
# Run Information
render_run_information(selected_result, refit_metric)
st.divider()
# Performance Summary Section
st.header("📊 Performance Overview")
render_performance_summary(results, refit_metric)
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
st.divider()
# Cross-Validation Statistics
render_cv_statistics_section(selected_result, selected_metric)
st.divider()
# Parameter Space Analysis
render_parameter_space_section(selected_result, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation")
render_parameter_correlation(results, selected_metric)
st.divider()
# Multi-Metric Comparison
if len(selected_result.available_metrics) >= 2:
st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
render_data_export_section(results, selected_result)
st.balloons()
stopwatch.summary()

View file

@ -1,6 +1,9 @@
"""Training Data page: Visualization of training data distributions."""
from typing import cast
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
@ -19,30 +22,21 @@ from entropice.dashboard.plots.training_data import (
render_spatial_map,
)
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.spatial import grids
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
def render_training_data_page():
"""Render the Training Data page of the dashboard."""
st.title("Training Data")
def render_dataset_configuration_sidebar():
"""Render dataset configuration selector in sidebar with form.
# Sidebar widgets for dataset configuration in a form
Stores the selected ensemble in session state when form is submitted.
"""
with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration")
# Combined grid and level selection
grid_options = [
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
# Grid selection
grid_options = [gc.display_name for gc in grid_configs]
grid_level_combined = st.selectbox(
"Grid Configuration",
@ -51,9 +45,8 @@ def render_training_data_page():
help="Select the grid system and resolution level",
)
# Parse grid type and level
grid, level_str = grid_level_combined.split("-")
level = int(level_str)
# Find the selected grid config
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Target feature selection
target = st.selectbox(
@ -66,317 +59,422 @@ def render_training_data_page():
# Members selection
st.subheader("Dataset Members")
all_members = [
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
]
selected_members = []
all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
selected_members: list[L2SourceDataset] = []
for member in all_members:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member)
selected_members.append(member) # type: ignore[arg-type]
# Form submit button
load_button = st.form_submit_button(
"Load Dataset",
type="primary",
width="stretch",
use_container_width=True,
disabled=len(selected_members) == 0,
)
# Create DatasetEnsemble only when form is submitted
if load_button:
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
target=cast(TargetDataset, target),
members=selected_members,
)
# Store ensemble in session state
st.session_state["dataset_ensemble"] = ensemble
st.session_state["dataset_loaded"] = True
# Display dataset information if loaded
if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
ensemble = st.session_state["dataset_ensemble"]
# Display current configuration
st.subheader("📊 Current Configuration")
def render_dataset_statistics(ensemble: DatasetEnsemble):
"""Render dataset statistics and configuration overview.
# Create a visually appealing layout with columns
col1, col2, col3, col4 = st.columns(4)
Args:
ensemble: The dataset ensemble configuration.
with col1:
st.metric(label="Grid Type", value=ensemble.grid.upper())
"""
st.markdown("### 📊 Dataset Configuration")
with col2:
st.metric(label="Grid Level", value=ensemble.level)
# Display current configuration in columns
col1, col2, col3, col4 = st.columns(4)
with col3:
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
with col1:
st.metric(label="Grid Type", value=ensemble.grid.upper())
with col4:
st.metric(label="Members", value=len(ensemble.members))
with col2:
st.metric(label="Grid Level", value=ensemble.level)
# Display members in an expandable section
with st.expander("🗂️ Dataset Members", expanded=False):
members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
with col3:
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
with col4:
st.metric(label="Members", value=len(ensemble.members))
# Display dataset statistics
st.markdown("---")
st.subheader("📈 Dataset Statistics")
# Display members in an expandable section
with st.expander("🗂️ Dataset Members", expanded=False):
members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
# High-level summary metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Display detailed dataset statistics
st.markdown("---")
st.markdown("### 📈 Dataset Statistics")
# Detailed member statistics in expandable section
with st.expander("📦 Data Source Details", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# Create metrics for this member
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
# Display dimensions in a more readable format
dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# High-level summary metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Show variables as colored badges
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Detailed member statistics in expandable section
with st.expander("📦 Data Source Details", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
# Show dimension details
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
# Create metrics for this member
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
# Display dimensions in a more readable format
dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
st.markdown("---")
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member
for member in ensemble.members:
if member == "AlphaEarth":
tab_names.append("🌍 AlphaEarth")
elif member == "ArcticDEM":
tab_names.append("🏔️ ArcticDEM")
elif member.startswith("ERA5"):
# Group ERA5 temporal variants
if "🌡️ ERA5" not in tab_names:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Labels tab
with tabs[0]:
st.markdown("### Target Labels Distribution and Spatial Visualization")
# Load training data for all three tasks
with st.spinner("Loading training data for all tasks..."):
train_data_dict = load_all_training_data(ensemble)
# Calculate total samples (use binary as reference)
total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
# Show variables as colored badges
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Render distribution histograms
st.markdown("---")
render_all_distribution_histograms(train_data_dict)
# Show dimension details
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
# Render spatial map
binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
render_spatial_map(train_data_dict)
def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
"""Render target labels distribution and spatial visualization.
# Areas tab
with tabs[1]:
st.markdown("### Grid Cell Areas and Land/Water Distribution")
Args:
ensemble: The dataset ensemble configuration.
train_data_dict: Pre-loaded training data for all tasks.
st.markdown(
"This visualization shows the spatial distribution of cell areas, land areas, "
"water areas, and land ratio across the grid. The grid has been filtered to "
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
"with >10% land coverage."
)
"""
st.markdown("### Target Labels Distribution and Spatial Visualization")
# Load grid data
grid_gdf = grids.open(ensemble.grid, ensemble.level)
# Calculate total samples (use binary as reference)
total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
# Show summary statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
# Render distribution histograms
st.markdown("---")
render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
st.markdown("---")
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_areas_map(grid_gdf, ensemble.grid)
# Render spatial map
binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
# AlphaEarth tab
tab_idx = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### AlphaEarth Embeddings Analysis")
render_spatial_map(train_data_dict)
with st.spinner("Loading AlphaEarth data..."):
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
"""Render grid cell areas and land/water distribution.
render_alphaearth_overview(alphaearth_ds)
render_alphaearth_plots(alphaearth_ds)
Args:
ensemble: The dataset ensemble configuration.
grid_gdf: Pre-loaded grid GeoDataFrame.
st.markdown("---")
"""
st.markdown("### Grid Cell Areas and Land/Water Distribution")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
st.markdown(
"This visualization shows the spatial distribution of cell areas, land areas, "
"water areas, and land ratio across the grid. The grid has been filtered to "
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
"with >10% land coverage."
)
tab_idx += 1
st.success(
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
# ArcticDEM tab
if "ArcticDEM" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### ArcticDEM Terrain Analysis")
# Show summary statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
with st.spinner("Loading ArcticDEM data..."):
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
tab_idx += 1
# ERA5 tab (combining all temporal variants)
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
if era5_members:
with tabs[tab_idx]:
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5:
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
with st.spinner(f"Loading {selected_era5} data..."):
era5_ds, targets = load_source_data(ensemble, selected_era5)
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
# Show balloons once after all tabs are rendered
st.balloons()
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
render_areas_map(grid_gdf, ensemble.grid)
def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
"""Render AlphaEarth embeddings analysis.
Args:
ensemble: The dataset ensemble configuration.
alphaearth_ds: Pre-loaded AlphaEarth dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### AlphaEarth Embeddings Analysis")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
render_alphaearth_overview(alphaearth_ds)
render_alphaearth_plots(alphaearth_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
"""Render ArcticDEM terrain analysis.
Args:
ensemble: The dataset ensemble configuration.
arcticdem_ds: Pre-loaded ArcticDEM dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ArcticDEM Terrain Analysis")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
"""Render ERA5 climate data analysis.
Args:
ensemble: The dataset ensemble configuration.
era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5 and selected_era5 in era5_data:
era5_ds, temporal_type = era5_data[selected_era5]
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
def render_training_data_page():
"""Render the Training Data page of the dashboard."""
st.title("🎯 Training Data")
st.markdown(
"""
Explore and visualize the training data for RTS prediction models.
Configure your dataset by selecting grid configuration, target dataset,
and data sources in the sidebar, then click "Load Dataset" to begin.
"""
)
# Render sidebar configuration
render_dataset_configuration_sidebar()
# Check if dataset is loaded in session state
if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
st.info(
"👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
)
return
# Get ensemble from session state
ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
st.divider()
# Load all necessary data once
with st.spinner("Loading dataset..."):
# Load training data for all tasks
train_data_dict = load_all_training_data(ensemble)
# Load grid data
grid_gdf = grids.open(ensemble.grid, ensemble.level)
# Load targets (needed by all source data views)
targets = ensemble._read_target()
# Load AlphaEarth data if in members
alphaearth_ds = None
if "AlphaEarth" in ensemble.members:
alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
# Load ArcticDEM data if in members
arcticdem_ds = None
if "ArcticDEM" in ensemble.members:
arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
# Load ERA5 data for all temporal aggregations in members
era5_data = {}
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
for era5_member in era5_members:
era5_ds, _ = load_source_data(ensemble, era5_member)
temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
era5_data[era5_member] = (era5_ds, temporal_type)
st.success(
f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
)
# Render dataset statistics
render_dataset_statistics(ensemble)
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member based on what's in the ensemble
if "AlphaEarth" in ensemble.members:
tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM")
# Check for ERA5 members
if era5_members:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Track current tab index
tab_idx = 0
# Labels tab
with tabs[tab_idx]:
render_labels_view(ensemble, train_data_dict)
tab_idx += 1
# Areas tab
with tabs[tab_idx]:
render_areas_view(ensemble, grid_gdf)
tab_idx += 1
# AlphaEarth tab
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
render_alphaearth_view(ensemble, alphaearth_ds, targets)
tab_idx += 1
# ArcticDEM tab
if "ArcticDEM" in ensemble.members:
with tabs[tab_idx]:
render_arcticdem_view(ensemble, arcticdem_ds, targets)
tab_idx += 1
# ERA5 tab (combining all temporal variants)
if era5_members:
with tabs[tab_idx]:
render_era5_view(ensemble, era5_data, targets)
# Show balloons once after all tabs are rendered
st.balloons()
stopwatch.summary()