Refactor other pages

This commit is contained in:
Tobias Hölzer 2026-01-04 18:39:38 +01:00
parent 4260b492ab
commit 393cc968cb
9 changed files with 962 additions and 559 deletions

View file

@ -11,12 +11,11 @@ Pages:
import streamlit as st import streamlit as st
# from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.inference_page import render_inference_page
# from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page from entropice.dashboard.views.training_data_page import render_training_data_page
# from entropice.dashboard.views.training_data_page import render_training_data_page
def main(): def main():
@ -25,17 +24,17 @@ def main():
# Setup Navigation # Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation( pg = st.navigation(
{ {
"Overview": [overview_page], "Overview": [overview_page],
# "Training": [training_data_page, training_analysis_page], "Training": [training_data_page, training_analysis_page],
# "Model State": [model_state_page], "Model State": [model_state_page],
# "Inference": [inference_page], "Inference": [inference_page],
} }
) )
pg.run() pg.run()

View file

@ -143,8 +143,8 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
# Extract scale information from settings if available # Extract scale information from settings if available
param_scales = {} param_scales = {}
if settings and "param_grid" in settings: if settings and hasattr(settings, "param_grid"):
param_grid = settings["param_grid"] param_grid = settings.param_grid
for param_name, param_config in param_grid.items(): for param_name, param_config in param_grid.items():
if isinstance(param_config, dict) and "distribution" in param_config: if isinstance(param_config, dict) and "distribution" in param_config:
# loguniform distribution indicates log scale # loguniform distribution indicates log scale
@ -1181,10 +1181,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
preds_gdf = gpd.read_parquet(preds_file) preds_gdf = gpd.read_parquet(preds_file)
# Get task and target information from settings # Get task and target information from settings
task = settings.get("task", "binary") task = settings.task
target = settings.get("target", "darts_rts") target = settings.target
grid = settings.get("grid", "hex") grid = settings.grid
level = settings.get("level", 3) level = settings.level
# Create dataset ensemble to get true labels # Create dataset ensemble to get true labels
# We need to load the target data to get true labels # We need to load the target data to get true labels

View file

@ -8,7 +8,7 @@ import streamlit as st
from shapely.geometry import shape from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.data import TrainingResult from entropice.dashboard.utils.loaders import TrainingResult
def _fix_hex_geometry(geom): def _fix_hex_geometry(geom):
@ -197,8 +197,8 @@ def render_inference_map(result: TrainingResult):
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet") preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
# Get settings # Get settings
task = result.settings.get("task", "binary") task = result.settings.task
grid = result.settings.get("grid", "hex") grid = result.settings.grid
# Create controls in columns # Create controls in columns
col1, col2, col3 = st.columns([2, 2, 1]) col1, col2, col3 = st.columns([2, 2, 1])

View file

@ -147,10 +147,11 @@ def load_all_training_data(
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
""" """
dataset = e.create(filter_target_col=e.covcol)
return { return {
"binary": e.create_cat_training_dataset("binary", device="cpu"), "binary": e._cat_and_split(dataset, "binary", device="cpu"),
"count": e.create_cat_training_dataset("count", device="cpu"), "count": e._cat_and_split(dataset, "count", device="cpu"),
"density": e.create_cat_training_dataset("density", device="cpu"), "density": e._cat_and_split(dataset, "density", device="cpu"),
} }

View file

@ -1,7 +1,8 @@
"""Inference page: Visualization of model inference results across the study region.""" """Inference page: Visualization of model inference results across the study region."""
import geopandas as gpd
import streamlit as st import streamlit as st
from entropice.dashboard.utils.data import load_all_training_results from stopuhr import stopwatch
from entropice.dashboard.plots.inference import ( from entropice.dashboard.plots.inference import (
render_class_comparison, render_class_comparison,
@ -10,66 +11,177 @@ from entropice.dashboard.plots.inference import (
render_inference_statistics, render_inference_statistics,
render_spatial_distribution_stats, render_spatial_distribution_stats,
) )
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
@st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.
Args:
training_results: List of available TrainingResult objects.
Returns:
Selected TrainingResult object.
"""
st.header("Select Training Run")
# Create selection options with task-first naming
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to view inference results",
key="inference_run_select",
)
selected_result = training_options[selected_name]
st.divider()
# Show run information in sidebar
st.subheader("Run Information")
st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}")
st.markdown(f"**Model:** {selected_result.settings.model.upper()}")
st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}")
st.markdown(f"**Level:** {selected_result.settings.level}")
st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}")
return selected_result
def render_run_information(selected_result: TrainingResult):
"""Render training run configuration overview.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("📋 Run Configuration")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.task.capitalize())
with col2:
st.metric("Model", selected_result.settings.model.upper())
with col3:
st.metric("Grid", selected_result.settings.grid.capitalize())
with col4:
st.metric("Level", selected_result.settings.level)
with col5:
st.metric("Target", selected_result.settings.target.replace("darts_", ""))
def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render inference summary statistics section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("📊 Inference Summary")
render_inference_statistics(predictions_gdf, task)
def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame):
"""Render spatial coverage statistics section.
Args:
predictions_gdf: GeoDataFrame with predictions.
"""
st.header("🌍 Spatial Coverage")
render_spatial_distribution_stats(predictions_gdf)
def render_map_visualization_section(selected_result: TrainingResult):
"""Render 3D map visualization section.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("🗺️ Interactive Prediction Map")
st.markdown(
"""
3D visualization of predictions across the study region. The map shows predicted
classes with color coding and spatial distribution of model outputs.
"""
)
render_inference_map(selected_result)
def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render class distribution histogram section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("📈 Class Distribution")
st.markdown("Distribution of predicted classes across all inference cells.")
render_class_distribution_histogram(predictions_gdf, task)
def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str):
"""Render class comparison analysis section.
Args:
predictions_gdf: GeoDataFrame with predictions.
task: Task type ('binary', 'count', 'density').
"""
st.header("🔍 Class Comparison Analysis")
st.markdown(
"""
Detailed comparison of predicted classes showing probability distributions
and confidence metrics for different class predictions.
"""
)
render_class_comparison(predictions_gdf, task)
def render_inference_page(): def render_inference_page():
"""Render the Inference page of the dashboard.""" """Render the Inference page of the dashboard."""
st.title("🗺️ Inference Results") st.title("🗺️ Inference Results")
st.markdown(
"""
Explore spatial predictions from trained models across the Arctic permafrost region.
Select a training run from the sidebar to visualize prediction maps, class distributions,
and spatial coverage statistics.
"""
)
# Load all available training results # Load all available training results
training_results = load_all_training_results() training_results = load_all_training_results()
if not training_results: if not training_results:
st.warning("No training results found. Please run some training experiments first.") st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.training`") st.info("Run training using: `pixi run python -m entropice.ml.training`")
return return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection # Sidebar: Training run selection
with st.sidebar: with st.sidebar:
st.header("Select Training Run") selected_result = render_sidebar_selection(training_results)
# Create selection options with task-first naming # Main content area - Run Information
training_options = {tr.get_display_name("task_first"): tr for tr in training_results} render_run_information(selected_result)
selected_name = st.selectbox(
"Training Run",
options=list(training_options.keys()),
index=0,
help="Select a training run to view inference results",
)
selected_result = training_options[selected_name]
st.divider()
# Show run information in sidebar
st.subheader("Run Information")
settings = selected_result.settings
st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
# Main content area - Run Information at the top
st.header("📋 Run Configuration")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
with col2:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
with col3:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
with col4:
st.metric("Level", selected_result.settings.get("level", "Unknown"))
with col5:
st.metric(
"Target",
selected_result.settings.get("target", "Unknown").replace("darts_", ""),
)
st.divider() st.divider()
@ -80,31 +192,33 @@ def render_inference_page():
st.info("Inference results are generated automatically during training.") st.info("Inference results are generated automatically during training.")
return return
# Load predictions for statistics # Load predictions
import geopandas as gpd with st.spinner("Loading inference results..."):
predictions_gdf = gpd.read_parquet(preds_file)
predictions_gdf = gpd.read_parquet(preds_file) task = selected_result.settings.task
task = selected_result.settings.get("task", "binary")
# Inference Statistics Section # Inference Statistics Section
render_inference_statistics(predictions_gdf, task) render_inference_statistics_section(predictions_gdf, task)
st.divider() st.divider()
# Spatial Coverage Section # Spatial Coverage Section
render_spatial_distribution_stats(predictions_gdf) render_spatial_coverage_section(predictions_gdf)
st.divider() st.divider()
# 3D Map Visualization Section # 3D Map Visualization Section
render_inference_map(selected_result) render_map_visualization_section(selected_result)
st.divider() st.divider()
# Class Distribution Section # Class Distribution Section
render_class_distribution_histogram(predictions_gdf, task) render_class_distribution_section(predictions_gdf, task)
st.divider() st.divider()
# Class Comparison Section # Class Comparison Section
render_class_comparison(predictions_gdf, task) render_class_comparison_section(predictions_gdf, task)
st.balloons()
stopwatch.summary()

View file

@ -1,16 +1,8 @@
"""Model State page for the Entropice dashboard.""" """Model State page: Visualization of model internal state and feature importance."""
import streamlit as st import streamlit as st
import xarray as xr import xarray as xr
from entropice.dashboard.utils.data import ( from stopuhr import stopwatch
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
get_members_from_settings,
load_all_training_results,
)
from entropice.dashboard.utils.training import load_model_state
from entropice.dashboard.plots.model_state import ( from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap, plot_arcticdem_heatmap,
@ -26,46 +18,64 @@ from entropice.dashboard.plots.model_state import (
plot_top_features, plot_top_features,
) )
from entropice.dashboard.utils.colors import generate_unified_colormap from entropice.dashboard.utils.colors import generate_unified_colormap
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
from entropice.dashboard.utils.unsembler import (
extract_arcticdem_features,
extract_common_features,
extract_embedding_features,
extract_era5_features,
)
from entropice.utils.types import L2SourceDataset
def render_model_state_page(): def get_members_from_settings(settings) -> list[L2SourceDataset]:
"""Render the Model State page of the dashboard.""" """Extract dataset members from training settings.
st.title("Model State")
st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
# Load available training results Args:
training_results = load_all_training_results() settings: TrainingSettings object containing dataset configuration.
if not training_results: Returns:
st.error("No training results found. Please run a training search first.") List of L2SourceDataset members used in training.
return
# Sidebar: Training run selection """
with st.sidebar: return settings.members
st.header("Select Training Run")
# Result selection with model-first naming
result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(result_options.keys()),
help="Choose a training result to visualize model state",
)
selected_result = result_options[selected_name]
st.divider() @st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.
# Get the model type from settings Args:
model_type = selected_result.settings.get("model", "espa") training_results: List of available TrainingResult objects.
# Load model state Returns:
with st.spinner("Loading model state..."): Selected TrainingResult object.
model_state = load_model_state(selected_result)
if model_state is None:
st.error("Could not load model state for this result.")
return
# Display basic model state info """
st.header("Select Training Run")
# Result selection with task-first naming
result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox(
"Training Run",
options=list(result_options.keys()),
index=0,
help="Choose a training result to visualize model state",
key="model_state_training_run_select",
)
selected_result = result_options[selected_name]
return selected_result
def render_model_info(model_state: xr.Dataset, model_type: str):
"""Render basic model state information.
Args:
model_state: Xarray dataset containing model state.
model_type: Type of model (espa, xgboost, rf, knn).
"""
with st.expander("Model State Information", expanded=False): with st.expander("Model State Information", expanded=False):
st.write(f"**Model Type:** {model_type.upper()}") st.write(f"**Model Type:** {model_type.upper()}")
st.write(f"**Variables:** {list(model_state.data_vars)}") st.write(f"**Variables:** {list(model_state.data_vars)}")
@ -73,15 +83,23 @@ def render_model_state_page():
st.write(f"**Coordinates:** {list(model_state.coords)}") st.write(f"**Coordinates:** {list(model_state.coords)}")
st.write(f"**Attributes:** {dict(model_state.attrs)}") st.write(f"**Attributes:** {dict(model_state.attrs)}")
# Display dataset members summary
st.header("📊 Training Data Summary")
members = get_members_from_settings(selected_result.settings)
st.markdown(f""" def render_training_data_summary(members: list[L2SourceDataset]):
"""Render summary of training data sources.
Args:
members: List of dataset members used in training.
"""
st.header("📊 Training Data Summary")
st.markdown(
f"""
**Dataset Members Used in Training:** {len(members)} **Dataset Members Used in Training:** {len(members)}
The following data sources were used to train this model: The following data sources were used to train this model:
""") """
)
# Create a nice display of members with emojis # Create a nice display of members with emojis
member_display = { member_display = {
@ -98,6 +116,52 @@ def render_model_state_page():
display_name = member_display.get(member, f"📁 {member}") display_name = member_display.get(member, f"📁 {member}")
st.info(display_name) st.info(display_name)
def render_model_state_page():
"""Render the Model State page of the dashboard."""
st.title("🔬 Model State")
st.markdown(
"""
Comprehensive visualization of the best model's internal state and feature importance.
Select a training run from the sidebar to explore model parameters, feature weights,
and data source contributions.
"""
)
# Load available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection
with st.sidebar:
selected_result = render_sidebar_selection(training_results)
# Get the model type from settings
model_type = selected_result.settings.model
# Load model state
with st.spinner("Loading model state..."):
model_state = selected_result.load_model_state()
if model_state is None:
st.error("Could not load model state for this result.")
st.info("The model state file (best_estimator_state.nc) may be missing from the training results.")
return
# Display basic model state info
render_model_info(model_state, model_type)
# Display dataset members summary
members = get_members_from_settings(selected_result.settings)
render_training_data_summary(members)
st.divider() st.divider()
# Render model-specific visualizations # Render model-specific visualizations
@ -112,9 +176,18 @@ def render_model_state_page():
else: else:
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.") st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
st.balloons()
stopwatch.summary()
def render_espa_model_state(model_state: xr.Dataset, selected_result):
"""Render visualizations for ESPA model.""" def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for ESPA model.
Args:
model_state: Xarray dataset containing ESPA model state.
selected_result: TrainingResult object containing training configuration.
"""
# Scale feature weights by number of features # Scale feature weights by number of features
n_features = model_state.sizes["feature"] n_features = model_state.sizes["feature"]
model_state["feature_weights"] *= n_features model_state["feature_weights"] *= n_features
@ -143,8 +216,9 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
common_feature_array = extract_common_features(model_state) common_feature_array = extract_common_features(model_state)
# Generate unified colormaps # Generate unified colormaps (convert dataclass to dict)
_, _, altair_colors = generate_unified_colormap(selected_result.settings) settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes}
_, _, altair_colors = generate_unified_colormap(settings_dict)
# Feature importance section # Feature importance section
st.header("Feature Importance") st.header("Feature Importance")
@ -255,8 +329,14 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array) render_common_features(common_feature_array)
def render_xgboost_model_state(model_state: xr.Dataset, selected_result): def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for XGBoost model.""" """Render visualizations for XGBoost model.
Args:
model_state: Xarray dataset containing XGBoost model state.
selected_result: TrainingResult object containing training configuration.
"""
from entropice.dashboard.plots.model_state import ( from entropice.dashboard.plots.model_state import (
plot_xgboost_feature_importance, plot_xgboost_feature_importance,
plot_xgboost_importance_comparison, plot_xgboost_importance_comparison,
@ -382,8 +462,14 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array) render_common_features(common_feature_array)
def render_rf_model_state(model_state: xr.Dataset, selected_result): def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for Random Forest model.""" """Render visualizations for Random Forest model.
Args:
model_state: Xarray dataset containing Random Forest model state.
selected_result: TrainingResult object containing training configuration.
"""
from entropice.dashboard.plots.model_state import plot_rf_feature_importance from entropice.dashboard.plots.model_state import plot_rf_feature_importance
st.header("🌳 Random Forest Model Analysis") st.header("🌳 Random Forest Model Analysis")
@ -529,8 +615,14 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
render_common_features(common_feature_array) render_common_features(common_feature_array)
def render_knn_model_state(model_state: xr.Dataset, selected_result): def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
"""Render visualizations for KNN model.""" """Render visualizations for KNN model.
Args:
model_state: Xarray dataset containing KNN model state.
selected_result: TrainingResult object containing training configuration.
"""
st.header("🔍 K-Nearest Neighbors Model Analysis") st.header("🔍 K-Nearest Neighbors Model Analysis")
st.markdown( st.markdown(
""" """
@ -568,8 +660,13 @@ def render_knn_model_state(model_state: xr.Dataset, selected_result):
# Helper functions for embedding/era5/common features # Helper functions for embedding/era5/common features
def render_embedding_features(embedding_feature_array): def render_embedding_features(embedding_feature_array: xr.DataArray):
"""Render embedding feature visualizations.""" """Render embedding feature visualizations.
Args:
embedding_feature_array: DataArray containing AlphaEarth embedding feature weights.
"""
with st.container(border=True): with st.container(border=True):
st.header("🛰️ Embedding Feature Analysis") st.header("🛰️ Embedding Feature Analysis")
st.markdown( st.markdown(
@ -619,7 +716,7 @@ def render_embedding_features(embedding_feature_array):
st.dataframe(top_emb, width="stretch") st.dataframe(top_emb, width="stretch")
def render_era5_features(era5_feature_array, temporal_group: str = ""): def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""):
"""Render ERA5 feature visualizations. """Render ERA5 feature visualizations.
Args: Args:
@ -631,9 +728,10 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
with st.container(border=True): with st.container(border=True):
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}") st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else ""
st.markdown( st.markdown(
f""" f"""
Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods
are most important for the model predictions. are most important for the model predictions.
""" """
) )
@ -709,8 +807,13 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
st.dataframe(top_era5, width="stretch") st.dataframe(top_era5, width="stretch")
def render_arcticdem_features(arcticdem_feature_array): def render_arcticdem_features(arcticdem_feature_array: xr.DataArray):
"""Render ArcticDEM feature visualizations.""" """Render ArcticDEM feature visualizations.
Args:
arcticdem_feature_array: DataArray containing ArcticDEM feature weights.
"""
with st.container(border=True): with st.container(border=True):
st.header("🏔️ ArcticDEM Feature Analysis") st.header("🏔️ ArcticDEM Feature Analysis")
st.markdown( st.markdown(
@ -758,8 +861,13 @@ def render_arcticdem_features(arcticdem_feature_array):
st.dataframe(top_arcticdem, width="stretch") st.dataframe(top_arcticdem, width="stretch")
def render_common_features(common_feature_array): def render_common_features(common_feature_array: xr.DataArray):
"""Render common feature visualizations.""" """Render common feature visualizations.
Args:
common_feature_array: DataArray containing common feature weights.
"""
with st.container(border=True): with st.container(border=True):
st.header("🗺️ Common Feature Analysis") st.header("🗺️ Common Feature Analysis")
st.markdown( st.markdown(

View file

@ -658,5 +658,4 @@ def render_overview_page():
render_dataset_analysis() render_dataset_analysis()
st.balloons() st.balloons()
stopwatch.summary() stopwatch.summary()

View file

@ -1,6 +1,7 @@
"""Training Results Analysis page: Analysis of training results and model performance.""" """Training Results Analysis page: Analysis of training results and model performance."""
import streamlit as st import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.hyperparameter_analysis import ( from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space, render_binned_parameter_space,
@ -12,169 +13,176 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
render_performance_summary, render_performance_summary,
render_top_configurations, render_top_configurations,
) )
from entropice.dashboard.utils.data import load_all_training_results from entropice.dashboard.utils.formatters import format_metric_name
from entropice.dashboard.utils.training import ( from entropice.dashboard.utils.loaders import load_all_training_results
format_metric_name, from entropice.dashboard.utils.stats import CVResultsStatistics
get_available_metrics,
get_cv_statistics,
get_parameter_space_summary,
)
def render_training_analysis_page(): @st.fragment
"""Render the Training Results Analysis page of the dashboard.""" def render_analysis_settings_sidebar(training_results):
st.title("🦾 Training Results Analysis") """Render sidebar for training run and analysis settings selection.
# Load all available training results Args:
training_results = load_all_training_results() training_results: List of available TrainingResult objects.
if not training_results: Returns:
st.warning("No training results found. Please run some training experiments first.") Tuple of (selected_result, selected_metric, refit_metric, top_n).
st.info("Run training using: `pixi run python -m entropice.training`")
return
# Sidebar: Training run selection """
with st.sidebar: st.header("Select Training Run")
st.header("Select Training Run")
# Create selection options with task-first naming # Create selection options with task-first naming
training_options = {tr.get_display_name("task_first"): tr for tr in training_results} training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
selected_name = st.selectbox( selected_name = st.selectbox(
"Training Run", "Training Run",
options=list(training_options.keys()), options=list(training_options.keys()),
index=0, index=0,
help="Select a training run to analyze", help="Select a training run to analyze",
) key="training_run_select",
)
selected_result = training_options[selected_name] selected_result = training_options[selected_name]
st.divider() st.divider()
# Metric selection for detailed analysis # Metric selection for detailed analysis
st.subheader("Analysis Settings") st.subheader("Analysis Settings")
available_metrics = get_available_metrics(selected_result.results) available_metrics = selected_result.available_metrics
# Try to get refit metric from settings # Try to get refit metric from settings
refit_metric = selected_result.settings.get("refit_metric") refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
if not refit_metric or refit_metric not in available_metrics: if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric # Infer from task or use first available metric
task = selected_result.settings.get("task", "binary") task = selected_result.settings.task
if task == "binary" and "f1" in available_metrics: if task == "binary" and "f1" in available_metrics:
refit_metric = "f1" refit_metric = "f1"
elif "f1_weighted" in available_metrics: elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted" refit_metric = "f1_weighted"
elif "accuracy" in available_metrics: elif "accuracy" in available_metrics:
refit_metric = "accuracy" refit_metric = "accuracy"
elif available_metrics: elif available_metrics:
refit_metric = available_metrics[0] refit_metric = available_metrics[0]
else:
st.error("No metrics found in results.")
return
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
else: else:
default_metric_idx = 0 st.error("No metrics found in results.")
return None, None, None, None
selected_metric = st.selectbox( if refit_metric in available_metrics:
"Primary Metric for Analysis", default_metric_idx = available_metrics.index(refit_metric)
options=available_metrics, else:
index=default_metric_idx, default_metric_idx = 0
format_func=format_metric_name,
help="Select the metric to focus on for detailed analysis",
)
# Top N configurations selected_metric = st.selectbox(
top_n = st.slider( "Primary Metric for Analysis",
"Top N Configurations", options=available_metrics,
min_value=5, index=default_metric_idx,
max_value=50, format_func=format_metric_name,
value=10, help="Select the metric to focus on for detailed analysis",
step=5, key="metric_select",
help="Number of top configurations to display", )
)
# Main content area - Run Information at the top # Top N configurations
top_n = st.slider(
"Top N Configurations",
min_value=5,
max_value=50,
value=10,
step=5,
help="Number of top configurations to display",
key="top_n_slider",
)
return selected_result, selected_metric, refit_metric, top_n
def render_run_information(selected_result, refit_metric):
"""Render training run configuration overview.
Args:
selected_result: The selected TrainingResult object.
refit_metric: The refit metric used for model selection.
"""
st.header("📋 Run Information") st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6) col1, col2, col3, col4, col5, col6 = st.columns(6)
with col1: with col1:
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize()) st.metric("Task", selected_result.settings.task.capitalize())
with col2: with col2:
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize()) st.metric("Grid", selected_result.settings.grid.capitalize())
with col3: with col3:
st.metric("Level", selected_result.settings.get("level", "Unknown")) st.metric("Level", selected_result.settings.level)
with col4: with col4:
st.metric("Model", selected_result.settings.get("model", "Unknown").upper()) st.metric("Model", selected_result.settings.model.upper())
with col5: with col5:
st.metric("Trials", len(selected_result.results)) st.metric("Trials", len(selected_result.results))
with col6: with col6:
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown")) st.metric("CV Splits", selected_result.settings.cv_splits)
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}") st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
st.divider()
# Main content area def render_cv_statistics_section(selected_result, selected_metric):
results = selected_result.results """Render cross-validation statistics for selected metric.
settings = selected_result.settings
# Performance Summary Section Args:
st.header("📊 Performance Overview") selected_result: The selected TrainingResult object.
selected_metric: The metric to display statistics for.
render_performance_summary(results, refit_metric) """
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
st.divider()
# Quick Statistics
st.header("📈 Cross-Validation Statistics") st.header("📈 Cross-Validation Statistics")
cv_stats = get_cv_statistics(results, selected_metric) from entropice.dashboard.utils.stats import CVMetricStatistics
if cv_stats: cv_stats = CVMetricStatistics.compute(selected_result, selected_metric)
col1, col2, col3, col4, col5 = st.columns(5)
with col1: col1, col2, col3, col4, col5 = st.columns(5)
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
with col2: with col1:
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}") st.metric("Best Score", f"{cv_stats.best_score:.4f}")
with col3: with col2:
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}") st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
with col4: with col3:
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}") st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
with col5: with col4:
st.metric("Median Score", f"{cv_stats['median_score']:.4f}") st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
if "mean_cv_std" in cv_stats: with col5:
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds") st.metric("Median Score", f"{cv_stats.median_score:.4f}")
st.divider() if cv_stats.mean_cv_std is not None:
st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
# Parameter Space Analysis
def render_parameter_space_section(selected_result, selected_metric):
"""Render parameter space analysis section.
Args:
selected_result: The selected TrainingResult object.
selected_metric: The metric to analyze parameters against.
"""
st.header("🔍 Parameter Space Analysis") st.header("🔍 Parameter Space Analysis")
# Compute CV results statistics
cv_results_stats = CVResultsStatistics.compute(selected_result)
# Show parameter space summary # Show parameter space summary
with st.expander("📋 Parameter Space Summary", expanded=False): with st.expander("📋 Parameter Space Summary", expanded=False):
param_summary = get_parameter_space_summary(results) param_summary_df = cv_results_stats.parameters_to_dataframe()
if not param_summary.empty: if not param_summary_df.empty:
st.dataframe(param_summary, hide_index=True, width="stretch") st.dataframe(param_summary_df, hide_index=True, width="stretch")
else: else:
st.info("No parameter information available.") st.info("No parameter information available.")
results = selected_result.results
settings = selected_result.settings
# Parameter distributions # Parameter distributions
st.subheader("📈 Parameter Distributions") st.subheader("📈 Parameter Distributions")
render_parameter_distributions(results, settings) render_parameter_distributions(results, settings)
@ -183,7 +191,7 @@ def render_training_analysis_page():
st.subheader("🎨 Binned Parameter Space") st.subheader("🎨 Binned Parameter Space")
# Check if this is an ESPA model and show ESPA-specific plots # Check if this is an ESPA model and show ESPA-specific plots
model_type = settings.get("model", "espa") model_type = settings.model
if model_type == "espa": if model_type == "espa":
# Show ESPA-specific binned plots (eps_cl vs eps_e binned by K) # Show ESPA-specific binned plots (eps_cl vs eps_e binned by K)
render_espa_binned_parameter_space(results, selected_metric) render_espa_binned_parameter_space(results, selected_metric)
@ -196,31 +204,15 @@ def render_training_analysis_page():
# For non-ESPA models, show the generic binned plots # For non-ESPA models, show the generic binned plots
render_binned_parameter_space(results, selected_metric) render_binned_parameter_space(results, selected_metric)
st.divider()
# Parameter Correlation def render_data_export_section(results, selected_result):
st.header("🔗 Parameter Correlation") """Render data export section with download buttons.
render_parameter_correlation(results, selected_metric) Args:
results: DataFrame with CV results.
selected_result: The selected TrainingResult object.
st.divider() """
# Multi-Metric Comparison
if len(available_metrics) >= 2:
st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
with st.expander("💾 Export Data", expanded=False): with st.expander("💾 Export Data", expanded=False):
st.subheader("Download Results") st.subheader("Download Results")
@ -234,22 +226,114 @@ def render_training_analysis_page():
data=csv_data, data=csv_data,
file_name=f"{selected_result.path.name}_results.csv", file_name=f"{selected_result.path.name}_results.csv",
mime="text/csv", mime="text/csv",
width="stretch",
) )
with col2: with col2:
# Download settings as text # Download settings as JSON
import json import json
settings_json = json.dumps(settings, indent=2) settings_dict = {
"task": selected_result.settings.task,
"grid": selected_result.settings.grid,
"level": selected_result.settings.level,
"model": selected_result.settings.model,
"cv_splits": selected_result.settings.cv_splits,
"classes": selected_result.settings.classes,
}
settings_json = json.dumps(settings_dict, indent=2)
st.download_button( st.download_button(
label="⚙️ Download Settings (JSON)", label="⚙️ Download Settings (JSON)",
data=settings_json, data=settings_json,
file_name=f"{selected_result.path.name}_settings.json", file_name=f"{selected_result.path.name}_settings.json",
mime="application/json", mime="application/json",
width="stretch",
) )
# Show raw data preview # Show raw data preview
st.subheader("Raw Data Preview") st.subheader("Raw Data Preview")
st.dataframe(results.head(100), width="stretch") st.dataframe(results.head(100), width="stretch")
def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard."""
st.title("🦾 Training Results Analysis")
st.markdown(
"""
Analyze training results, hyperparameter search performance, and model configurations.
Select a training run from the sidebar to explore detailed metrics and parameter analysis.
"""
)
# Load all available training results
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
st.info("Run training using: `pixi run python -m entropice.ml.training`")
return
st.success(f"Found **{len(training_results)}** training result(s)")
st.divider()
# Sidebar: Training run selection
with st.sidebar:
selection_result = render_analysis_settings_sidebar(training_results)
if selection_result[0] is None:
return
selected_result, selected_metric, refit_metric, top_n = selection_result
# Main content area
results = selected_result.results
settings = selected_result.settings
# Run Information
render_run_information(selected_result, refit_metric)
st.divider()
# Performance Summary Section
st.header("📊 Performance Overview")
render_performance_summary(results, refit_metric)
st.divider()
# Confusion Matrix Map Section
st.header("🗺️ Prediction Results Map")
render_confusion_matrix_map(selected_result.path, settings)
st.divider()
# Cross-Validation Statistics
render_cv_statistics_section(selected_result, selected_metric)
st.divider()
# Parameter Space Analysis
render_parameter_space_section(selected_result, selected_metric)
st.divider()
# Parameter Correlation
st.header("🔗 Parameter Correlation")
render_parameter_correlation(results, selected_metric)
st.divider()
# Multi-Metric Comparison
if len(selected_result.available_metrics) >= 2:
st.header("📊 Multi-Metric Comparison")
render_multi_metric_comparison(results)
st.divider()
# Top Configurations
st.header("🏆 Top Performing Configurations")
render_top_configurations(results, selected_metric, top_n)
st.divider()
# Raw Data Export
render_data_export_section(results, selected_result)
st.balloons()
stopwatch.summary()

View file

@ -1,6 +1,9 @@
"""Training Data page: Visualization of training data distributions.""" """Training Data page: Visualization of training data distributions."""
from typing import cast
import streamlit as st import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.source_data import ( from entropice.dashboard.plots.source_data import (
render_alphaearth_map, render_alphaearth_map,
@ -19,30 +22,21 @@ from entropice.dashboard.plots.training_data import (
render_spatial_map, render_spatial_map,
) )
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
def render_training_data_page(): def render_dataset_configuration_sidebar():
"""Render the Training Data page of the dashboard.""" """Render dataset configuration selector in sidebar with form.
st.title("Training Data")
# Sidebar widgets for dataset configuration in a form Stores the selected ensemble in session state when form is submitted.
"""
with st.sidebar.form("dataset_config_form"): with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration") st.header("Dataset Configuration")
# Combined grid and level selection # Grid selection
grid_options = [ grid_options = [gc.display_name for gc in grid_configs]
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
grid_level_combined = st.selectbox( grid_level_combined = st.selectbox(
"Grid Configuration", "Grid Configuration",
@ -51,9 +45,8 @@ def render_training_data_page():
help="Select the grid system and resolution level", help="Select the grid system and resolution level",
) )
# Parse grid type and level # Find the selected grid config
grid, level_str = grid_level_combined.split("-") selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
level = int(level_str)
# Target feature selection # Target feature selection
target = st.selectbox( target = st.selectbox(
@ -66,317 +59,422 @@ def render_training_data_page():
# Members selection # Members selection
st.subheader("Dataset Members") st.subheader("Dataset Members")
all_members = [ all_members = cast(
"AlphaEarth", list[L2SourceDataset],
"ArcticDEM", ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
"ERA5-yearly", )
"ERA5-seasonal", selected_members: list[L2SourceDataset] = []
"ERA5-shoulder",
]
selected_members = []
for member in all_members: for member in all_members:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member) selected_members.append(member) # type: ignore[arg-type]
# Form submit button # Form submit button
load_button = st.form_submit_button( load_button = st.form_submit_button(
"Load Dataset", "Load Dataset",
type="primary", type="primary",
width="stretch", use_container_width=True,
disabled=len(selected_members) == 0, disabled=len(selected_members) == 0,
) )
# Create DatasetEnsemble only when form is submitted # Create DatasetEnsemble only when form is submitted
if load_button: if load_button:
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members) ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
target=cast(TargetDataset, target),
members=selected_members,
)
# Store ensemble in session state # Store ensemble in session state
st.session_state["dataset_ensemble"] = ensemble st.session_state["dataset_ensemble"] = ensemble
st.session_state["dataset_loaded"] = True st.session_state["dataset_loaded"] = True
# Display dataset information if loaded
if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
ensemble = st.session_state["dataset_ensemble"]
# Display current configuration def render_dataset_statistics(ensemble: DatasetEnsemble):
st.subheader("📊 Current Configuration") """Render dataset statistics and configuration overview.
# Create a visually appealing layout with columns Args:
col1, col2, col3, col4 = st.columns(4) ensemble: The dataset ensemble configuration.
with col1: """
st.metric(label="Grid Type", value=ensemble.grid.upper()) st.markdown("### 📊 Dataset Configuration")
with col2: # Display current configuration in columns
st.metric(label="Grid Level", value=ensemble.level) col1, col2, col3, col4 = st.columns(4)
with col3: with col1:
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", "")) st.metric(label="Grid Type", value=ensemble.grid.upper())
with col4: with col2:
st.metric(label="Members", value=len(ensemble.members)) st.metric(label="Grid Level", value=ensemble.level)
# Display members in an expandable section with col3:
with st.expander("🗂️ Dataset Members", expanded=False): st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
# Display dataset ID in a styled container with col4:
st.info(f"**Dataset ID:** `{ensemble.id()}`") st.metric(label="Members", value=len(ensemble.members))
# Display dataset statistics # Display members in an expandable section
st.markdown("---") with st.expander("🗂️ Dataset Members", expanded=False):
st.subheader("📈 Dataset Statistics") members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
with st.spinner("Computing dataset statistics..."): # Display dataset ID in a styled container
stats = ensemble.get_stats() st.info(f"**Dataset ID:** `{ensemble.id()}`")
# High-level summary metrics # Display detailed dataset statistics
col1, col2, col3 = st.columns(3) st.markdown("---")
with col1: st.markdown("### 📈 Dataset Statistics")
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Detailed member statistics in expandable section with st.spinner("Computing dataset statistics..."):
with st.expander("📦 Data Source Details", expanded=False): stats = ensemble.get_stats()
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
# Create metrics for this member # High-level summary metrics
metric_cols = st.columns(4) col1, col2, col3 = st.columns(3)
with metric_cols[0]: with col1:
st.metric("Features", member_stats["num_features"]) st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with metric_cols[1]: with col2:
st.metric("Variables", member_stats["num_variables"]) st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with metric_cols[2]: with col3:
# Display dimensions in a more readable format st.metric(label="Data Sources", value=len(stats["members"]))
dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values():
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Show variables as colored badges # Detailed member statistics in expandable section
st.markdown("**Variables:**") with st.expander("📦 Data Source Details", expanded=False):
vars_html = " ".join( for member, member_stats in stats["members"].items():
[ st.markdown(f"### {member}")
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Show dimension details # Create metrics for this member
st.markdown("**Dimensions:**") metric_cols = st.columns(4)
dim_html = " ".join( with metric_cols[0]:
[ st.metric("Features", member_stats["num_features"])
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; ' with metric_cols[1]:
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">' st.metric("Variables", member_stats["num_variables"])
f"{dim_name}: {dim_size}</span>" with metric_cols[2]:
for dim_name, dim_size in member_stats["dimensions"].items() # Display dimensions in a more readable format
] dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
) st.metric("Shape", dim_str)
st.markdown(dim_html, unsafe_allow_html=True) with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
st.markdown("---") # Show variables as colored badges
st.markdown("**Variables:**")
st.markdown("---") vars_html = " ".join(
[
# Create tabs for different data views f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
tab_names = ["📊 Labels", "📐 Areas"] f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
# Add tabs for each member ]
for member in ensemble.members:
if member == "AlphaEarth":
tab_names.append("🌍 AlphaEarth")
elif member == "ArcticDEM":
tab_names.append("🏔️ ArcticDEM")
elif member.startswith("ERA5"):
# Group ERA5 temporal variants
if "🌡️ ERA5" not in tab_names:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Labels tab
with tabs[0]:
st.markdown("### Target Labels Distribution and Spatial Visualization")
# Load training data for all three tasks
with st.spinner("Loading training data for all tasks..."):
train_data_dict = load_all_training_data(ensemble)
# Calculate total samples (use binary as reference)
total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
) )
st.markdown(vars_html, unsafe_allow_html=True)
# Render distribution histograms # Show dimension details
st.markdown("---") st.markdown("**Dimensions:**")
render_all_distribution_histograms(train_data_dict) dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---") st.markdown("---")
# Render spatial map
binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
render_spatial_map(train_data_dict) def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
"""Render target labels distribution and spatial visualization.
# Areas tab Args:
with tabs[1]: ensemble: The dataset ensemble configuration.
st.markdown("### Grid Cell Areas and Land/Water Distribution") train_data_dict: Pre-loaded training data for all tasks.
st.markdown( """
"This visualization shows the spatial distribution of cell areas, land areas, " st.markdown("### Target Labels Distribution and Spatial Visualization")
"water areas, and land ratio across the grid. The grid has been filtered to "
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
"with >10% land coverage."
)
# Load grid data # Calculate total samples (use binary as reference)
grid_gdf = grids.open(ensemble.grid, ensemble.level) total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success( st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
# Show summary statistics # Render distribution histograms
col1, col2, col3, col4 = st.columns(4) st.markdown("---")
with col1: render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
st.markdown("---") st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or ( # Render spatial map
ensemble.grid == "healpix" and ensemble.level == 10 binary_dataset = train_data_dict["binary"]
): assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_areas_map(grid_gdf, ensemble.grid)
# AlphaEarth tab render_spatial_map(train_data_dict)
tab_idx = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
st.markdown("### AlphaEarth Embeddings Analysis")
with st.spinner("Loading AlphaEarth data..."):
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells") def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
"""Render grid cell areas and land/water distribution.
render_alphaearth_overview(alphaearth_ds) Args:
render_alphaearth_plots(alphaearth_ds) ensemble: The dataset ensemble configuration.
grid_gdf: Pre-loaded grid GeoDataFrame.
st.markdown("---") """
st.markdown("### Grid Cell Areas and Land/Water Distribution")
if (ensemble.grid == "hex" and ensemble.level == 6) or ( st.markdown(
ensemble.grid == "healpix" and ensemble.level == 10 "This visualization shows the spatial distribution of cell areas, land areas, "
): "water areas, and land ratio across the grid. The grid has been filtered to "
st.warning( "include only cells in the permafrost region (>50° latitude, <85° latitude) "
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." "with >10% land coverage."
) )
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
tab_idx += 1 st.success(
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
# ArcticDEM tab # Show summary statistics
if "ArcticDEM" in ensemble.members: col1, col2, col3, col4 = st.columns(4)
with tabs[tab_idx]: with col1:
st.markdown("### ArcticDEM Terrain Analysis") st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
with st.spinner("Loading ArcticDEM data..."): st.markdown("---")
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
tab_idx += 1
# ERA5 tab (combining all temporal variants)
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
if era5_members:
with tabs[tab_idx]:
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5:
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
with st.spinner(f"Loading {selected_era5} data..."):
era5_ds, targets = load_source_data(ensemble, selected_era5)
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
# Show balloons once after all tabs are rendered
st.balloons()
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else: else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.") render_areas_map(grid_gdf, ensemble.grid)
def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
"""Render AlphaEarth embeddings analysis.
Args:
ensemble: The dataset ensemble configuration.
alphaearth_ds: Pre-loaded AlphaEarth dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### AlphaEarth Embeddings Analysis")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
render_alphaearth_overview(alphaearth_ds)
render_alphaearth_plots(alphaearth_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
"""Render ArcticDEM terrain analysis.
Args:
ensemble: The dataset ensemble configuration.
arcticdem_ds: Pre-loaded ArcticDEM dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ArcticDEM Terrain Analysis")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
"""Render ERA5 climate data analysis.
Args:
ensemble: The dataset ensemble configuration.
era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5 and selected_era5 in era5_data:
era5_ds, temporal_type = era5_data[selected_era5]
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
def render_training_data_page():
"""Render the Training Data page of the dashboard."""
st.title("🎯 Training Data")
st.markdown(
"""
Explore and visualize the training data for RTS prediction models.
Configure your dataset by selecting grid configuration, target dataset,
and data sources in the sidebar, then click "Load Dataset" to begin.
"""
)
# Render sidebar configuration
render_dataset_configuration_sidebar()
# Check if dataset is loaded in session state
if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
st.info(
"👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
)
return
# Get ensemble from session state
ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
st.divider()
# Load all necessary data once
with st.spinner("Loading dataset..."):
# Load training data for all tasks
train_data_dict = load_all_training_data(ensemble)
# Load grid data
grid_gdf = grids.open(ensemble.grid, ensemble.level)
# Load targets (needed by all source data views)
targets = ensemble._read_target()
# Load AlphaEarth data if in members
alphaearth_ds = None
if "AlphaEarth" in ensemble.members:
alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
# Load ArcticDEM data if in members
arcticdem_ds = None
if "ArcticDEM" in ensemble.members:
arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
# Load ERA5 data for all temporal aggregations in members
era5_data = {}
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
for era5_member in era5_members:
era5_ds, _ = load_source_data(ensemble, era5_member)
temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
era5_data[era5_member] = (era5_ds, temporal_type)
st.success(
f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
)
# Render dataset statistics
render_dataset_statistics(ensemble)
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member based on what's in the ensemble
if "AlphaEarth" in ensemble.members:
tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM")
# Check for ERA5 members
if era5_members:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Track current tab index
tab_idx = 0
# Labels tab
with tabs[tab_idx]:
render_labels_view(ensemble, train_data_dict)
tab_idx += 1
# Areas tab
with tabs[tab_idx]:
render_areas_view(ensemble, grid_gdf)
tab_idx += 1
# AlphaEarth tab
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
render_alphaearth_view(ensemble, alphaearth_ds, targets)
tab_idx += 1
# ArcticDEM tab
if "ArcticDEM" in ensemble.members:
with tabs[tab_idx]:
render_arcticdem_view(ensemble, arcticdem_ds, targets)
tab_idx += 1
# ERA5 tab (combining all temporal variants)
if era5_members:
with tabs[tab_idx]:
render_era5_view(ensemble, era5_data, targets)
# Show balloons once after all tabs are rendered
st.balloons()
stopwatch.summary()