Refactor other pages
This commit is contained in:
parent
4260b492ab
commit
393cc968cb
9 changed files with 962 additions and 559 deletions
|
|
@ -11,12 +11,11 @@ Pages:
|
|||
|
||||
import streamlit as st
|
||||
|
||||
# from entropice.dashboard.views.inference_page import render_inference_page
|
||||
# from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
|
||||
# from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
# from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -25,17 +24,17 @@ def main():
|
|||
|
||||
# Setup Navigation
|
||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||
# training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
# training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
# model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
# inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Overview": [overview_page],
|
||||
# "Training": [training_data_page, training_analysis_page],
|
||||
# "Model State": [model_state_page],
|
||||
# "Inference": [inference_page],
|
||||
"Training": [training_data_page, training_analysis_page],
|
||||
"Model State": [model_state_page],
|
||||
"Inference": [inference_page],
|
||||
}
|
||||
)
|
||||
pg.run()
|
||||
|
|
|
|||
|
|
@ -143,8 +143,8 @@ def render_parameter_distributions(results: pd.DataFrame, settings: dict | None
|
|||
|
||||
# Extract scale information from settings if available
|
||||
param_scales = {}
|
||||
if settings and "param_grid" in settings:
|
||||
param_grid = settings["param_grid"]
|
||||
if settings and hasattr(settings, "param_grid"):
|
||||
param_grid = settings.param_grid
|
||||
for param_name, param_config in param_grid.items():
|
||||
if isinstance(param_config, dict) and "distribution" in param_config:
|
||||
# loguniform distribution indicates log scale
|
||||
|
|
@ -1181,10 +1181,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
|
|||
preds_gdf = gpd.read_parquet(preds_file)
|
||||
|
||||
# Get task and target information from settings
|
||||
task = settings.get("task", "binary")
|
||||
target = settings.get("target", "darts_rts")
|
||||
grid = settings.get("grid", "hex")
|
||||
level = settings.get("level", 3)
|
||||
task = settings.task
|
||||
target = settings.target
|
||||
grid = settings.grid
|
||||
level = settings.level
|
||||
|
||||
# Create dataset ensemble to get true labels
|
||||
# We need to load the target data to get true labels
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import streamlit as st
|
|||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.data import TrainingResult
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
|
|
@ -197,8 +197,8 @@ def render_inference_map(result: TrainingResult):
|
|||
preds_gdf = gpd.read_parquet(result.path / "predicted_probabilities.parquet")
|
||||
|
||||
# Get settings
|
||||
task = result.settings.get("task", "binary")
|
||||
grid = result.settings.get("grid", "hex")
|
||||
task = result.settings.task
|
||||
grid = result.settings.grid
|
||||
|
||||
# Create controls in columns
|
||||
col1, col2, col3 = st.columns([2, 2, 1])
|
||||
|
|
|
|||
|
|
@ -147,10 +147,11 @@ def load_all_training_data(
|
|||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||
|
||||
"""
|
||||
dataset = e.create(filter_target_col=e.covcol)
|
||||
return {
|
||||
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
||||
"count": e.create_cat_training_dataset("count", device="cpu"),
|
||||
"density": e.create_cat_training_dataset("density", device="cpu"),
|
||||
"binary": e._cat_and_split(dataset, "binary", device="cpu"),
|
||||
"count": e._cat_and_split(dataset, "count", device="cpu"),
|
||||
"density": e._cat_and_split(dataset, "density", device="cpu"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Inference page: Visualization of model inference results across the study region."""
|
||||
|
||||
import geopandas as gpd
|
||||
import streamlit as st
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.inference import (
|
||||
render_class_comparison,
|
||||
|
|
@ -10,32 +11,31 @@ from entropice.dashboard.plots.inference import (
|
|||
render_inference_statistics,
|
||||
render_spatial_distribution_stats,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
|
||||
|
||||
def render_inference_page():
|
||||
"""Render the Inference page of the dashboard."""
|
||||
st.title("🗺️ Inference Results")
|
||||
@st.fragment
|
||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||
"""Render sidebar for training run selection.
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
Args:
|
||||
training_results: List of available TrainingResult objects.
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.training`")
|
||||
return
|
||||
Returns:
|
||||
Selected TrainingResult object.
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
"""
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options with task-first naming
|
||||
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
|
||||
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(training_options.keys()),
|
||||
index=0,
|
||||
help="Select a training run to view inference results",
|
||||
key="inference_run_select",
|
||||
)
|
||||
|
||||
selected_result = training_options[selected_name]
|
||||
|
|
@ -45,31 +45,143 @@ def render_inference_page():
|
|||
# Show run information in sidebar
|
||||
st.subheader("Run Information")
|
||||
|
||||
settings = selected_result.settings
|
||||
st.markdown(f"**Task:** {selected_result.settings.task.capitalize()}")
|
||||
st.markdown(f"**Model:** {selected_result.settings.model.upper()}")
|
||||
st.markdown(f"**Grid:** {selected_result.settings.grid.capitalize()}")
|
||||
st.markdown(f"**Level:** {selected_result.settings.level}")
|
||||
st.markdown(f"**Target:** {selected_result.settings.target.replace('darts_', '')}")
|
||||
|
||||
st.markdown(f"**Task:** {settings.get('task', 'Unknown').capitalize()}")
|
||||
st.markdown(f"**Model:** {settings.get('model', 'Unknown').upper()}")
|
||||
st.markdown(f"**Grid:** {settings.get('grid', 'Unknown').capitalize()}")
|
||||
st.markdown(f"**Level:** {settings.get('level', 'Unknown')}")
|
||||
st.markdown(f"**Target:** {settings.get('target', 'Unknown')}")
|
||||
return selected_result
|
||||
|
||||
# Main content area - Run Information at the top
|
||||
|
||||
def render_run_information(selected_result: TrainingResult):
|
||||
"""Render training run configuration overview.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("📋 Run Configuration")
|
||||
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
|
||||
st.metric("Task", selected_result.settings.task.capitalize())
|
||||
|
||||
with col2:
|
||||
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
|
||||
st.metric("Model", selected_result.settings.model.upper())
|
||||
|
||||
with col3:
|
||||
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
|
||||
st.metric("Grid", selected_result.settings.grid.capitalize())
|
||||
|
||||
with col4:
|
||||
st.metric("Level", selected_result.settings.get("level", "Unknown"))
|
||||
st.metric("Level", selected_result.settings.level)
|
||||
|
||||
with col5:
|
||||
st.metric(
|
||||
"Target",
|
||||
selected_result.settings.get("target", "Unknown").replace("darts_", ""),
|
||||
st.metric("Target", selected_result.settings.target.replace("darts_", ""))
|
||||
|
||||
|
||||
def render_inference_statistics_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render inference summary statistics section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("📊 Inference Summary")
|
||||
render_inference_statistics(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_spatial_coverage_section(predictions_gdf: gpd.GeoDataFrame):
|
||||
"""Render spatial coverage statistics section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
|
||||
"""
|
||||
st.header("🌍 Spatial Coverage")
|
||||
render_spatial_distribution_stats(predictions_gdf)
|
||||
|
||||
|
||||
def render_map_visualization_section(selected_result: TrainingResult):
|
||||
"""Render 3D map visualization section.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("🗺️ Interactive Prediction Map")
|
||||
st.markdown(
|
||||
"""
|
||||
3D visualization of predictions across the study region. The map shows predicted
|
||||
classes with color coding and spatial distribution of model outputs.
|
||||
"""
|
||||
)
|
||||
render_inference_map(selected_result)
|
||||
|
||||
|
||||
def render_class_distribution_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render class distribution histogram section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("📈 Class Distribution")
|
||||
st.markdown("Distribution of predicted classes across all inference cells.")
|
||||
render_class_distribution_histogram(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_class_comparison_section(predictions_gdf: gpd.GeoDataFrame, task: str):
|
||||
"""Render class comparison analysis section.
|
||||
|
||||
Args:
|
||||
predictions_gdf: GeoDataFrame with predictions.
|
||||
task: Task type ('binary', 'count', 'density').
|
||||
|
||||
"""
|
||||
st.header("🔍 Class Comparison Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
Detailed comparison of predicted classes showing probability distributions
|
||||
and confidence metrics for different class predictions.
|
||||
"""
|
||||
)
|
||||
render_class_comparison(predictions_gdf, task)
|
||||
|
||||
|
||||
def render_inference_page():
|
||||
"""Render the Inference page of the dashboard."""
|
||||
st.title("🗺️ Inference Results")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
Explore spatial predictions from trained models across the Arctic permafrost region.
|
||||
Select a training run from the sidebar to visualize prediction maps, class distributions,
|
||||
and spatial coverage statistics.
|
||||
"""
|
||||
)
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.ml.training`")
|
||||
return
|
||||
|
||||
st.success(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
selected_result = render_sidebar_selection(training_results)
|
||||
|
||||
# Main content area - Run Information
|
||||
render_run_information(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -80,31 +192,33 @@ def render_inference_page():
|
|||
st.info("Inference results are generated automatically during training.")
|
||||
return
|
||||
|
||||
# Load predictions for statistics
|
||||
import geopandas as gpd
|
||||
|
||||
# Load predictions
|
||||
with st.spinner("Loading inference results..."):
|
||||
predictions_gdf = gpd.read_parquet(preds_file)
|
||||
task = selected_result.settings.get("task", "binary")
|
||||
task = selected_result.settings.task
|
||||
|
||||
# Inference Statistics Section
|
||||
render_inference_statistics(predictions_gdf, task)
|
||||
render_inference_statistics_section(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Spatial Coverage Section
|
||||
render_spatial_distribution_stats(predictions_gdf)
|
||||
render_spatial_coverage_section(predictions_gdf)
|
||||
|
||||
st.divider()
|
||||
|
||||
# 3D Map Visualization Section
|
||||
render_inference_map(selected_result)
|
||||
render_map_visualization_section(selected_result)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Distribution Section
|
||||
render_class_distribution_histogram(predictions_gdf, task)
|
||||
render_class_distribution_section(predictions_gdf, task)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Class Comparison Section
|
||||
render_class_comparison(predictions_gdf, task)
|
||||
render_class_comparison_section(predictions_gdf, task)
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -1,16 +1,8 @@
|
|||
"""Model State page for the Entropice dashboard."""
|
||||
"""Model State page: Visualization of model internal state and feature importance."""
|
||||
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
get_members_from_settings,
|
||||
load_all_training_results,
|
||||
)
|
||||
from entropice.dashboard.utils.training import load_model_state
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_arcticdem_heatmap,
|
||||
|
|
@ -26,46 +18,64 @@ from entropice.dashboard.plots.model_state import (
|
|||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
|
||||
from entropice.dashboard.utils.unsembler import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
extract_embedding_features,
|
||||
extract_era5_features,
|
||||
)
|
||||
from entropice.utils.types import L2SourceDataset
|
||||
|
||||
|
||||
def render_model_state_page():
|
||||
"""Render the Model State page of the dashboard."""
|
||||
st.title("Model State")
|
||||
st.markdown("Comprehensive visualization of the best model's internal state and feature importance")
|
||||
def get_members_from_settings(settings) -> list[L2SourceDataset]:
|
||||
"""Extract dataset members from training settings.
|
||||
|
||||
# Load available training results
|
||||
training_results = load_all_training_results()
|
||||
Args:
|
||||
settings: TrainingSettings object containing dataset configuration.
|
||||
|
||||
if not training_results:
|
||||
st.error("No training results found. Please run a training search first.")
|
||||
return
|
||||
Returns:
|
||||
List of L2SourceDataset members used in training.
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
"""
|
||||
return settings.members
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||
"""Render sidebar for training run selection.
|
||||
|
||||
Args:
|
||||
training_results: List of available TrainingResult objects.
|
||||
|
||||
Returns:
|
||||
Selected TrainingResult object.
|
||||
|
||||
"""
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Result selection with model-first naming
|
||||
result_options = {tr.get_display_name("model_first"): tr for tr in training_results}
|
||||
# Result selection with task-first naming
|
||||
result_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(result_options.keys()),
|
||||
index=0,
|
||||
help="Choose a training result to visualize model state",
|
||||
key="model_state_training_run_select",
|
||||
)
|
||||
selected_result = result_options[selected_name]
|
||||
|
||||
st.divider()
|
||||
return selected_result
|
||||
|
||||
# Get the model type from settings
|
||||
model_type = selected_result.settings.get("model", "espa")
|
||||
|
||||
# Load model state
|
||||
with st.spinner("Loading model state..."):
|
||||
model_state = load_model_state(selected_result)
|
||||
if model_state is None:
|
||||
st.error("Could not load model state for this result.")
|
||||
return
|
||||
def render_model_info(model_state: xr.Dataset, model_type: str):
|
||||
"""Render basic model state information.
|
||||
|
||||
# Display basic model state info
|
||||
Args:
|
||||
model_state: Xarray dataset containing model state.
|
||||
model_type: Type of model (espa, xgboost, rf, knn).
|
||||
|
||||
"""
|
||||
with st.expander("Model State Information", expanded=False):
|
||||
st.write(f"**Model Type:** {model_type.upper()}")
|
||||
st.write(f"**Variables:** {list(model_state.data_vars)}")
|
||||
|
|
@ -73,15 +83,23 @@ def render_model_state_page():
|
|||
st.write(f"**Coordinates:** {list(model_state.coords)}")
|
||||
st.write(f"**Attributes:** {dict(model_state.attrs)}")
|
||||
|
||||
# Display dataset members summary
|
||||
st.header("📊 Training Data Summary")
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
|
||||
st.markdown(f"""
|
||||
def render_training_data_summary(members: list[L2SourceDataset]):
|
||||
"""Render summary of training data sources.
|
||||
|
||||
Args:
|
||||
members: List of dataset members used in training.
|
||||
|
||||
"""
|
||||
st.header("📊 Training Data Summary")
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
**Dataset Members Used in Training:** {len(members)}
|
||||
|
||||
The following data sources were used to train this model:
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
# Create a nice display of members with emojis
|
||||
member_display = {
|
||||
|
|
@ -98,6 +116,52 @@ def render_model_state_page():
|
|||
display_name = member_display.get(member, f"📁 {member}")
|
||||
st.info(display_name)
|
||||
|
||||
|
||||
def render_model_state_page():
|
||||
"""Render the Model State page of the dashboard."""
|
||||
st.title("🔬 Model State")
|
||||
st.markdown(
|
||||
"""
|
||||
Comprehensive visualization of the best model's internal state and feature importance.
|
||||
Select a training run from the sidebar to explore model parameters, feature weights,
|
||||
and data source contributions.
|
||||
"""
|
||||
)
|
||||
|
||||
# Load available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.ml.training`")
|
||||
return
|
||||
|
||||
st.success(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
selected_result = render_sidebar_selection(training_results)
|
||||
|
||||
# Get the model type from settings
|
||||
model_type = selected_result.settings.model
|
||||
|
||||
# Load model state
|
||||
with st.spinner("Loading model state..."):
|
||||
model_state = selected_result.load_model_state()
|
||||
if model_state is None:
|
||||
st.error("Could not load model state for this result.")
|
||||
st.info("The model state file (best_estimator_state.nc) may be missing from the training results.")
|
||||
return
|
||||
|
||||
# Display basic model state info
|
||||
render_model_info(model_state, model_type)
|
||||
|
||||
# Display dataset members summary
|
||||
members = get_members_from_settings(selected_result.settings)
|
||||
render_training_data_summary(members)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render model-specific visualizations
|
||||
|
|
@ -112,9 +176,18 @@ def render_model_state_page():
|
|||
else:
|
||||
st.warning(f"Visualization for model type '{model_type}' is not yet implemented.")
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
||||
def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for ESPA model."""
|
||||
|
||||
def render_espa_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for ESPA model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing ESPA model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
# Scale feature weights by number of features
|
||||
n_features = model_state.sizes["feature"]
|
||||
model_state["feature_weights"] *= n_features
|
||||
|
|
@ -143,8 +216,9 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
common_feature_array = extract_common_features(model_state)
|
||||
|
||||
# Generate unified colormaps
|
||||
_, _, altair_colors = generate_unified_colormap(selected_result.settings)
|
||||
# Generate unified colormaps (convert dataclass to dict)
|
||||
settings_dict = {"task": selected_result.settings.task, "classes": selected_result.settings.classes}
|
||||
_, _, altair_colors = generate_unified_colormap(settings_dict)
|
||||
|
||||
# Feature importance section
|
||||
st.header("Feature Importance")
|
||||
|
|
@ -255,8 +329,14 @@ def render_espa_model_state(model_state: xr.Dataset, selected_result):
|
|||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for XGBoost model."""
|
||||
def render_xgboost_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for XGBoost model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing XGBoost model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_xgboost_feature_importance,
|
||||
plot_xgboost_importance_comparison,
|
||||
|
|
@ -382,8 +462,14 @@ def render_xgboost_model_state(model_state: xr.Dataset, selected_result):
|
|||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for Random Forest model."""
|
||||
def render_rf_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for Random Forest model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing Random Forest model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
from entropice.dashboard.plots.model_state import plot_rf_feature_importance
|
||||
|
||||
st.header("🌳 Random Forest Model Analysis")
|
||||
|
|
@ -529,8 +615,14 @@ def render_rf_model_state(model_state: xr.Dataset, selected_result):
|
|||
render_common_features(common_feature_array)
|
||||
|
||||
|
||||
def render_knn_model_state(model_state: xr.Dataset, selected_result):
|
||||
"""Render visualizations for KNN model."""
|
||||
def render_knn_model_state(model_state: xr.Dataset, selected_result: TrainingResult):
|
||||
"""Render visualizations for KNN model.
|
||||
|
||||
Args:
|
||||
model_state: Xarray dataset containing KNN model state.
|
||||
selected_result: TrainingResult object containing training configuration.
|
||||
|
||||
"""
|
||||
st.header("🔍 K-Nearest Neighbors Model Analysis")
|
||||
st.markdown(
|
||||
"""
|
||||
|
|
@ -568,8 +660,13 @@ def render_knn_model_state(model_state: xr.Dataset, selected_result):
|
|||
|
||||
|
||||
# Helper functions for embedding/era5/common features
|
||||
def render_embedding_features(embedding_feature_array):
|
||||
"""Render embedding feature visualizations."""
|
||||
def render_embedding_features(embedding_feature_array: xr.DataArray):
|
||||
"""Render embedding feature visualizations.
|
||||
|
||||
Args:
|
||||
embedding_feature_array: DataArray containing AlphaEarth embedding feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🛰️ Embedding Feature Analysis")
|
||||
st.markdown(
|
||||
|
|
@ -619,7 +716,7 @@ def render_embedding_features(embedding_feature_array):
|
|||
st.dataframe(top_emb, width="stretch")
|
||||
|
||||
|
||||
def render_era5_features(era5_feature_array, temporal_group: str = ""):
|
||||
def render_era5_features(era5_feature_array: xr.DataArray, temporal_group: str = ""):
|
||||
"""Render ERA5 feature visualizations.
|
||||
|
||||
Args:
|
||||
|
|
@ -631,9 +728,10 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
|
|||
|
||||
with st.container(border=True):
|
||||
st.header(f"⛅ ERA5 Feature Analysis{group_suffix}")
|
||||
temporal_suffix = f" for {temporal_group.lower()} aggregation" if temporal_group else ""
|
||||
st.markdown(
|
||||
f"""
|
||||
Analysis of ERA5 climate features{" for " + temporal_group.lower() + " aggregation" if temporal_group else ""} showing which variables and time periods
|
||||
Analysis of ERA5 climate features{temporal_suffix} showing which variables and time periods
|
||||
are most important for the model predictions.
|
||||
"""
|
||||
)
|
||||
|
|
@ -709,8 +807,13 @@ def render_era5_features(era5_feature_array, temporal_group: str = ""):
|
|||
st.dataframe(top_era5, width="stretch")
|
||||
|
||||
|
||||
def render_arcticdem_features(arcticdem_feature_array):
|
||||
"""Render ArcticDEM feature visualizations."""
|
||||
def render_arcticdem_features(arcticdem_feature_array: xr.DataArray):
|
||||
"""Render ArcticDEM feature visualizations.
|
||||
|
||||
Args:
|
||||
arcticdem_feature_array: DataArray containing ArcticDEM feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🏔️ ArcticDEM Feature Analysis")
|
||||
st.markdown(
|
||||
|
|
@ -758,8 +861,13 @@ def render_arcticdem_features(arcticdem_feature_array):
|
|||
st.dataframe(top_arcticdem, width="stretch")
|
||||
|
||||
|
||||
def render_common_features(common_feature_array):
|
||||
"""Render common feature visualizations."""
|
||||
def render_common_features(common_feature_array: xr.DataArray):
|
||||
"""Render common feature visualizations.
|
||||
|
||||
Args:
|
||||
common_feature_array: DataArray containing common feature weights.
|
||||
|
||||
"""
|
||||
with st.container(border=True):
|
||||
st.header("🗺️ Common Feature Analysis")
|
||||
st.markdown(
|
||||
|
|
|
|||
|
|
@ -658,5 +658,4 @@ def render_overview_page():
|
|||
render_dataset_analysis()
|
||||
|
||||
st.balloons()
|
||||
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Training Results Analysis page: Analysis of training results and model performance."""
|
||||
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||
render_binned_parameter_space,
|
||||
|
|
@ -12,39 +13,33 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
|
|||
render_performance_summary,
|
||||
render_top_configurations,
|
||||
)
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
from entropice.dashboard.utils.training import (
|
||||
format_metric_name,
|
||||
get_available_metrics,
|
||||
get_cv_statistics,
|
||||
get_parameter_space_summary,
|
||||
)
|
||||
from entropice.dashboard.utils.formatters import format_metric_name
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import CVResultsStatistics
|
||||
|
||||
|
||||
def render_training_analysis_page():
|
||||
"""Render the Training Results Analysis page of the dashboard."""
|
||||
st.title("🦾 Training Results Analysis")
|
||||
@st.fragment
|
||||
def render_analysis_settings_sidebar(training_results):
|
||||
"""Render sidebar for training run and analysis settings selection.
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
Args:
|
||||
training_results: List of available TrainingResult objects.
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.training`")
|
||||
return
|
||||
Returns:
|
||||
Tuple of (selected_result, selected_metric, refit_metric, top_n).
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
"""
|
||||
st.header("Select Training Run")
|
||||
|
||||
# Create selection options with task-first naming
|
||||
training_options = {tr.get_display_name("task_first"): tr for tr in training_results}
|
||||
training_options = {tr.display_info.get_display_name("task_first"): tr for tr in training_results}
|
||||
|
||||
selected_name = st.selectbox(
|
||||
"Training Run",
|
||||
options=list(training_options.keys()),
|
||||
index=0,
|
||||
help="Select a training run to analyze",
|
||||
key="training_run_select",
|
||||
)
|
||||
|
||||
selected_result = training_options[selected_name]
|
||||
|
|
@ -54,14 +49,14 @@ def render_training_analysis_page():
|
|||
# Metric selection for detailed analysis
|
||||
st.subheader("Analysis Settings")
|
||||
|
||||
available_metrics = get_available_metrics(selected_result.results)
|
||||
available_metrics = selected_result.available_metrics
|
||||
|
||||
# Try to get refit metric from settings
|
||||
refit_metric = selected_result.settings.get("refit_metric")
|
||||
refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
|
||||
|
||||
if not refit_metric or refit_metric not in available_metrics:
|
||||
# Infer from task or use first available metric
|
||||
task = selected_result.settings.get("task", "binary")
|
||||
task = selected_result.settings.task
|
||||
if task == "binary" and "f1" in available_metrics:
|
||||
refit_metric = "f1"
|
||||
elif "f1_weighted" in available_metrics:
|
||||
|
|
@ -72,7 +67,7 @@ def render_training_analysis_page():
|
|||
refit_metric = available_metrics[0]
|
||||
else:
|
||||
st.error("No metrics found in results.")
|
||||
return
|
||||
return None, None, None, None
|
||||
|
||||
if refit_metric in available_metrics:
|
||||
default_metric_idx = available_metrics.index(refit_metric)
|
||||
|
|
@ -85,6 +80,7 @@ def render_training_analysis_page():
|
|||
index=default_metric_idx,
|
||||
format_func=format_metric_name,
|
||||
help="Select the metric to focus on for detailed analysis",
|
||||
key="metric_select",
|
||||
)
|
||||
|
||||
# Top N configurations
|
||||
|
|
@ -95,86 +91,98 @@ def render_training_analysis_page():
|
|||
value=10,
|
||||
step=5,
|
||||
help="Number of top configurations to display",
|
||||
key="top_n_slider",
|
||||
)
|
||||
|
||||
# Main content area - Run Information at the top
|
||||
return selected_result, selected_metric, refit_metric, top_n
|
||||
|
||||
|
||||
def render_run_information(selected_result, refit_metric):
|
||||
"""Render training run configuration overview.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
refit_metric: The refit metric used for model selection.
|
||||
|
||||
"""
|
||||
st.header("📋 Run Information")
|
||||
|
||||
col1, col2, col3, col4, col5, col6 = st.columns(6)
|
||||
with col1:
|
||||
st.metric("Task", selected_result.settings.get("task", "Unknown").capitalize())
|
||||
st.metric("Task", selected_result.settings.task.capitalize())
|
||||
with col2:
|
||||
st.metric("Grid", selected_result.settings.get("grid", "Unknown").capitalize())
|
||||
st.metric("Grid", selected_result.settings.grid.capitalize())
|
||||
with col3:
|
||||
st.metric("Level", selected_result.settings.get("level", "Unknown"))
|
||||
st.metric("Level", selected_result.settings.level)
|
||||
with col4:
|
||||
st.metric("Model", selected_result.settings.get("model", "Unknown").upper())
|
||||
st.metric("Model", selected_result.settings.model.upper())
|
||||
with col5:
|
||||
st.metric("Trials", len(selected_result.results))
|
||||
with col6:
|
||||
st.metric("CV Splits", selected_result.settings.get("cv_splits", "Unknown"))
|
||||
st.metric("CV Splits", selected_result.settings.cv_splits)
|
||||
|
||||
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Main content area
|
||||
results = selected_result.results
|
||||
settings = selected_result.settings
|
||||
def render_cv_statistics_section(selected_result, selected_metric):
|
||||
"""Render cross-validation statistics for selected metric.
|
||||
|
||||
# Performance Summary Section
|
||||
st.header("📊 Performance Overview")
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
selected_metric: The metric to display statistics for.
|
||||
|
||||
render_performance_summary(results, refit_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Confusion Matrix Map Section
|
||||
st.header("🗺️ Prediction Results Map")
|
||||
|
||||
render_confusion_matrix_map(selected_result.path, settings)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Quick Statistics
|
||||
"""
|
||||
st.header("📈 Cross-Validation Statistics")
|
||||
|
||||
cv_stats = get_cv_statistics(results, selected_metric)
|
||||
from entropice.dashboard.utils.stats import CVMetricStatistics
|
||||
|
||||
cv_stats = CVMetricStatistics.compute(selected_result, selected_metric)
|
||||
|
||||
if cv_stats:
|
||||
col1, col2, col3, col4, col5 = st.columns(5)
|
||||
|
||||
with col1:
|
||||
st.metric("Best Score", f"{cv_stats['best_score']:.4f}")
|
||||
st.metric("Best Score", f"{cv_stats.best_score:.4f}")
|
||||
|
||||
with col2:
|
||||
st.metric("Mean Score", f"{cv_stats['mean_score']:.4f}")
|
||||
st.metric("Mean Score", f"{cv_stats.mean_score:.4f}")
|
||||
|
||||
with col3:
|
||||
st.metric("Std Dev", f"{cv_stats['std_score']:.4f}")
|
||||
st.metric("Std Dev", f"{cv_stats.std_score:.4f}")
|
||||
|
||||
with col4:
|
||||
st.metric("Worst Score", f"{cv_stats['worst_score']:.4f}")
|
||||
st.metric("Worst Score", f"{cv_stats.worst_score:.4f}")
|
||||
|
||||
with col5:
|
||||
st.metric("Median Score", f"{cv_stats['median_score']:.4f}")
|
||||
st.metric("Median Score", f"{cv_stats.median_score:.4f}")
|
||||
|
||||
if "mean_cv_std" in cv_stats:
|
||||
st.info(f"**Mean CV Std:** {cv_stats['mean_cv_std']:.4f} - Average standard deviation across CV folds")
|
||||
if cv_stats.mean_cv_std is not None:
|
||||
st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Space Analysis
|
||||
def render_parameter_space_section(selected_result, selected_metric):
|
||||
"""Render parameter space analysis section.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
selected_metric: The metric to analyze parameters against.
|
||||
|
||||
"""
|
||||
st.header("🔍 Parameter Space Analysis")
|
||||
|
||||
# Compute CV results statistics
|
||||
cv_results_stats = CVResultsStatistics.compute(selected_result)
|
||||
|
||||
# Show parameter space summary
|
||||
with st.expander("📋 Parameter Space Summary", expanded=False):
|
||||
param_summary = get_parameter_space_summary(results)
|
||||
if not param_summary.empty:
|
||||
st.dataframe(param_summary, hide_index=True, width="stretch")
|
||||
param_summary_df = cv_results_stats.parameters_to_dataframe()
|
||||
if not param_summary_df.empty:
|
||||
st.dataframe(param_summary_df, hide_index=True, width="stretch")
|
||||
else:
|
||||
st.info("No parameter information available.")
|
||||
|
||||
results = selected_result.results
|
||||
settings = selected_result.settings
|
||||
|
||||
# Parameter distributions
|
||||
st.subheader("📈 Parameter Distributions")
|
||||
render_parameter_distributions(results, settings)
|
||||
|
|
@ -183,7 +191,7 @@ def render_training_analysis_page():
|
|||
st.subheader("🎨 Binned Parameter Space")
|
||||
|
||||
# Check if this is an ESPA model and show ESPA-specific plots
|
||||
model_type = settings.get("model", "espa")
|
||||
model_type = settings.model
|
||||
if model_type == "espa":
|
||||
# Show ESPA-specific binned plots (eps_cl vs eps_e binned by K)
|
||||
render_espa_binned_parameter_space(results, selected_metric)
|
||||
|
|
@ -196,31 +204,15 @@ def render_training_analysis_page():
|
|||
# For non-ESPA models, show the generic binned plots
|
||||
render_binned_parameter_space(results, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Correlation
|
||||
st.header("🔗 Parameter Correlation")
|
||||
def render_data_export_section(results, selected_result):
|
||||
"""Render data export section with download buttons.
|
||||
|
||||
render_parameter_correlation(results, selected_metric)
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
st.divider()
|
||||
|
||||
# Multi-Metric Comparison
|
||||
if len(available_metrics) >= 2:
|
||||
st.header("📊 Multi-Metric Comparison")
|
||||
|
||||
render_multi_metric_comparison(results)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Top Configurations
|
||||
st.header("🏆 Top Performing Configurations")
|
||||
|
||||
render_top_configurations(results, selected_metric, top_n)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Raw Data Export
|
||||
"""
|
||||
with st.expander("💾 Export Data", expanded=False):
|
||||
st.subheader("Download Results")
|
||||
|
||||
|
|
@ -234,22 +226,114 @@ def render_training_analysis_page():
|
|||
data=csv_data,
|
||||
file_name=f"{selected_result.path.name}_results.csv",
|
||||
mime="text/csv",
|
||||
width="stretch",
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Download settings as text
|
||||
# Download settings as JSON
|
||||
import json
|
||||
|
||||
settings_json = json.dumps(settings, indent=2)
|
||||
settings_dict = {
|
||||
"task": selected_result.settings.task,
|
||||
"grid": selected_result.settings.grid,
|
||||
"level": selected_result.settings.level,
|
||||
"model": selected_result.settings.model,
|
||||
"cv_splits": selected_result.settings.cv_splits,
|
||||
"classes": selected_result.settings.classes,
|
||||
}
|
||||
settings_json = json.dumps(settings_dict, indent=2)
|
||||
st.download_button(
|
||||
label="⚙️ Download Settings (JSON)",
|
||||
data=settings_json,
|
||||
file_name=f"{selected_result.path.name}_settings.json",
|
||||
mime="application/json",
|
||||
width="stretch",
|
||||
)
|
||||
|
||||
# Show raw data preview
|
||||
st.subheader("Raw Data Preview")
|
||||
st.dataframe(results.head(100), width="stretch")
|
||||
|
||||
|
||||
def render_training_analysis_page():
|
||||
"""Render the Training Results Analysis page of the dashboard."""
|
||||
st.title("🦾 Training Results Analysis")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
Analyze training results, hyperparameter search performance, and model configurations.
|
||||
Select a training run from the sidebar to explore detailed metrics and parameter analysis.
|
||||
"""
|
||||
)
|
||||
|
||||
# Load all available training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
st.info("Run training using: `pixi run python -m entropice.ml.training`")
|
||||
return
|
||||
|
||||
st.success(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Sidebar: Training run selection
|
||||
with st.sidebar:
|
||||
selection_result = render_analysis_settings_sidebar(training_results)
|
||||
if selection_result[0] is None:
|
||||
return
|
||||
selected_result, selected_metric, refit_metric, top_n = selection_result
|
||||
|
||||
# Main content area
|
||||
results = selected_result.results
|
||||
settings = selected_result.settings
|
||||
|
||||
# Run Information
|
||||
render_run_information(selected_result, refit_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Performance Summary Section
|
||||
st.header("📊 Performance Overview")
|
||||
render_performance_summary(results, refit_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Confusion Matrix Map Section
|
||||
st.header("🗺️ Prediction Results Map")
|
||||
render_confusion_matrix_map(selected_result.path, settings)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Cross-Validation Statistics
|
||||
render_cv_statistics_section(selected_result, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Space Analysis
|
||||
render_parameter_space_section(selected_result, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Parameter Correlation
|
||||
st.header("🔗 Parameter Correlation")
|
||||
render_parameter_correlation(results, selected_metric)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Multi-Metric Comparison
|
||||
if len(selected_result.available_metrics) >= 2:
|
||||
st.header("📊 Multi-Metric Comparison")
|
||||
render_multi_metric_comparison(results)
|
||||
st.divider()
|
||||
|
||||
# Top Configurations
|
||||
st.header("🏆 Top Performing Configurations")
|
||||
render_top_configurations(results, selected_metric, top_n)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Raw Data Export
|
||||
render_data_export_section(results, selected_result)
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
"""Training Data page: Visualization of training data distributions."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.source_data import (
|
||||
render_alphaearth_map,
|
||||
|
|
@ -19,30 +22,21 @@ from entropice.dashboard.plots.training_data import (
|
|||
render_spatial_map,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||
from entropice.spatial import grids
|
||||
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
|
||||
|
||||
|
||||
def render_training_data_page():
|
||||
"""Render the Training Data page of the dashboard."""
|
||||
st.title("Training Data")
|
||||
def render_dataset_configuration_sidebar():
|
||||
"""Render dataset configuration selector in sidebar with form.
|
||||
|
||||
# Sidebar widgets for dataset configuration in a form
|
||||
Stores the selected ensemble in session state when form is submitted.
|
||||
"""
|
||||
with st.sidebar.form("dataset_config_form"):
|
||||
st.header("Dataset Configuration")
|
||||
|
||||
# Combined grid and level selection
|
||||
grid_options = [
|
||||
"hex-3",
|
||||
"hex-4",
|
||||
"hex-5",
|
||||
"hex-6",
|
||||
"healpix-6",
|
||||
"healpix-7",
|
||||
"healpix-8",
|
||||
"healpix-9",
|
||||
"healpix-10",
|
||||
]
|
||||
# Grid selection
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
grid_level_combined = st.selectbox(
|
||||
"Grid Configuration",
|
||||
|
|
@ -51,9 +45,8 @@ def render_training_data_page():
|
|||
help="Select the grid system and resolution level",
|
||||
)
|
||||
|
||||
# Parse grid type and level
|
||||
grid, level_str = grid_level_combined.split("-")
|
||||
level = int(level_str)
|
||||
# Find the selected grid config
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
|
||||
# Target feature selection
|
||||
target = st.selectbox(
|
||||
|
|
@ -66,42 +59,47 @@ def render_training_data_page():
|
|||
# Members selection
|
||||
st.subheader("Dataset Members")
|
||||
|
||||
all_members = [
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
]
|
||||
selected_members = []
|
||||
all_members = cast(
|
||||
list[L2SourceDataset],
|
||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||
)
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
|
||||
for member in all_members:
|
||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||
selected_members.append(member)
|
||||
selected_members.append(member) # type: ignore[arg-type]
|
||||
|
||||
# Form submit button
|
||||
load_button = st.form_submit_button(
|
||||
"Load Dataset",
|
||||
type="primary",
|
||||
width="stretch",
|
||||
use_container_width=True,
|
||||
disabled=len(selected_members) == 0,
|
||||
)
|
||||
|
||||
# Create DatasetEnsemble only when form is submitted
|
||||
if load_button:
|
||||
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=selected_grid_config.grid,
|
||||
level=selected_grid_config.level,
|
||||
target=cast(TargetDataset, target),
|
||||
members=selected_members,
|
||||
)
|
||||
# Store ensemble in session state
|
||||
st.session_state["dataset_ensemble"] = ensemble
|
||||
st.session_state["dataset_loaded"] = True
|
||||
|
||||
# Display dataset information if loaded
|
||||
if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
|
||||
ensemble = st.session_state["dataset_ensemble"]
|
||||
|
||||
# Display current configuration
|
||||
st.subheader("📊 Current Configuration")
|
||||
def render_dataset_statistics(ensemble: DatasetEnsemble):
|
||||
"""Render dataset statistics and configuration overview.
|
||||
|
||||
# Create a visually appealing layout with columns
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
|
||||
"""
|
||||
st.markdown("### 📊 Dataset Configuration")
|
||||
|
||||
# Display current configuration in columns
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
|
|
@ -126,9 +124,9 @@ def render_training_data_page():
|
|||
# Display dataset ID in a styled container
|
||||
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
||||
|
||||
# Display dataset statistics
|
||||
# Display detailed dataset statistics
|
||||
st.markdown("---")
|
||||
st.subheader("📈 Dataset Statistics")
|
||||
st.markdown("### 📈 Dataset Statistics")
|
||||
|
||||
with st.spinner("Computing dataset statistics..."):
|
||||
stats = ensemble.get_stats()
|
||||
|
|
@ -155,12 +153,12 @@ def render_training_data_page():
|
|||
st.metric("Variables", member_stats["num_variables"])
|
||||
with metric_cols[2]:
|
||||
# Display dimensions in a more readable format
|
||||
dim_str = " × ".join([f"{dim}" for dim in member_stats["dimensions"].values()])
|
||||
dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
||||
st.metric("Shape", dim_str)
|
||||
with metric_cols[3]:
|
||||
# Calculate total data points
|
||||
total_points = 1
|
||||
for dim_size in member_stats["dimensions"].values():
|
||||
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
||||
total_points *= dim_size
|
||||
st.metric("Data Points", f"{total_points:,}")
|
||||
|
||||
|
|
@ -170,7 +168,7 @@ def render_training_data_page():
|
|||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats["variables"]
|
||||
for v in member_stats["variables"] # type: ignore[union-attr]
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
|
@ -182,51 +180,34 @@ def render_training_data_page():
|
|||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size}</span>"
|
||||
for dim_name, dim_size in member_stats["dimensions"].items()
|
||||
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Create tabs for different data views
|
||||
tab_names = ["📊 Labels", "📐 Areas"]
|
||||
def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
|
||||
"""Render target labels distribution and spatial visualization.
|
||||
|
||||
# Add tabs for each member
|
||||
for member in ensemble.members:
|
||||
if member == "AlphaEarth":
|
||||
tab_names.append("🌍 AlphaEarth")
|
||||
elif member == "ArcticDEM":
|
||||
tab_names.append("🏔️ ArcticDEM")
|
||||
elif member.startswith("ERA5"):
|
||||
# Group ERA5 temporal variants
|
||||
if "🌡️ ERA5" not in tab_names:
|
||||
tab_names.append("🌡️ ERA5")
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
train_data_dict: Pre-loaded training data for all tasks.
|
||||
|
||||
tabs = st.tabs(tab_names)
|
||||
|
||||
# Labels tab
|
||||
with tabs[0]:
|
||||
"""
|
||||
st.markdown("### Target Labels Distribution and Spatial Visualization")
|
||||
|
||||
# Load training data for all three tasks
|
||||
with st.spinner("Loading training data for all tasks..."):
|
||||
train_data_dict = load_all_training_data(ensemble)
|
||||
|
||||
# Calculate total samples (use binary as reference)
|
||||
total_samples = len(train_data_dict["binary"])
|
||||
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
||||
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
||||
|
||||
st.success(
|
||||
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
|
||||
)
|
||||
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
|
||||
|
||||
# Render distribution histograms
|
||||
st.markdown("---")
|
||||
render_all_distribution_histograms(train_data_dict)
|
||||
render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
|
|
@ -236,8 +217,15 @@ def render_training_data_page():
|
|||
|
||||
render_spatial_map(train_data_dict)
|
||||
|
||||
# Areas tab
|
||||
with tabs[1]:
|
||||
|
||||
def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
|
||||
"""Render grid cell areas and land/water distribution.
|
||||
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
grid_gdf: Pre-loaded grid GeoDataFrame.
|
||||
|
||||
"""
|
||||
st.markdown("### Grid Cell Areas and Land/Water Distribution")
|
||||
|
||||
st.markdown(
|
||||
|
|
@ -247,9 +235,6 @@ def render_training_data_page():
|
|||
"with >10% land coverage."
|
||||
)
|
||||
|
||||
# Load grid data
|
||||
grid_gdf = grids.open(ensemble.grid, ensemble.level)
|
||||
|
||||
st.success(
|
||||
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
|
||||
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
|
||||
|
|
@ -269,23 +254,26 @@ def render_training_data_page():
|
|||
|
||||
st.markdown("---")
|
||||
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (
|
||||
ensemble.grid == "healpix" and ensemble.level == 10
|
||||
):
|
||||
# Check if we should skip map rendering for performance
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
||||
st.warning(
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
||||
"due to performance considerations."
|
||||
)
|
||||
else:
|
||||
render_areas_map(grid_gdf, ensemble.grid)
|
||||
|
||||
# AlphaEarth tab
|
||||
tab_idx = 2
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
st.markdown("### AlphaEarth Embeddings Analysis")
|
||||
|
||||
with st.spinner("Loading AlphaEarth data..."):
|
||||
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
|
||||
def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
|
||||
"""Render AlphaEarth embeddings analysis.
|
||||
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
alphaearth_ds: Pre-loaded AlphaEarth dataset.
|
||||
targets: Pre-loaded targets GeoDataFrame.
|
||||
|
||||
"""
|
||||
st.markdown("### AlphaEarth Embeddings Analysis")
|
||||
|
||||
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
|
||||
|
||||
|
|
@ -294,25 +282,27 @@ def render_training_data_page():
|
|||
|
||||
st.markdown("---")
|
||||
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (
|
||||
ensemble.grid == "healpix" and ensemble.level == 10
|
||||
):
|
||||
# Check if we should skip map rendering for performance
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
||||
st.warning(
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
||||
"due to performance considerations."
|
||||
)
|
||||
else:
|
||||
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
|
||||
|
||||
tab_idx += 1
|
||||
|
||||
# ArcticDEM tab
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
|
||||
"""Render ArcticDEM terrain analysis.
|
||||
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
arcticdem_ds: Pre-loaded ArcticDEM dataset.
|
||||
targets: Pre-loaded targets GeoDataFrame.
|
||||
|
||||
"""
|
||||
st.markdown("### ArcticDEM Terrain Analysis")
|
||||
|
||||
with st.spinner("Loading ArcticDEM data..."):
|
||||
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
|
||||
|
||||
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
|
||||
|
||||
render_arcticdem_overview(arcticdem_ds)
|
||||
|
|
@ -320,21 +310,25 @@ def render_training_data_page():
|
|||
|
||||
st.markdown("---")
|
||||
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (
|
||||
ensemble.grid == "healpix" and ensemble.level == 10
|
||||
):
|
||||
# Check if we should skip map rendering for performance
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
||||
st.warning(
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
||||
"due to performance considerations."
|
||||
)
|
||||
else:
|
||||
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
|
||||
|
||||
tab_idx += 1
|
||||
|
||||
# ERA5 tab (combining all temporal variants)
|
||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||
if era5_members:
|
||||
with tabs[tab_idx]:
|
||||
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
|
||||
"""Render ERA5 climate data analysis.
|
||||
|
||||
Args:
|
||||
ensemble: The dataset ensemble configuration.
|
||||
era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
|
||||
targets: Pre-loaded targets GeoDataFrame.
|
||||
|
||||
"""
|
||||
st.markdown("### ERA5 Climate Data Analysis")
|
||||
|
||||
# Let user select which ERA5 temporal aggregation to view
|
||||
|
|
@ -344,7 +338,7 @@ def render_training_data_page():
|
|||
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
||||
}
|
||||
|
||||
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
|
||||
available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
|
||||
|
||||
selected_era5 = st.selectbox(
|
||||
"Select ERA5 temporal aggregation",
|
||||
|
|
@ -353,30 +347,134 @@ def render_training_data_page():
|
|||
key="era5_temporal_select",
|
||||
)
|
||||
|
||||
if selected_era5:
|
||||
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
|
||||
|
||||
with st.spinner(f"Loading {selected_era5} data..."):
|
||||
era5_ds, targets = load_source_data(ensemble, selected_era5)
|
||||
|
||||
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
|
||||
if selected_era5 and selected_era5 in era5_data:
|
||||
era5_ds, temporal_type = era5_data[selected_era5]
|
||||
|
||||
render_era5_overview(era5_ds, temporal_type)
|
||||
render_era5_plots(era5_ds, temporal_type)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (
|
||||
ensemble.grid == "healpix" and ensemble.level == 10
|
||||
):
|
||||
# Check if we should skip map rendering for performance
|
||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
||||
st.warning(
|
||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
|
||||
"🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
||||
"due to performance considerations."
|
||||
)
|
||||
else:
|
||||
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
|
||||
|
||||
|
||||
def render_training_data_page():
|
||||
"""Render the Training Data page of the dashboard."""
|
||||
st.title("🎯 Training Data")
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
Explore and visualize the training data for RTS prediction models.
|
||||
Configure your dataset by selecting grid configuration, target dataset,
|
||||
and data sources in the sidebar, then click "Load Dataset" to begin.
|
||||
"""
|
||||
)
|
||||
|
||||
# Render sidebar configuration
|
||||
render_dataset_configuration_sidebar()
|
||||
|
||||
# Check if dataset is loaded in session state
|
||||
if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
|
||||
st.info(
|
||||
"👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
|
||||
)
|
||||
return
|
||||
|
||||
# Get ensemble from session state
|
||||
ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
|
||||
|
||||
st.divider()
|
||||
|
||||
# Load all necessary data once
|
||||
with st.spinner("Loading dataset..."):
|
||||
# Load training data for all tasks
|
||||
train_data_dict = load_all_training_data(ensemble)
|
||||
|
||||
# Load grid data
|
||||
grid_gdf = grids.open(ensemble.grid, ensemble.level)
|
||||
|
||||
# Load targets (needed by all source data views)
|
||||
targets = ensemble._read_target()
|
||||
|
||||
# Load AlphaEarth data if in members
|
||||
alphaearth_ds = None
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
|
||||
|
||||
# Load ArcticDEM data if in members
|
||||
arcticdem_ds = None
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
|
||||
|
||||
# Load ERA5 data for all temporal aggregations in members
|
||||
era5_data = {}
|
||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||
for era5_member in era5_members:
|
||||
era5_ds, _ = load_source_data(ensemble, era5_member)
|
||||
temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
|
||||
era5_data[era5_member] = (era5_ds, temporal_type)
|
||||
|
||||
st.success(
|
||||
f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
|
||||
)
|
||||
|
||||
# Render dataset statistics
|
||||
render_dataset_statistics(ensemble)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Create tabs for different data views
|
||||
tab_names = ["📊 Labels", "📐 Areas"]
|
||||
|
||||
# Add tabs for each member based on what's in the ensemble
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
tab_names.append("🌍 AlphaEarth")
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
tab_names.append("🏔️ ArcticDEM")
|
||||
|
||||
# Check for ERA5 members
|
||||
if era5_members:
|
||||
tab_names.append("🌡️ ERA5")
|
||||
|
||||
tabs = st.tabs(tab_names)
|
||||
|
||||
# Track current tab index
|
||||
tab_idx = 0
|
||||
|
||||
# Labels tab
|
||||
with tabs[tab_idx]:
|
||||
render_labels_view(ensemble, train_data_dict)
|
||||
tab_idx += 1
|
||||
|
||||
# Areas tab
|
||||
with tabs[tab_idx]:
|
||||
render_areas_view(ensemble, grid_gdf)
|
||||
tab_idx += 1
|
||||
|
||||
# AlphaEarth tab
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
render_alphaearth_view(ensemble, alphaearth_ds, targets)
|
||||
tab_idx += 1
|
||||
|
||||
# ArcticDEM tab
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
with tabs[tab_idx]:
|
||||
render_arcticdem_view(ensemble, arcticdem_ds, targets)
|
||||
tab_idx += 1
|
||||
|
||||
# ERA5 tab (combining all temporal variants)
|
||||
if era5_members:
|
||||
with tabs[tab_idx]:
|
||||
render_era5_view(ensemble, era5_data, targets)
|
||||
|
||||
# Show balloons once after all tabs are rendered
|
||||
st.balloons()
|
||||
|
||||
else:
|
||||
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue