diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 8e4555f..713f7d0 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -11,12 +11,11 @@ Pages: import streamlit as st -# from entropice.dashboard.views.inference_page import render_inference_page -# from entropice.dashboard.views.model_state_page import render_model_state_page +from entropice.dashboard.views.inference_page import render_inference_page +from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.overview_page import render_overview_page - -# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page -# from entropice.dashboard.views.training_data_page import render_training_data_page +from entropice.dashboard.views.training_analysis_page import render_training_analysis_page +from entropice.dashboard.views.training_data_page import render_training_data_page def main(): @@ -25,17 +24,17 @@ def main(): # Setup Navigation overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) - # training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") - # training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") - # model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") - # inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") + training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") + training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") + model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮") + inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️") pg = st.navigation( { "Overview": [overview_page], - # "Training": [training_data_page, training_analysis_page], - # "Model State": [model_state_page], - # "Inference": [inference_page], + "Training": [training_data_page, training_analysis_page], + "Model State": [model_state_page], + "Inference": [inference_page], } ) pg.run() diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index e63620f..79d1867 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -143,8 +143,8 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None # Extract scale information from settings if available param_scales = {} - if settings and "param_grid" in settings: - param_grid = settings["param_grid"] + if settings and hasattr(settings, "param_grid"): + param_grid = settings.param_grid for param_name, param_config in param_grid.items(): if isinstance(param_config, dict) and "distribution" in param_config: # loguniform distribution indicates log scale @@ -1181,10 +1181,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): preds_gdf = gpd.read_parquet(preds_file) # Get task and target information from settings - task = settings.get("task", "binary") - target = settings.get("target", "darts_rts") - grid = settings.get("grid", "hex") - level = settings.get("level", 3) + task = settings.task + target = settings.target + grid = settings.grid + level = settings.level # Create dataset ensemble to get true labels # We need to load the target data to get true labels diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py index a229161..29804f7 100644 --- a/src/entropice/dashboard/plots/inference.py +++ b/src/entropice/dashboard/plots/inference.py @@ -8,7 +8,7 @@ import streamlit as st from shapely.geometry import shape from entropice.dashboard.utils.colors import get_palette -from entropice.dashboard.utils.data import TrainingResult +from entropice.dashboard.utils.loaders import TrainingResult def _fix_hex_geometry(geom): @@ -197,8 +197,8 @@ def render_inference_map(result: TrainingResult): preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet") # Get settings - task = result.settings.get("task", "binary") - grid = result.settings.get("grid", "hex") + task = result.settings.task + grid = result.settings.grid # Create controls in columns col1, col2, col3 = st.columns([2, 2, 1]) diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index 34233a1..bb3b467 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -147,10 +147,11 @@ def load_all_training_data( Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values. """ + dataset = e.create(filter_target_col=e.covcol) return { - "binary": e.create_cat_training_dataset("binary", device="cpu"), - "count": e.create_cat_training_dataset("count", device="cpu"), - "density": e.create_cat_training_dataset("density", device="cpu"), + "binary": e._cat_and_split(dataset, "binary", device="cpu"), + "count": e._cat_and_split(dataset, "count", device="cpu"), + "density": e._cat_and_split(dataset, "density", device="cpu"), } diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py index 20d1759..b5f2835 100644 --- a/src/entropice/dashboard/views/inference_page.py +++ b/src/entropice/dashboard/views/inference_page.py @@ -1,7 +1,8 @@ """Inference page: Visualization of model inference results across the study region.""" +import geopandas as gpd import streamlit as st -from entropice.dashboard.utils.data import load_all_training_results +from stopuhr import stopwatch from entropice.dashboard.plots.inference import ( render_class_comparison, @@ -10,66 +11,177 @@ from entropice.dashboard.plots.inference import ( render_inference_statistics, render_spatial_distribution_stats, ) +from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results + + +@st.fragment +def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult: + """Render sidebar for training run selection. + + Args: + training_results: List of available TrainingResult objects. + + Returns: + Selected TrainingResult object. + + """ + st.header("Select Training Run") + + # Create selection options with task-first naming + training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results} + + selected_name = st.selectbox( + "Training Run", + options=list(training_options.keys()), + index=0, + help="Select a training run to view inference results", + key="inference_run_select", + ) + + selected_result = training_options[selected_name] + + st.divider() + + # Show run information in sidebar + st.subheader("Run Information") + + st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}") + st.markdown(f"**Model:** {selected_result.settings.model.upper()}") + st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}") + st.markdown(f"**Level:** {selected_result.settings.level}") + st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}") + + return selected_result + + +def render_run_information(selected_result: TrainingResult): + """Render training run configuration overview. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("📋 Run Configuration") + + col1, col2, col3, col4, col5 = st.columns(5) + + with col1: + st.metric("Task", selected_result.settings.task.capitalize()) + + with col2: + st.metric("Model", selected_result.settings.model.upper()) + + with col3: + st.metric("Grid", selected_result.settings.grid.capitalize()) + + with col4: + st.metric("Level", selected_result.settings.level) + + with col5: + st.metric("Target", selected_result.settings.target.replace("darts_", "")) + + +def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render inference summary statistics section. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.header("📊 Inference Summary") + render_inference_statistics(predictions_gdf, task) + + +def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame): + """Render spatial coverage statistics section. + + Args: + predictions_gdf: GeoDataFrame with predictions. + + """ + st.header("🌍 Spatial Coverage") + render_spatial_distribution_stats(predictions_gdf) + + +def render_map_visualization_section(selected_result: TrainingResult): + """Render 3D map visualization section. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("🗺️ Interactive Prediction Map") + st.markdown( + """ + 3D visualization of predictions across the study region. The map shows predicted + classes with color coding and spatial distribution of model outputs. + """ + ) + render_inference_map(selected_result) + + +def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render class distribution histogram section. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.header("📈 Class Distribution") + st.markdown("Distribution of predicted classes across all inference cells.") + render_class_distribution_histogram(predictions_gdf, task) + + +def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str): + """Render class comparison analysis section. + + Args: + predictions_gdf: GeoDataFrame with predictions. + task: Task type ('binary', 'count', 'density'). + + """ + st.header("🔍 Class Comparison Analysis") + st.markdown( + """ + Detailed comparison of predicted classes showing probability distributions + and confidence metrics for different class predictions. + """ + ) + render_class_comparison(predictions_gdf, task) def render_inference_page(): """Render the Inference page of the dashboard.""" st.title("🗺️ Inference Results") + st.markdown( + """ + Explore spatial predictions from trained models across the Arctic permafrost region. + Select a training run from the sidebar to visualize prediction maps, class distributions, + and spatial coverage statistics. + """ + ) + # Load all available training results training_results = load_all_training_results() if not training_results: st.warning("No training results found. Please run some training experiments first.") - st.info("Run training using: `pixi run python -m entropice.training`") + st.info("Run training using: `pixi run python -m entropice.ml.training`") return + st.success(f"Found **{len(training_results)}** training result(s)") + + st.divider() + # Sidebar: Training run selection with st.sidebar: - st.header("Select Training Run") + selected_result = render_sidebar_selection(training_results) - # Create selection options with task-first naming - training_options = {tr.get_display_name("task_first"): tr for tr in training_results} - - selected_name = st.selectbox( - "Training Run", - options=list(training_options.keys()), - index=0, - help="Select a training run to view inference results", - ) - - selected_result = training_options[selected_name] - - st.divider() - - # Show run information in sidebar - st.subheader("Run Information") - - settings = selected_result.settings - - st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}") - st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}") - st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}") - st.markdown(f"**Level:** {settings.get('level', 'Unknown')}") - st.markdown(f"**Target:** {settings.get('target', 'Unknown')}") - - # Main content area - Run Information at the top - st.header("📋 Run Configuration") - - col1, col2, col3, col4, col5 = st.columns(5) - with col1: - st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize()) - with col2: - st.metric("Model", selected_result.settings.get("model", "Unknown").upper()) - with col3: - st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize()) - with col4: - st.metric("Level", selected_result.settings.get("level", "Unknown")) - with col5: - st.metric( - "Target", - selected_result.settings.get("target", "Unknown").replace("darts_", ""), - ) + # Main content area - Run Information + render_run_information(selected_result) st.divider() @@ -80,31 +192,33 @@ def render_inference_page(): st.info("Inference results are generated automatically during training.") return - # Load predictions for statistics - import geopandas as gpd - - predictions_gdf = gpd.read_parquet(preds_file) - task = selected_result.settings.get("task", "binary") + # Load predictions + with st.spinner("Loading inference results..."): + predictions_gdf = gpd.read_parquet(preds_file) + task = selected_result.settings.task # Inference Statistics Section - render_inference_statistics(predictions_gdf, task) + render_inference_statistics_section(predictions_gdf, task) st.divider() # Spatial Coverage Section - render_spatial_distribution_stats(predictions_gdf) + render_spatial_coverage_section(predictions_gdf) st.divider() # 3D Map Visualization Section - render_inference_map(selected_result) + render_map_visualization_section(selected_result) st.divider() # Class Distribution Section - render_class_distribution_histogram(predictions_gdf, task) + render_class_distribution_section(predictions_gdf, task) st.divider() # Class Comparison Section - render_class_comparison(predictions_gdf, task) + render_class_comparison_section(predictions_gdf, task) + + st.balloons() + stopwatch.summary() diff --git a/src/entropice/dashboard/views/model_state_page.py b/src/entropice/dashboard/views/model_state_page.py index 5cd90f5..c536d99 100644 --- a/src/entropice/dashboard/views/model_state_page.py +++ b/src/entropice/dashboard/views/model_state_page.py @@ -1,16 +1,8 @@ -"""Model State page for the Entropice dashboard.""" +"""Model State page: Visualization of model internal state and feature importance.""" import streamlit as st import xarray as xr -from entropice.dashboard.utils.data import ( - extract_arcticdem_features, - extract_common_features, - extract_embedding_features, - extract_era5_features, - get_members_from_settings, - load_all_training_results, -) -from entropice.dashboard.utils.training import load_model_state +from stopuhr import stopwatch from entropice.dashboard.plots.model_state import ( plot_arcticdem_heatmap, @@ -26,46 +18,64 @@ from entropice.dashboard.plots.model_state import ( plot_top_features, ) from entropice.dashboard.utils.colors import generate_unified_colormap +from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results +from entropice.dashboard.utils.unsembler import ( + extract_arcticdem_features, + extract_common_features, + extract_embedding_features, + extract_era5_features, +) +from entropice.utils.types import L2SourceDataset -def render_model_state_page(): - """Render the Model State page of the dashboard.""" - st.title("Model State") - st.markdown("Comprehensive visualization of the best model's internal state and feature importance") +def get_members_from_settings(settings) -> list[L2SourceDataset]: + """Extract dataset members from training settings. - # Load available training results - training_results = load_all_training_results() + Args: + settings: TrainingSettings object containing dataset configuration. - if not training_results: - st.error("No training results found. Please run a training search first.") - return + Returns: + List of L2SourceDataset members used in training. - # Sidebar: Training run selection - with st.sidebar: - st.header("Select Training Run") + """ + return settings.members - # Result selection with model-first naming - result_options = {tr.get_display_name("model_first"): tr for tr in training_results} - selected_name = st.selectbox( - "Training Run", - options=list(result_options.keys()), - help="Choose a training result to visualize model state", - ) - selected_result = result_options[selected_name] - st.divider() +@st.fragment +def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult: + """Render sidebar for training run selection. - # Get the model type from settings - model_type = selected_result.settings.get("model", "espa") + Args: + training_results: List of available TrainingResult objects. - # Load model state - with st.spinner("Loading model state..."): - model_state = load_model_state(selected_result) - if model_state is None: - st.error("Could not load model state for this result.") - return + Returns: + Selected TrainingResult object. - # Display basic model state info + """ + st.header("Select Training Run") + + # Result selection with task-first naming + result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results} + selected_name = st.selectbox( + "Training Run", + options=list(result_options.keys()), + index=0, + help="Choose a training result to visualize model state", + key="model_state_training_run_select", + ) + selected_result = result_options[selected_name] + + return selected_result + + +def render_model_info(model_state: xr.Dataset, model_type: str): + """Render basic model state information. + + Args: + model_state: Xarray dataset containing model state. + model_type: Type of model (espa, xgboost, rf, knn). + + """ with st.expander("Model State Information", expanded=False): st.write(f"**Model Type:** {model_type.upper()}") st.write(f"**Variables:** {list(model_state.data_vars)}") @@ -73,15 +83,23 @@ def render_model_state_page(): st.write(f"**Coordinates:** {list(model_state.coords)}") st.write(f"**Attributes:** {dict(model_state.attrs)}") - # Display dataset members summary - st.header("📊 Training Data Summary") - members = get_members_from_settings(selected_result.settings) - st.markdown(f""" +def render_training_data_summary(members: list[L2SourceDataset]): + """Render summary of training data sources. + + Args: + members: List of dataset members used in training. + + """ + st.header("📊 Training Data Summary") + + st.markdown( + f""" **Dataset Members Used in Training:** {len(members)} The following data sources were used to train this model: - """) + """ + ) # Create a nice display of members with emojis member_display = { @@ -98,6 +116,52 @@ def render_model_state_page(): display_name = member_display.get(member, f"📁 {member}") st.info(display_name) + +def render_model_state_page(): + """Render the Model State page of the dashboard.""" + st.title("🔬 Model State") + st.markdown( + """ + Comprehensive visualization of the best model's internal state and feature importance. + Select a training run from the sidebar to explore model parameters, feature weights, + and data source contributions. + """ + ) + + # Load available training results + training_results = load_all_training_results() + + if not training_results: + st.warning("No training results found. Please run some training experiments first.") + st.info("Run training using: `pixi run python -m entropice.ml.training`") + return + + st.success(f"Found **{len(training_results)}** training result(s)") + + st.divider() + + # Sidebar: Training run selection + with st.sidebar: + selected_result = render_sidebar_selection(training_results) + + # Get the model type from settings + model_type = selected_result.settings.model + + # Load model state + with st.spinner("Loading model state..."): + model_state = selected_result.load_model_state() + if model_state is None: + st.error("Could not load model state for this result.") + st.info("The model state file (best_estimator_state.nc) may be missing from the training results.") + return + + # Display basic model state info + render_model_info(model_state, model_type) + + # Display dataset members summary + members = get_members_from_settings(selected_result.settings) + render_training_data_summary(members) + st.divider() # Render model-specific visualizations @@ -112,9 +176,18 @@ def render_model_state_page(): else: st.warning(f"Visualization for model type '{model_type}' is not yet implemented.") + st.balloons() + stopwatch.summary() -def render_espa_model_state(model_state: xr.Dataset, selected_result): - """Render visualizations for ESPA model.""" + +def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult): + """Render visualizations for ESPA model. + + Args: + model_state: Xarray dataset containing ESPA model state. + selected_result: TrainingResult object containing training configuration. + + """ # Scale feature weights by number of features n_features = model_state.sizes["feature"] model_state["feature_weights"] *= n_features @@ -143,8 +216,9 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): common_feature_array = extract_common_features(model_state) - # Generate unified colormaps - _, _, altair_colors = generate_unified_colormap(selected_result.settings) + # Generate unified colormaps (convert dataclass to dict) + settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes} + _, _, altair_colors = generate_unified_colormap(settings_dict) # Feature importance section st.header("Feature Importance") @@ -255,8 +329,14 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result): render_common_features(common_feature_array) -def render_xgboost_model_state(model_state: xr.Dataset, selected_result): - """Render visualizations for XGBoost model.""" +def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult): + """Render visualizations for XGBoost model. + + Args: + model_state: Xarray dataset containing XGBoost model state. + selected_result: TrainingResult object containing training configuration. + + """ from entropice.dashboard.plots.model_state import ( plot_xgboost_feature_importance, plot_xgboost_importance_comparison, @@ -382,8 +462,14 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result): render_common_features(common_feature_array) -def render_rf_model_state(model_state: xr.Dataset, selected_result): - """Render visualizations for Random Forest model.""" +def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult): + """Render visualizations for Random Forest model. + + Args: + model_state: Xarray dataset containing Random Forest model state. + selected_result: TrainingResult object containing training configuration. + + """ from entropice.dashboard.plots.model_state import plot_rf_feature_importance st.header("🌳 Random Forest Model Analysis") @@ -529,8 +615,14 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result): render_common_features(common_feature_array) -def render_knn_model_state(model_state: xr.Dataset, selected_result): - """Render visualizations for KNN model.""" +def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult): + """Render visualizations for KNN model. + + Args: + model_state: Xarray dataset containing KNN model state. + selected_result: TrainingResult object containing training configuration. + + """ st.header("🔍 K-Nearest Neighbors Model Analysis") st.markdown( """ @@ -568,8 +660,13 @@ def render_knn_model_state(model_state: xr.Dataset, selected_result): # Helper functions for embedding/era5/common features -def render_embedding_features(embedding_feature_array): - """Render embedding feature visualizations.""" +def render_embedding_features(embedding_feature_array: xr.DataArray): + """Render embedding feature visualizations. + + Args: + embedding_feature_array: DataArray containing AlphaEarth embedding feature weights. + + """ with st.container(border=True): st.header("🛰️ Embedding Feature Analysis") st.markdown( @@ -619,7 +716,7 @@ def render_embedding_features(embedding_feature_array): st.dataframe(top_emb, width="stretch") -def render_era5_features(era5_feature_array, temporal_group: str = ""): +def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""): """Render ERA5 feature visualizations. Args: @@ -631,9 +728,10 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""): with st.container(border=True): st.header(f"⛅ ERA5 Feature Analysis{group_suffix}") + temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else "" st.markdown( f""" - Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods + Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods are most important for the model predictions. """ ) @@ -709,8 +807,13 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""): st.dataframe(top_era5, width="stretch") -def render_arcticdem_features(arcticdem_feature_array): - """Render ArcticDEM feature visualizations.""" +def render_arcticdem_features(arcticdem_feature_array: xr.DataArray): + """Render ArcticDEM feature visualizations. + + Args: + arcticdem_feature_array: DataArray containing ArcticDEM feature weights. + + """ with st.container(border=True): st.header("🏔️ ArcticDEM Feature Analysis") st.markdown( @@ -758,8 +861,13 @@ def render_arcticdem_features(arcticdem_feature_array): st.dataframe(top_arcticdem, width="stretch") -def render_common_features(common_feature_array): - """Render common feature visualizations.""" +def render_common_features(common_feature_array: xr.DataArray): + """Render common feature visualizations. + + Args: + common_feature_array: DataArray containing common feature weights. + + """ with st.container(border=True): st.header("🗺️ Common Feature Analysis") st.markdown( diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 55c0220..8cf243b 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -658,5 +658,4 @@ def render_overview_page(): render_dataset_analysis() st.balloons() - stopwatch.summary() diff --git a/src/entropice/dashboard/views/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py index e2484ff..8609331 100644 --- a/src/entropice/dashboard/views/training_analysis_page.py +++ b/src/entropice/dashboard/views/training_analysis_page.py @@ -1,6 +1,7 @@ """Training Results Analysis page: Analysis of training results and model performance.""" import streamlit as st +from stopuhr import stopwatch from entropice.dashboard.plots.hyperparameter_analysis import ( render_binned_parameter_space, @@ -12,169 +13,176 @@ from entropice.dashboard.plots.hyperparameter_analysis import ( render_performance_summary, render_top_configurations, ) -from entropice.dashboard.utils.data import load_all_training_results -from entropice.dashboard.utils.training import ( - format_metric_name, - get_available_metrics, - get_cv_statistics, - get_parameter_space_summary, -) +from entropice.dashboard.utils.formatters import format_metric_name +from entropice.dashboard.utils.loaders import load_all_training_results +from entropice.dashboard.utils.stats import CVResultsStatistics -def render_training_analysis_page(): - """Render the Training Results Analysis page of the dashboard.""" - st.title("🦾 Training Results Analysis") +@st.fragment +def render_analysis_settings_sidebar(training_results): + """Render sidebar for training run and analysis settings selection. - # Load all available training results - training_results = load_all_training_results() + Args: + training_results: List of available TrainingResult objects. - if not training_results: - st.warning("No training results found. Please run some training experiments first.") - st.info("Run training using: `pixi run python -m entropice.training`") - return + Returns: + Tuple of (selected_result, selected_metric, refit_metric, top_n). - # Sidebar: Training run selection - with st.sidebar: - st.header("Select Training Run") + """ + st.header("Select Training Run") - # Create selection options with task-first naming - training_options = {tr.get_display_name("task_first"): tr for tr in training_results} + # Create selection options with task-first naming + training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results} - selected_name = st.selectbox( - "Training Run", - options=list(training_options.keys()), - index=0, - help="Select a training run to analyze", - ) + selected_name = st.selectbox( + "Training Run", + options=list(training_options.keys()), + index=0, + help="Select a training run to analyze", + key="training_run_select", + ) - selected_result = training_options[selected_name] + selected_result = training_options[selected_name] - st.divider() + st.divider() - # Metric selection for detailed analysis - st.subheader("Analysis Settings") + # Metric selection for detailed analysis + st.subheader("Analysis Settings") - available_metrics = get_available_metrics(selected_result.results) + available_metrics = selected_result.available_metrics - # Try to get refit metric from settings - refit_metric = selected_result.settings.get("refit_metric") + # Try to get refit metric from settings + refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None - if not refit_metric or refit_metric not in available_metrics: - # Infer from task or use first available metric - task = selected_result.settings.get("task", "binary") - if task == "binary" and "f1" in available_metrics: - refit_metric = "f1" - elif "f1_weighted" in available_metrics: - refit_metric = "f1_weighted" - elif "accuracy" in available_metrics: - refit_metric = "accuracy" - elif available_metrics: - refit_metric = available_metrics[0] - else: - st.error("No metrics found in results.") - return - - if refit_metric in available_metrics: - default_metric_idx = available_metrics.index(refit_metric) + if not refit_metric or refit_metric not in available_metrics: + # Infer from task or use first available metric + task = selected_result.settings.task + if task == "binary" and "f1" in available_metrics: + refit_metric = "f1" + elif "f1_weighted" in available_metrics: + refit_metric = "f1_weighted" + elif "accuracy" in available_metrics: + refit_metric = "accuracy" + elif available_metrics: + refit_metric = available_metrics[0] else: - default_metric_idx = 0 + st.error("No metrics found in results.") + return None, None, None, None - selected_metric = st.selectbox( - "Primary Metric for Analysis", - options=available_metrics, - index=default_metric_idx, - format_func=format_metric_name, - help="Select the metric to focus on for detailed analysis", - ) + if refit_metric in available_metrics: + default_metric_idx = available_metrics.index(refit_metric) + else: + default_metric_idx = 0 - # Top N configurations - top_n = st.slider( - "Top N Configurations", - min_value=5, - max_value=50, - value=10, - step=5, - help="Number of top configurations to display", - ) + selected_metric = st.selectbox( + "Primary Metric for Analysis", + options=available_metrics, + index=default_metric_idx, + format_func=format_metric_name, + help="Select the metric to focus on for detailed analysis", + key="metric_select", + ) - # Main content area - Run Information at the top + # Top N configurations + top_n = st.slider( + "Top N Configurations", + min_value=5, + max_value=50, + value=10, + step=5, + help="Number of top configurations to display", + key="top_n_slider", + ) + + return selected_result, selected_metric, refit_metric, top_n + + +def render_run_information(selected_result, refit_metric): + """Render training run configuration overview. + + Args: + selected_result: The selected TrainingResult object. + refit_metric: The refit metric used for model selection. + + """ st.header("📋 Run Information") col1, col2, col3, col4, col5, col6 = st.columns(6) with col1: - st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize()) + st.metric("Task", selected_result.settings.task.capitalize()) with col2: - st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize()) + st.metric("Grid", selected_result.settings.grid.capitalize()) with col3: - st.metric("Level", selected_result.settings.get("level", "Unknown")) + st.metric("Level", selected_result.settings.level) with col4: - st.metric("Model", selected_result.settings.get("model", "Unknown").upper()) + st.metric("Model", selected_result.settings.model.upper()) with col5: st.metric("Trials", len(selected_result.results)) with col6: - st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown")) + st.metric("CV Splits", selected_result.settings.cv_splits) st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}") - st.divider() - # Main content area - results = selected_result.results - settings = selected_result.settings +def render_cv_statistics_section(selected_result, selected_metric): + """Render cross-validation statistics for selected metric. - # Performance Summary Section - st.header("📊 Performance Overview") + Args: + selected_result: The selected TrainingResult object. + selected_metric: The metric to display statistics for. - render_performance_summary(results, refit_metric) - - st.divider() - - # Confusion Matrix Map Section - st.header("🗺️ Prediction Results Map") - - render_confusion_matrix_map(selected_result.path, settings) - - st.divider() - - # Quick Statistics + """ st.header("📈 Cross-Validation Statistics") - cv_stats = get_cv_statistics(results, selected_metric) + from entropice.dashboard.utils.stats import CVMetricStatistics - if cv_stats: - col1, col2, col3, col4, col5 = st.columns(5) + cv_stats = CVMetricStatistics.compute(selected_result, selected_metric) - with col1: - st.metric("Best Score", f"{cv_stats['best_score']:.4f}") + col1, col2, col3, col4, col5 = st.columns(5) - with col2: - st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}") + with col1: + st.metric("Best Score", f"{cv_stats.best_score:.4f}") - with col3: - st.metric("Std Dev", f"{cv_stats['std_score']:.4f}") + with col2: + st.metric("Mean Score", f"{cv_stats.mean_score:.4f}") - with col4: - st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}") + with col3: + st.metric("Std Dev", f"{cv_stats.std_score:.4f}") - with col5: - st.metric("Median Score", f"{cv_stats['median_score']:.4f}") + with col4: + st.metric("Worst Score", f"{cv_stats.worst_score:.4f}") - if "mean_cv_std" in cv_stats: - st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds") + with col5: + st.metric("Median Score", f"{cv_stats.median_score:.4f}") - st.divider() + if cv_stats.mean_cv_std is not None: + st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds") - # Parameter Space Analysis + +def render_parameter_space_section(selected_result, selected_metric): + """Render parameter space analysis section. + + Args: + selected_result: The selected TrainingResult object. + selected_metric: The metric to analyze parameters against. + + """ st.header("🔍 Parameter Space Analysis") + # Compute CV results statistics + cv_results_stats = CVResultsStatistics.compute(selected_result) + # Show parameter space summary with st.expander("📋 Parameter Space Summary", expanded=False): - param_summary = get_parameter_space_summary(results) - if not param_summary.empty: - st.dataframe(param_summary, hide_index=True, width="stretch") + param_summary_df = cv_results_stats.parameters_to_dataframe() + if not param_summary_df.empty: + st.dataframe(param_summary_df, hide_index=True, width="stretch") else: st.info("No parameter information available.") + results = selected_result.results + settings = selected_result.settings + # Parameter distributions st.subheader("📈 Parameter Distributions") render_parameter_distributions(results, settings) @@ -183,7 +191,7 @@ def render_training_analysis_page(): st.subheader("🎨 Binned Parameter Space") # Check if this is an ESPA model and show ESPA-specific plots - model_type = settings.get("model", "espa") + model_type = settings.model if model_type == "espa": # Show ESPA-specific binned plots (eps_cl vs eps_e binned by K) render_espa_binned_parameter_space(results, selected_metric) @@ -196,31 +204,15 @@ def render_training_analysis_page(): # For non-ESPA models, show the generic binned plots render_binned_parameter_space(results, selected_metric) - st.divider() - # Parameter Correlation - st.header("🔗 Parameter Correlation") +def render_data_export_section(results, selected_result): + """Render data export section with download buttons. - render_parameter_correlation(results, selected_metric) + Args: + results: DataFrame with CV results. + selected_result: The selected TrainingResult object. - st.divider() - - # Multi-Metric Comparison - if len(available_metrics) >= 2: - st.header("📊 Multi-Metric Comparison") - - render_multi_metric_comparison(results) - - st.divider() - - # Top Configurations - st.header("🏆 Top Performing Configurations") - - render_top_configurations(results, selected_metric, top_n) - - st.divider() - - # Raw Data Export + """ with st.expander("💾 Export Data", expanded=False): st.subheader("Download Results") @@ -234,22 +226,114 @@ def render_training_analysis_page(): data=csv_data, file_name=f"{selected_result.path.name}_results.csv", mime="text/csv", - width="stretch", ) with col2: - # Download settings as text + # Download settings as JSON import json - settings_json = json.dumps(settings, indent=2) + settings_dict = { + "task": selected_result.settings.task, + "grid": selected_result.settings.grid, + "level": selected_result.settings.level, + "model": selected_result.settings.model, + "cv_splits": selected_result.settings.cv_splits, + "classes": selected_result.settings.classes, + } + settings_json = json.dumps(settings_dict, indent=2) st.download_button( label="⚙️ Download Settings (JSON)", data=settings_json, file_name=f"{selected_result.path.name}_settings.json", mime="application/json", - width="stretch", ) # Show raw data preview st.subheader("Raw Data Preview") st.dataframe(results.head(100), width="stretch") + + +def render_training_analysis_page(): + """Render the Training Results Analysis page of the dashboard.""" + st.title("🦾 Training Results Analysis") + + st.markdown( + """ + Analyze training results, hyperparameter search performance, and model configurations. + Select a training run from the sidebar to explore detailed metrics and parameter analysis. + """ + ) + + # Load all available training results + training_results = load_all_training_results() + + if not training_results: + st.warning("No training results found. Please run some training experiments first.") + st.info("Run training using: `pixi run python -m entropice.ml.training`") + return + + st.success(f"Found **{len(training_results)}** training result(s)") + + st.divider() + + # Sidebar: Training run selection + with st.sidebar: + selection_result = render_analysis_settings_sidebar(training_results) + if selection_result[0] is None: + return + selected_result, selected_metric, refit_metric, top_n = selection_result + + # Main content area + results = selected_result.results + settings = selected_result.settings + + # Run Information + render_run_information(selected_result, refit_metric) + + st.divider() + + # Performance Summary Section + st.header("📊 Performance Overview") + render_performance_summary(results, refit_metric) + + st.divider() + + # Confusion Matrix Map Section + st.header("🗺️ Prediction Results Map") + render_confusion_matrix_map(selected_result.path, settings) + + st.divider() + + # Cross-Validation Statistics + render_cv_statistics_section(selected_result, selected_metric) + + st.divider() + + # Parameter Space Analysis + render_parameter_space_section(selected_result, selected_metric) + + st.divider() + + # Parameter Correlation + st.header("🔗 Parameter Correlation") + render_parameter_correlation(results, selected_metric) + + st.divider() + + # Multi-Metric Comparison + if len(selected_result.available_metrics) >= 2: + st.header("📊 Multi-Metric Comparison") + render_multi_metric_comparison(results) + st.divider() + + # Top Configurations + st.header("🏆 Top Performing Configurations") + render_top_configurations(results, selected_metric, top_n) + + st.divider() + + # Raw Data Export + render_data_export_section(results, selected_result) + + st.balloons() + stopwatch.summary() diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py index 4523f4c..5ee37de 100644 --- a/src/entropice/dashboard/views/training_data_page.py +++ b/src/entropice/dashboard/views/training_data_page.py @@ -1,6 +1,9 @@ """Training Data page: Visualization of training data distributions.""" +from typing import cast + import streamlit as st +from stopuhr import stopwatch from entropice.dashboard.plots.source_data import ( render_alphaearth_map, @@ -19,30 +22,21 @@ from entropice.dashboard.plots.training_data import ( render_spatial_map, ) from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data -from entropice.ml.dataset import DatasetEnsemble +from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble from entropice.spatial import grids +from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs -def render_training_data_page(): - """Render the Training Data page of the dashboard.""" - st.title("Training Data") +def render_dataset_configuration_sidebar(): + """Render dataset configuration selector in sidebar with form. - # Sidebar widgets for dataset configuration in a form + Stores the selected ensemble in session state when form is submitted. + """ with st.sidebar.form("dataset_config_form"): st.header("Dataset Configuration") - # Combined grid and level selection - grid_options = [ - "hex-3", - "hex-4", - "hex-5", - "hex-6", - "healpix-6", - "healpix-7", - "healpix-8", - "healpix-9", - "healpix-10", - ] + # Grid selection + grid_options = [gc.display_name for gc in grid_configs] grid_level_combined = st.selectbox( "Grid Configuration", @@ -51,9 +45,8 @@ def render_training_data_page(): help="Select the grid system and resolution level", ) - # Parse grid type and level - grid, level_str = grid_level_combined.split("-") - level = int(level_str) + # Find the selected grid config + selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) # Target feature selection target = st.selectbox( @@ -66,317 +59,422 @@ def render_training_data_page(): # Members selection st.subheader("Dataset Members") - all_members = [ - "AlphaEarth", - "ArcticDEM", - "ERA5-yearly", - "ERA5-seasonal", - "ERA5-shoulder", - ] - selected_members = [] + all_members = cast( + list[L2SourceDataset], + ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"], + ) + selected_members: list[L2SourceDataset] = [] for member in all_members: if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): - selected_members.append(member) + selected_members.append(member) # type: ignore[arg-type] # Form submit button load_button = st.form_submit_button( "Load Dataset", type="primary", - width="stretch", + use_container_width=True, disabled=len(selected_members) == 0, ) # Create DatasetEnsemble only when form is submitted if load_button: - ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members) + ensemble = DatasetEnsemble( + grid=selected_grid_config.grid, + level=selected_grid_config.level, + target=cast(TargetDataset, target), + members=selected_members, + ) # Store ensemble in session state st.session_state["dataset_ensemble"] = ensemble st.session_state["dataset_loaded"] = True - # Display dataset information if loaded - if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state: - ensemble = st.session_state["dataset_ensemble"] - # Display current configuration - st.subheader("📊 Current Configuration") +def render_dataset_statistics(ensemble: DatasetEnsemble): + """Render dataset statistics and configuration overview. - # Create a visually appealing layout with columns - col1, col2, col3, col4 = st.columns(4) + Args: + ensemble: The dataset ensemble configuration. - with col1: - st.metric(label="Grid Type", value=ensemble.grid.upper()) + """ + st.markdown("### 📊 Dataset Configuration") - with col2: - st.metric(label="Grid Level", value=ensemble.level) + # Display current configuration in columns + col1, col2, col3, col4 = st.columns(4) - with col3: - st.metric(label="Target Feature", value=ensemble.target.replace("darts_", "")) + with col1: + st.metric(label="Grid Type", value=ensemble.grid.upper()) - with col4: - st.metric(label="Members", value=len(ensemble.members)) + with col2: + st.metric(label="Grid Level", value=ensemble.level) - # Display members in an expandable section - with st.expander("🗂️ Dataset Members", expanded=False): - members_cols = st.columns(len(ensemble.members)) - for idx, member in enumerate(ensemble.members): - with members_cols[idx]: - st.markdown(f"✓ **{member}**") + with col3: + st.metric(label="Target Feature", value=ensemble.target.replace("darts_", "")) - # Display dataset ID in a styled container - st.info(f"**Dataset ID:** `{ensemble.id()}`") + with col4: + st.metric(label="Members", value=len(ensemble.members)) - # Display dataset statistics - st.markdown("---") - st.subheader("📈 Dataset Statistics") + # Display members in an expandable section + with st.expander("🗂️ Dataset Members", expanded=False): + members_cols = st.columns(len(ensemble.members)) + for idx, member in enumerate(ensemble.members): + with members_cols[idx]: + st.markdown(f"✓ **{member}**") - with st.spinner("Computing dataset statistics..."): - stats = ensemble.get_stats() + # Display dataset ID in a styled container + st.info(f"**Dataset ID:** `{ensemble.id()}`") - # High-level summary metrics - col1, col2, col3 = st.columns(3) - with col1: - st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}") - with col2: - st.metric(label="Total Features", value=f"{stats['total_features']:,}") - with col3: - st.metric(label="Data Sources", value=len(stats["members"])) + # Display detailed dataset statistics + st.markdown("---") + st.markdown("### 📈 Dataset Statistics") - # Detailed member statistics in expandable section - with st.expander("📦 Data Source Details", expanded=False): - for member, member_stats in stats["members"].items(): - st.markdown(f"### {member}") + with st.spinner("Computing dataset statistics..."): + stats = ensemble.get_stats() - # Create metrics for this member - metric_cols = st.columns(4) - with metric_cols[0]: - st.metric("Features", member_stats["num_features"]) - with metric_cols[1]: - st.metric("Variables", member_stats["num_variables"]) - with metric_cols[2]: - # Display dimensions in a more readable format - dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) - st.metric("Shape", dim_str) - with metric_cols[3]: - # Calculate total data points - total_points = 1 - for dim_size in member_stats["dimensions"].values(): - total_points *= dim_size - st.metric("Data Points", f"{total_points:,}") + # High-level summary metrics + col1, col2, col3 = st.columns(3) + with col1: + st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}") + with col2: + st.metric(label="Total Features", value=f"{stats['total_features']:,}") + with col3: + st.metric(label="Data Sources", value=len(stats["members"])) - # Show variables as colored badges - st.markdown("**Variables:**") - vars_html = " ".join( - [ - f'{v}' - for v in member_stats["variables"] - ] - ) - st.markdown(vars_html, unsafe_allow_html=True) + # Detailed member statistics in expandable section + with st.expander("📦 Data Source Details", expanded=False): + for member, member_stats in stats["members"].items(): + st.markdown(f"### {member}") - # Show dimension details - st.markdown("**Dimensions:**") - dim_html = " ".join( - [ - f'' - f"{dim_name}: {dim_size}" - for dim_name, dim_size in member_stats["dimensions"].items() - ] - ) - st.markdown(dim_html, unsafe_allow_html=True) + # Create metrics for this member + metric_cols = st.columns(4) + with metric_cols[0]: + st.metric("Features", member_stats["num_features"]) + with metric_cols[1]: + st.metric("Variables", member_stats["num_variables"]) + with metric_cols[2]: + # Display dimensions in a more readable format + dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr] + st.metric("Shape", dim_str) + with metric_cols[3]: + # Calculate total data points + total_points = 1 + for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr] + total_points *= dim_size + st.metric("Data Points", f"{total_points:,}") - st.markdown("---") - - st.markdown("---") - - # Create tabs for different data views - tab_names = ["📊 Labels", "📐 Areas"] - - # Add tabs for each member - for member in ensemble.members: - if member == "AlphaEarth": - tab_names.append("🌍 AlphaEarth") - elif member == "ArcticDEM": - tab_names.append("🏔️ ArcticDEM") - elif member.startswith("ERA5"): - # Group ERA5 temporal variants - if "🌡️ ERA5" not in tab_names: - tab_names.append("🌡️ ERA5") - - tabs = st.tabs(tab_names) - - # Labels tab - with tabs[0]: - st.markdown("### Target Labels Distribution and Spatial Visualization") - - # Load training data for all three tasks - with st.spinner("Loading training data for all tasks..."): - train_data_dict = load_all_training_data(ensemble) - - # Calculate total samples (use binary as reference) - total_samples = len(train_data_dict["binary"]) - train_samples = (train_data_dict["binary"].split == "train").sum().item() - test_samples = (train_data_dict["binary"].split == "test").sum().item() - - st.success( - f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks" + # Show variables as colored badges + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats["variables"] # type: ignore[union-attr] + ] ) + st.markdown(vars_html, unsafe_allow_html=True) - # Render distribution histograms - st.markdown("---") - render_all_distribution_histograms(train_data_dict) + # Show dimension details + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size}" + for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr] + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) st.markdown("---") - # Render spatial map - binary_dataset = train_data_dict["binary"] - assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" - render_spatial_map(train_data_dict) +def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]): + """Render target labels distribution and spatial visualization. - # Areas tab - with tabs[1]: - st.markdown("### Grid Cell Areas and Land/Water Distribution") + Args: + ensemble: The dataset ensemble configuration. + train_data_dict: Pre-loaded training data for all tasks. - st.markdown( - "This visualization shows the spatial distribution of cell areas, land areas, " - "water areas, and land ratio across the grid. The grid has been filtered to " - "include only cells in the permafrost region (>50° latitude, <85° latitude) " - "with >10% land coverage." - ) + """ + st.markdown("### Target Labels Distribution and Spatial Visualization") - # Load grid data - grid_gdf = grids.open(ensemble.grid, ensemble.level) + # Calculate total samples (use binary as reference) + total_samples = len(train_data_dict["binary"]) + train_samples = (train_data_dict["binary"].split == "train").sum().item() + test_samples = (train_data_dict["binary"].split == "test").sum().item() - st.success( - f"Loaded {len(grid_gdf)} grid cells with areas ranging from " - f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²" - ) + st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks") - # Show summary statistics - col1, col2, col3, col4 = st.columns(4) - with col1: - st.metric("Total Cells", f"{len(grid_gdf):,}") - with col2: - st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²") - with col3: - st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}") - with col4: - total_land = grid_gdf["land_area"].sum() - st.metric("Total Land Area", f"{total_land:,.0f} km²") + # Render distribution histograms + st.markdown("---") + render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type] - st.markdown("---") + st.markdown("---") - if (ensemble.grid == "hex" and ensemble.level == 6) or ( - ensemble.grid == "healpix" and ensemble.level == 10 - ): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." - ) - else: - render_areas_map(grid_gdf, ensemble.grid) + # Render spatial map + binary_dataset = train_data_dict["binary"] + assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset" - # AlphaEarth tab - tab_idx = 2 - if "AlphaEarth" in ensemble.members: - with tabs[tab_idx]: - st.markdown("### AlphaEarth Embeddings Analysis") + render_spatial_map(train_data_dict) - with st.spinner("Loading AlphaEarth data..."): - alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth") - st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells") +def render_areas_view(ensemble: DatasetEnsemble, grid_gdf): + """Render grid cell areas and land/water distribution. - render_alphaearth_overview(alphaearth_ds) - render_alphaearth_plots(alphaearth_ds) + Args: + ensemble: The dataset ensemble configuration. + grid_gdf: Pre-loaded grid GeoDataFrame. - st.markdown("---") + """ + st.markdown("### Grid Cell Areas and Land/Water Distribution") - if (ensemble.grid == "hex" and ensemble.level == 6) or ( - ensemble.grid == "healpix" and ensemble.level == 10 - ): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." - ) - else: - render_alphaearth_map(alphaearth_ds, targets, ensemble.grid) + st.markdown( + "This visualization shows the spatial distribution of cell areas, land areas, " + "water areas, and land ratio across the grid. The grid has been filtered to " + "include only cells in the permafrost region (>50° latitude, <85° latitude) " + "with >10% land coverage." + ) - tab_idx += 1 + st.success( + f"Loaded {len(grid_gdf)} grid cells with areas ranging from " + f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²" + ) - # ArcticDEM tab - if "ArcticDEM" in ensemble.members: - with tabs[tab_idx]: - st.markdown("### ArcticDEM Terrain Analysis") + # Show summary statistics + col1, col2, col3, col4 = st.columns(4) + with col1: + st.metric("Total Cells", f"{len(grid_gdf):,}") + with col2: + st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²") + with col3: + st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}") + with col4: + total_land = grid_gdf["land_area"].sum() + st.metric("Total Land Area", f"{total_land:,.0f} km²") - with st.spinner("Loading ArcticDEM data..."): - arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM") - - st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells") - - render_arcticdem_overview(arcticdem_ds) - render_arcticdem_plots(arcticdem_ds) - - st.markdown("---") - - if (ensemble.grid == "hex" and ensemble.level == 6) or ( - ensemble.grid == "healpix" and ensemble.level == 10 - ): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." - ) - else: - render_arcticdem_map(arcticdem_ds, targets, ensemble.grid) - - tab_idx += 1 - - # ERA5 tab (combining all temporal variants) - era5_members = [m for m in ensemble.members if m.startswith("ERA5")] - if era5_members: - with tabs[tab_idx]: - st.markdown("### ERA5 Climate Data Analysis") - - # Let user select which ERA5 temporal aggregation to view - era5_options = { - "ERA5-yearly": "Yearly", - "ERA5-seasonal": "Seasonal (Winter/Summer)", - "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)", - } - - available_era5 = {k: v for k, v in era5_options.items() if k in era5_members} - - selected_era5 = st.selectbox( - "Select ERA5 temporal aggregation", - options=list(available_era5.keys()), - format_func=lambda x: available_era5[x], - key="era5_temporal_select", - ) - - if selected_era5: - temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder' - - with st.spinner(f"Loading {selected_era5} data..."): - era5_ds, targets = load_source_data(ensemble, selected_era5) - - st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells") - - render_era5_overview(era5_ds, temporal_type) - render_era5_plots(era5_ds, temporal_type) - - st.markdown("---") - - if (ensemble.grid == "hex" and ensemble.level == 6) or ( - ensemble.grid == "healpix" and ensemble.level == 10 - ): - st.warning( - "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations." - ) - else: - render_era5_map(era5_ds, targets, ensemble.grid, temporal_type) - - # Show balloons once after all tabs are rendered - st.balloons() + st.markdown("---") + # Check if we should skip map rendering for performance + if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): + st.warning( + "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " + "due to performance considerations." + ) else: - st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.") + render_areas_map(grid_gdf, ensemble.grid) + + +def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets): + """Render AlphaEarth embeddings analysis. + + Args: + ensemble: The dataset ensemble configuration. + alphaearth_ds: Pre-loaded AlphaEarth dataset. + targets: Pre-loaded targets GeoDataFrame. + + """ + st.markdown("### AlphaEarth Embeddings Analysis") + + st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells") + + render_alphaearth_overview(alphaearth_ds) + render_alphaearth_plots(alphaearth_ds) + + st.markdown("---") + + # Check if we should skip map rendering for performance + if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): + st.warning( + "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " + "due to performance considerations." + ) + else: + render_alphaearth_map(alphaearth_ds, targets, ensemble.grid) + + +def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets): + """Render ArcticDEM terrain analysis. + + Args: + ensemble: The dataset ensemble configuration. + arcticdem_ds: Pre-loaded ArcticDEM dataset. + targets: Pre-loaded targets GeoDataFrame. + + """ + st.markdown("### ArcticDEM Terrain Analysis") + + st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells") + + render_arcticdem_overview(arcticdem_ds) + render_arcticdem_plots(arcticdem_ds) + + st.markdown("---") + + # Check if we should skip map rendering for performance + if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): + st.warning( + "🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " + "due to performance considerations." + ) + else: + render_arcticdem_map(arcticdem_ds, targets, ensemble.grid) + + +def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets): + """Render ERA5 climate data analysis. + + Args: + ensemble: The dataset ensemble configuration. + era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples. + targets: Pre-loaded targets GeoDataFrame. + + """ + st.markdown("### ERA5 Climate Data Analysis") + + # Let user select which ERA5 temporal aggregation to view + era5_options = { + "ERA5-yearly": "Yearly", + "ERA5-seasonal": "Seasonal (Winter/Summer)", + "ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)", + } + + available_era5 = {k: v for k, v in era5_options.items() if k in era5_data} + + selected_era5 = st.selectbox( + "Select ERA5 temporal aggregation", + options=list(available_era5.keys()), + format_func=lambda x: available_era5[x], + key="era5_temporal_select", + ) + + if selected_era5 and selected_era5 in era5_data: + era5_ds, temporal_type = era5_data[selected_era5] + + render_era5_overview(era5_ds, temporal_type) + render_era5_plots(era5_ds, temporal_type) + + st.markdown("---") + + # Check if we should skip map rendering for performance + if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10): + st.warning( + "🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) " + "due to performance considerations." + ) + else: + render_era5_map(era5_ds, targets, ensemble.grid, temporal_type) + + +def render_training_data_page(): + """Render the Training Data page of the dashboard.""" + st.title("🎯 Training Data") + + st.markdown( + """ + Explore and visualize the training data for RTS prediction models. + Configure your dataset by selecting grid configuration, target dataset, + and data sources in the sidebar, then click "Load Dataset" to begin. + """ + ) + + # Render sidebar configuration + render_dataset_configuration_sidebar() + + # Check if dataset is loaded in session state + if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state: + st.info( + "👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data" + ) + return + + # Get ensemble from session state + ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"] + + st.divider() + + # Load all necessary data once + with st.spinner("Loading dataset..."): + # Load training data for all tasks + train_data_dict = load_all_training_data(ensemble) + + # Load grid data + grid_gdf = grids.open(ensemble.grid, ensemble.level) + + # Load targets (needed by all source data views) + targets = ensemble._read_target() + + # Load AlphaEarth data if in members + alphaearth_ds = None + if "AlphaEarth" in ensemble.members: + alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth") + + # Load ArcticDEM data if in members + arcticdem_ds = None + if "ArcticDEM" in ensemble.members: + arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM") + + # Load ERA5 data for all temporal aggregations in members + era5_data = {} + era5_members = [m for m in ensemble.members if m.startswith("ERA5")] + for era5_member in era5_members: + era5_ds, _ = load_source_data(ensemble, era5_member) + temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder' + era5_data[era5_member] = (era5_ds, temporal_type) + + st.success( + f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features" + ) + + # Render dataset statistics + render_dataset_statistics(ensemble) + + st.markdown("---") + + # Create tabs for different data views + tab_names = ["📊 Labels", "📐 Areas"] + + # Add tabs for each member based on what's in the ensemble + if "AlphaEarth" in ensemble.members: + tab_names.append("🌍 AlphaEarth") + if "ArcticDEM" in ensemble.members: + tab_names.append("🏔️ ArcticDEM") + + # Check for ERA5 members + if era5_members: + tab_names.append("🌡️ ERA5") + + tabs = st.tabs(tab_names) + + # Track current tab index + tab_idx = 0 + + # Labels tab + with tabs[tab_idx]: + render_labels_view(ensemble, train_data_dict) + tab_idx += 1 + + # Areas tab + with tabs[tab_idx]: + render_areas_view(ensemble, grid_gdf) + tab_idx += 1 + + # AlphaEarth tab + if "AlphaEarth" in ensemble.members: + with tabs[tab_idx]: + render_alphaearth_view(ensemble, alphaearth_ds, targets) + tab_idx += 1 + + # ArcticDEM tab + if "ArcticDEM" in ensemble.members: + with tabs[tab_idx]: + render_arcticdem_view(ensemble, arcticdem_ds, targets) + tab_idx += 1 + + # ERA5 tab (combining all temporal variants) + if era5_members: + with tabs[tab_idx]: + render_era5_view(ensemble, era5_data, targets) + + # Show balloons once after all tabs are rendered + st.balloons() + stopwatch.summary()