Add a Storage Analysis Section

This commit is contained in:
Tobias Hölzer 2026-01-18 20:36:47 +01:00
parent 636c034b55
commit 2664579a75
9 changed files with 370 additions and 1959 deletions

View file

@ -18,7 +18,6 @@ from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
from entropice.dashboard.views.training_data_page import render_training_data_page
def main():
@ -28,7 +27,6 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
@ -38,8 +36,7 @@ def main():
{
"Overview": [overview_page],
"Data": [data_page],
"Training": [training_data_page, training_analysis_page, autogluon_page],
"Model State": [model_state_page],
"Experiments": [training_analysis_page, autogluon_page, model_state_page],
"Inference": [inference_page],
}
)

File diff suppressed because it is too large Load diff

View file

@ -1,366 +0,0 @@
"""Plotting functions for training data visualizations."""
import geopandas as gpd
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import CategoricalTrainingDataset
def render_all_distribution_histograms(
train_data_dict: dict[str, CategoricalTrainingDataset],
):
"""Render histograms for all three tasks side by side.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("📊 Target Distribution by Task")
# Create a 3-column layout for the three tasks
cols = st.columns(3)
tasks = ["binary", "count", "density"]
task_titles = {
"binary": "Binary Classification",
"count": "Count Classification",
"density": "Density Classification",
}
for idx, task in enumerate(tasks):
dataset = train_data_dict[task]
categories = dataset.y.binned.cat.categories.tolist()
colors = get_palette(task, len(categories))
with cols[idx]:
st.markdown(f"**{task_titles[task]}**")
# Create histogram data
counts_df = pd.DataFrame(
{
"Category": categories,
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
}
)
# Create stacked bar chart
fig = go.Figure()
fig.add_trace(
go.Bar(
name="Train",
x=counts_df["Category"],
y=counts_df["Train"],
marker_color=colors,
opacity=0.9,
text=counts_df["Train"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.add_trace(
go.Bar(
name="Test",
x=counts_df["Category"],
y=counts_df["Test"],
marker_color=colors,
opacity=0.6,
text=counts_df["Test"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.update_layout(
barmode="group",
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True,
legend={
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
},
xaxis_title=None,
yaxis_title="Count",
xaxis={"tickangle": -45},
)
st.plotly_chart(fig, width="stretch")
# Show summary statistics
total = len(dataset)
train_pct = (dataset.split == "train").sum() / total * 100
test_pct = (dataset.split == "test").sum() / total * 100
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
"""Assign colors to geodataframe based on the selected color mode.
Args:
gdf: GeoDataFrame to add colors to
color_mode: One of 'target_class' or 'split'
dataset: CategoricalTrainingDataset
selected_task: Task name for color palette selection
Returns:
GeoDataFrame with 'fill_color' column added
"""
if color_mode == "target_class":
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
# Create color mapping
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
gdf["color"] = gdf["target_class"].map(color_map)
# Convert hex colors to RGB
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip("#")
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
elif color_mode == "split":
split_colors = {
"train": [66, 135, 245],
"test": [245, 135, 66],
} # Blue # Orange
gdf["fill_color"] = gdf["split"].map(split_colors)
return gdf
@st.fragment
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
"""Render a pydeck spatial map showing training data distribution with interactive controls.
This is a Streamlit fragment that reruns independently when users interact with the
visualization controls (color mode and opacity), without re-running the entire page.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("🗺️ Spatial Distribution Map")
# Create controls in columns
col1, col2 = st.columns([3, 1])
with col1:
vis_mode = st.selectbox(
"Visualization mode",
options=["binary", "count", "density", "split"],
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
key="spatial_map_mode",
)
with col2:
opacity = st.slider(
"Opacity",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
key="spatial_map_opacity",
)
# Determine which task dataset to use and color mode
if vis_mode == "split":
# Use binary dataset for split visualization
dataset = train_data_dict["binary"]
color_mode = "split"
selected_task = "binary"
else:
# Use the selected task
dataset = train_data_dict[vis_mode]
color_mode = "target_class"
selected_task = vis_mode
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
# Fix antimeridian issues
gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
# Add binned labels and split information from current dataset
gdf["target_class"] = dataset.y.binned.to_numpy()
gdf["split"] = dataset.split.to_numpy()
gdf["raw_value"] = dataset.z.to_numpy()
# Add information from all three tasks for tooltip
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
# Convert to WGS84 for pydeck
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
# Assign colors based on the selected mode
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
# Convert to GeoJSON format and add elevation for 3D visualization
geojson_data = []
# Normalize raw values for elevation (only for count and density)
use_elevation = vis_mode in ["count", "density"]
if use_elevation:
raw_values = gdf_wgs84["raw_value"]
min_val, max_val = raw_values.min(), raw_values.max()
# Normalize to 0-1 range for better 3D visualization
if max_val > min_val:
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
else:
gdf_wgs84["elevation"] = 0
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"target_class": str(row["target_class"]),
"split": str(row["split"]),
"raw_value": float(row["raw_value"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]) if use_elevation else 0,
"binary_label": str(row["binary_label"]),
"count_category": str(row["count_category"]),
"count_raw": int(row["count_raw"]),
"density_category": str(row["density_category"]),
"density_raw": f"{float(row['density_raw']):.4f}",
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=use_elevation,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if use_elevation else 0,
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
pickable=True,
)
# Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using elevation
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not use_elevation else 1.5,
pitch=0 if not use_elevation else 45,
)
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Binary:</b> {binary_label}<br/>"
"<b>Count Category:</b> {count_category}<br/>"
"<b>Count Raw:</b> {count_raw}<br/>"
"<b>Density Category:</b> {density_category}<br/>"
"<b>Density Raw:</b> {density_raw}<br/>"
"<b>Split:</b> {split}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
# Render the map
st.pydeck_chart(deck)
# Show info about 3D visualization
if use_elevation:
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
# Add legend
with st.expander("Legend", expanded=True):
if color_mode == "target_class":
st.markdown("**Target Classes:**")
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
intervals = dataset.y.intervals
# For count and density tasks, show intervals
if selected_task in ["count", "density"]:
for i, cat in enumerate(categories):
color = colors_palette[i]
interval_min, interval_max = intervals[i]
# Format interval display
if interval_min is None or interval_max is None:
interval_str = ""
elif selected_task == "count":
# Integer values for count
if interval_min == interval_max:
interval_str = f" ({int(interval_min)})"
else:
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
else: # density
# Percentage values for density
if interval_min == interval_max:
interval_str = f" ({interval_min * 100:.4f}%)"
else:
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{cat}{interval_str}</span></div>",
unsafe_allow_html=True,
)
else:
# Binary task: use original column layout
legend_cols = st.columns(len(categories))
for i, cat in enumerate(categories):
with legend_cols[i]:
color = colors_palette[i]
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
f"<span>{cat}</span></div>",
unsafe_allow_html=True,
)
if use_elevation:
st.markdown("---")
st.markdown("**Elevation (3D):**")
min_val = gdf_wgs84["raw_value"].min()
max_val = gdf_wgs84["raw_value"].max()
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
elif color_mode == "split":
st.markdown("**Data Split:**")
legend_html = (
'<div style="display: flex; gap: 20px;">'
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Train</span></div>"
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Test</span></div></div>"
)
st.markdown(legend_html, unsafe_allow_html=True)

View file

@ -431,7 +431,7 @@ def _render_aggregation_selection(
if not submitted:
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
st.stop()
return dimension_filters
return dimension_filters

View file

@ -103,8 +103,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
)
# Expandable details for each result
st.subheader("Individual Experiment Details")
with st.expander("Show Individual Experiment Details", expanded=False):
for tr in filtered_results:
tr_info = tr.display_info
display_name = tr_info.get_display_name("model_first")
@ -160,7 +159,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
unique_vals = tr.results[param].nunique()
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
with st.expander("Show CV Results DataFrame"):
st.write("**CV Results DataFrame:**")
st.dataframe(tr.results, width="stretch", hide_index=True)
st.write(f"\n**Path:** `{tr.path}`")

View file

@ -0,0 +1,163 @@
"""Storage Statistics Section for Entropice Dashboard."""
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from entropice.dashboard.utils.loaders import StorageInfo, load_storage_statistics
from entropice.utils.paths import DATA_DIR
def _format_bytes(bytes_value: int) -> str:
"""Format bytes into human-readable string."""
value = float(bytes_value)
for unit in ["B", "KB", "MB", "GB", "TB"]:
if value < 1024.0:
return f"{value:.2f} {unit}"
value /= 1024.0
return f"{value:.2f} PB"
def _create_storage_bar_chart(storage_infos: list[StorageInfo]) -> go.Figure:
"""Create a horizontal bar chart showing storage usage by subdirectory."""
if not storage_infos:
return go.Figure()
# Prepare data
names = [info.name for info in storage_infos]
sizes = [info.size_bytes / (1024**3) for info in storage_infos] # Convert to GB
file_counts = [info.file_count for info in storage_infos]
# Create figure
fig = go.Figure()
# Add bar trace
fig.add_trace(
go.Bar(
y=names,
x=sizes,
orientation="h",
text=[f"{s:.2f} GB" for s in sizes],
textposition="auto",
hovertemplate="<b>%{y}</b><br>Size: %{x:.2f} GB<br>Files: %{customdata:,}<extra></extra>",
customdata=file_counts,
marker={
"color": sizes,
"colorscale": "Blues",
"showscale": False,
},
)
)
# Update layout
fig.update_layout(
title="Storage Usage by Subdirectory",
xaxis_title="Size (GB)",
yaxis_title="Directory",
height=max(400, len(names) * 40), # Dynamic height based on number of directories
showlegend=False,
margin={"l": 200, "r": 50, "t": 50, "b": 50},
)
return fig
def _create_storage_pie_chart(storage_infos: list[StorageInfo]) -> go.Figure:
"""Create a pie chart showing storage distribution."""
if not storage_infos:
return go.Figure()
# Prepare data
names = [info.name for info in storage_infos]
sizes = [info.size_bytes for info in storage_infos]
# Create figure
fig = go.Figure(
data=[
go.Pie(
labels=names,
values=sizes,
textinfo="label+percent",
hovertemplate="<b>%{label}</b><br>Size: %{customdata}<br>%{percent}<extra></extra>",
customdata=[info.display_size for info in storage_infos],
)
]
)
# Update layout
fig.update_layout(
title="Storage Distribution",
height=500,
)
return fig
def render_storage_statistics():
"""Render the storage statistics section showing disk usage for DATA_DIR subdirectories."""
st.header("💾 Storage Statistics")
st.markdown(
f"""
This section shows the disk usage of subdirectories in the data directory:
**`{DATA_DIR}`**
Data is collected using [dust](https://github.com/bootandy/dust), a modern disk usage analyzer.
Statistics are cached for 5 minutes to reduce overhead.
"""
)
# Load storage statistics
with st.spinner("Analyzing storage usage..."):
storage_infos, total_size, total_files = load_storage_statistics()
if not storage_infos:
st.warning("No storage data available. The data directory may be empty or inaccessible.")
return
# Display summary metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Storage Used", _format_bytes(total_size))
with col2:
st.metric("Total Files", f"{total_files:,}")
with col3:
st.metric("Number of Subdirectories", len(storage_infos))
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["📊 Bar Chart", "🥧 Pie Chart", "📋 Detailed Table"])
with tab1:
st.plotly_chart(_create_storage_bar_chart(storage_infos), use_container_width=True)
with tab2:
st.plotly_chart(_create_storage_pie_chart(storage_infos), use_container_width=True)
with tab3:
# Create DataFrame for detailed view
df = pd.DataFrame(
[
{
"Directory": info.name,
"Size": info.display_size,
"Size (Bytes)": info.size_bytes,
"Files": info.file_count,
"Percentage": f"{(info.size_bytes / total_size * 100):.2f}%",
}
for info in storage_infos
]
)
st.dataframe(
df[["Directory", "Size", "Files", "Percentage"]],
use_container_width=True,
hide_index=True,
)
# Add download button for detailed data
st.download_button(
label="📥 Download Storage Statistics (CSV)",
data=df.to_csv(index=False),
file_name="entropice_storage_statistics.csv",
mime="text/csv",
)

View file

@ -1,6 +1,8 @@
"""Data utilities for Entropice dashboard."""
import json
import pickle
import subprocess
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@ -252,3 +254,148 @@ def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Ta
for task in all_tasks:
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
return train_data_dict
@dataclass
class StorageInfo:
"""Storage information for a directory."""
name: str
size_bytes: int
file_count: int
display_size: str
def _parse_size_to_bytes(size_str: str) -> int:
"""Convert dust's human-readable size string to bytes.
Examples: "92K" -> 92*1024, "1.5M" -> 1.5*1024*1024, "928B" -> 928
"""
size_str = size_str.strip().upper()
if not size_str:
return 0
# Extract numeric part and unit
numeric_part = ""
unit = ""
for char in size_str:
if char.isdigit() or char == ".":
numeric_part += char
else:
unit += char
try:
value = float(numeric_part) if numeric_part else 0
except ValueError:
return 0
# Convert based on unit
unit = unit.strip()
multipliers = {
"B": 1,
"K": 1024,
"M": 1024**2,
"G": 1024**3,
"T": 1024**4,
"P": 1024**5,
}
return int(value * multipliers.get(unit, 1))
def _run_dust_command(data_dir: Path, for_files: bool = False) -> dict | None:
"""Run dust command and return parsed JSON output.
Args:
data_dir: Directory to analyze
for_files: If True, count files (-f flag); if False, count disk space
Returns:
Parsed JSON dict or None if command failed
"""
cmd = ["dust", "-j", "-d", "1"]
if for_files:
cmd.append("-f")
cmd.append(str(data_dir))
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
return None
return json.loads(result.stdout)
except (subprocess.TimeoutExpired, json.JSONDecodeError):
return None
def _build_file_counts_lookup(files_data: dict | None) -> dict[str, int]:
"""Build lookup dict for file counts from dust JSON output."""
file_counts = {}
if files_data and "children" in files_data:
for child in files_data["children"]:
name = Path(child["name"]).name
count_str = child.get("size", "0")
file_counts[name] = _parse_size_to_bytes(count_str)
return file_counts
@st.cache_data(ttl=300) # Cache for 5 minutes
def load_storage_statistics() -> tuple[list[StorageInfo], int, int]:
"""Load storage statistics for DATA_DIR subdirectories using dust.
Returns:
Tuple of (subdirectory stats list, total size in bytes, total file count)
"""
data_dir = entropice.utils.paths.DATA_DIR
if not data_dir.exists():
return [], 0, 0
try:
# Run dust for disk space and file counts
space_data = _run_dust_command(data_dir, for_files=False)
files_data = _run_dust_command(data_dir, for_files=True)
if not space_data:
st.warning("Failed to get storage statistics from dust")
return [], 0, 0
# Build lookup dict for file counts
file_counts = _build_file_counts_lookup(files_data)
# Extract subdirectory information from space data
storage_infos = []
total_size = 0
total_files = 0
if "children" in space_data:
for child in space_data["children"]:
full_path = child.get("name", "")
dir_name = Path(full_path).name
size_str = child.get("size", "0")
size_bytes = _parse_size_to_bytes(size_str)
file_count = file_counts.get(dir_name, 0)
storage_infos.append(
StorageInfo(
name=dir_name,
size_bytes=size_bytes,
file_count=file_count,
display_size=size_str,
)
)
total_size += size_bytes
total_files += file_count
# Sort by size descending
storage_infos.sort(key=lambda x: x.size_bytes, reverse=True)
return storage_infos, total_size, total_files
except FileNotFoundError:
st.error("dust command not found. Please install dust: https://github.com/bootandy/dust")
return [], 0, 0
except Exception as e:
st.error(f"Error getting storage statistics: {e}")
return [], 0, 0

View file

@ -8,6 +8,7 @@ from entropice.dashboard.sections.experiment_results import (
render_experiment_results,
render_training_results_summary,
)
from entropice.dashboard.sections.storage_statistics import render_storage_statistics
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
@ -52,5 +53,10 @@ def render_overview_page():
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
st.divider()
# Render storage statistics section
render_storage_statistics()
st.balloons()
stopwatch.summary()

View file

@ -1,481 +0,0 @@
"""Training Data page: Visualization of training data distributions."""
from typing import cast
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.source_data import (
render_alphaearth_map,
render_alphaearth_overview,
render_alphaearth_plots,
render_arcticdem_map,
render_arcticdem_overview,
render_arcticdem_plots,
render_areas_map,
render_era5_map,
render_era5_overview,
render_era5_plots,
)
from entropice.dashboard.plots.training_data import (
render_all_distribution_histograms,
render_spatial_map,
)
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.spatial import grids
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
def render_dataset_configuration_sidebar():
"""Render dataset configuration selector in sidebar with form.
Stores the selected ensemble in session state when form is submitted.
"""
with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration")
# Grid selection
grid_options = [gc.display_name for gc in grid_configs]
grid_level_combined = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
)
# Find the selected grid config
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Target feature selection
target = st.selectbox(
"Target Feature",
options=["darts_rts", "darts_mllabels"],
index=0,
help="Select the target variable for training",
)
# Members selection
st.subheader("Dataset Members")
all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
selected_members: list[L2SourceDataset] = []
for member in all_members:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member) # type: ignore[arg-type]
# Form submit button
load_button = st.form_submit_button(
"Load Dataset",
type="primary",
use_container_width=True,
disabled=len(selected_members) == 0,
)
# Create DatasetEnsemble only when form is submitted
if load_button:
ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
target=cast(TargetDataset, target),
members=selected_members,
)
# Store ensemble in session state
st.session_state["dataset_ensemble"] = ensemble
st.session_state["dataset_loaded"] = True
def render_dataset_statistics(ensemble: DatasetEnsemble):
"""Render dataset statistics and configuration overview.
Args:
ensemble: The dataset ensemble configuration.
"""
st.markdown("### 📊 Dataset Configuration")
# Display current configuration in columns
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric(label="Grid Type", value=ensemble.grid.upper())
with col2:
st.metric(label="Grid Level", value=ensemble.level)
with col3:
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
with col4:
st.metric(label="Members", value=len(ensemble.members))
# Display members in an expandable section
with st.expander("🗂️ Dataset Members", expanded=False):
members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
# Display detailed dataset statistics
st.markdown("---")
st.markdown("### 📈 Dataset Statistics")
with st.spinner("Computing dataset statistics..."):
stats = ensemble.get_stats()
# High-level summary metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
with col2:
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
with col3:
st.metric(label="Data Sources", value=len(stats["members"]))
# Detailed member statistics in expandable section
with st.expander("📦 Data Source Details", expanded=False):
for member, member_stats in stats["members"].items():
st.markdown(f"### {member}")
# Create metrics for this member
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Features", member_stats["num_features"])
with metric_cols[1]:
st.metric("Variables", member_stats["num_variables"])
with metric_cols[2]:
# Display dimensions in a more readable format
dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
st.metric("Shape", dim_str)
with metric_cols[3]:
# Calculate total data points
total_points = 1
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
total_points *= dim_size
st.metric("Data Points", f"{total_points:,}")
# Show variables as colored badges
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats["variables"] # type: ignore[union-attr]
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Show dimension details
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size}</span>"
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
"""Render target labels distribution and spatial visualization.
Args:
ensemble: The dataset ensemble configuration.
train_data_dict: Pre-loaded training data for all tasks.
"""
st.markdown("### Target Labels Distribution and Spatial Visualization")
# Calculate total samples (use binary as reference)
total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
# Render distribution histograms
st.markdown("---")
render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
st.markdown("---")
# Render spatial map
binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
render_spatial_map(train_data_dict)
def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
"""Render grid cell areas and land/water distribution.
Args:
ensemble: The dataset ensemble configuration.
grid_gdf: Pre-loaded grid GeoDataFrame.
"""
st.markdown("### Grid Cell Areas and Land/Water Distribution")
st.markdown(
"This visualization shows the spatial distribution of cell areas, land areas, "
"water areas, and land ratio across the grid. The grid has been filtered to "
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
"with >10% land coverage."
)
st.success(
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
)
# Show summary statistics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Cells", f"{len(grid_gdf):,}")
with col2:
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
with col3:
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
with col4:
total_land = grid_gdf["land_area"].sum()
st.metric("Total Land Area", f"{total_land:,.0f} km²")
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_areas_map(grid_gdf, ensemble.grid)
def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
"""Render AlphaEarth embeddings analysis.
Args:
ensemble: The dataset ensemble configuration.
alphaearth_ds: Pre-loaded AlphaEarth dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### AlphaEarth Embeddings Analysis")
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
render_alphaearth_overview(alphaearth_ds)
render_alphaearth_plots(alphaearth_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
"""Render ArcticDEM terrain analysis.
Args:
ensemble: The dataset ensemble configuration.
arcticdem_ds: Pre-loaded ArcticDEM dataset.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ArcticDEM Terrain Analysis")
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
render_arcticdem_overview(arcticdem_ds)
render_arcticdem_plots(arcticdem_ds)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
@st.fragment
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
"""Render ERA5 climate data analysis.
Args:
ensemble: The dataset ensemble configuration.
era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
targets: Pre-loaded targets GeoDataFrame.
"""
st.markdown("### ERA5 Climate Data Analysis")
# Let user select which ERA5 temporal aggregation to view
era5_options = {
"ERA5-yearly": "Yearly",
"ERA5-seasonal": "Seasonal (Winter/Summer)",
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
}
available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
selected_era5 = st.selectbox(
"Select ERA5 temporal aggregation",
options=list(available_era5.keys()),
format_func=lambda x: available_era5[x],
key="era5_temporal_select",
)
if selected_era5 and selected_era5 in era5_data:
era5_ds, temporal_type = era5_data[selected_era5]
render_era5_overview(era5_ds, temporal_type)
render_era5_plots(era5_ds, temporal_type)
st.markdown("---")
# Check if we should skip map rendering for performance
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
st.warning(
"🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
"due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
def render_training_data_page():
"""Render the Training Data page of the dashboard."""
st.title("🎯 Training Data")
st.markdown(
"""
Explore and visualize the training data for RTS prediction models.
Configure your dataset by selecting grid configuration, target dataset,
and data sources in the sidebar, then click "Load Dataset" to begin.
"""
)
# Render sidebar configuration
render_dataset_configuration_sidebar()
# Check if dataset is loaded in session state
if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
st.info(
"👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
)
return
# Get ensemble from session state
ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
st.divider()
# Load all necessary data once
with st.spinner("Loading dataset..."):
# Load training data for all tasks
train_data_dict = load_all_training_data(ensemble)
# Load grid data
grid_gdf = grids.open(ensemble.grid, ensemble.level)
# Load targets (needed by all source data views)
targets = ensemble._read_target()
# Load AlphaEarth data if in members
alphaearth_ds = None
if "AlphaEarth" in ensemble.members:
alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
# Load ArcticDEM data if in members
arcticdem_ds = None
if "ArcticDEM" in ensemble.members:
arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
# Load ERA5 data for all temporal aggregations in members
era5_data = {}
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
for era5_member in era5_members:
era5_ds, _ = load_source_data(ensemble, era5_member)
temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
era5_data[era5_member] = (era5_ds, temporal_type)
st.success(
f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
)
# Render dataset statistics
render_dataset_statistics(ensemble)
st.markdown("---")
# Create tabs for different data views
tab_names = ["📊 Labels", "📐 Areas"]
# Add tabs for each member based on what's in the ensemble
if "AlphaEarth" in ensemble.members:
tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM")
# Check for ERA5 members
if era5_members:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
# Track current tab index
tab_idx = 0
# Labels tab
with tabs[tab_idx]:
render_labels_view(ensemble, train_data_dict)
tab_idx += 1
# Areas tab
with tabs[tab_idx]:
render_areas_view(ensemble, grid_gdf)
tab_idx += 1
# AlphaEarth tab
if "AlphaEarth" in ensemble.members:
with tabs[tab_idx]:
render_alphaearth_view(ensemble, alphaearth_ds, targets)
tab_idx += 1
# ArcticDEM tab
if "ArcticDEM" in ensemble.members:
with tabs[tab_idx]:
render_arcticdem_view(ensemble, arcticdem_ds, targets)
tab_idx += 1
# ERA5 tab (combining all temporal variants)
if era5_members:
with tabs[tab_idx]:
render_era5_view(ensemble, era5_data, targets)
# Show balloons once after all tabs are rendered
st.balloons()
stopwatch.summary()