Add a Dataset Page to deprecate the Training Data page

This commit is contained in:
Tobias Hölzer 2026-01-16 23:51:38 +01:00
parent a6e9a91692
commit cb7b0f9e6b
7 changed files with 986 additions and 90 deletions

View file

@ -13,6 +13,7 @@ Pages:
import streamlit as st import streamlit as st
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
from entropice.dashboard.views.dataset_page import render_dataset_page
from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page from entropice.dashboard.views.overview_page import render_overview_page
@ -26,6 +27,7 @@ def main():
# Setup Navigation # Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖") autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
@ -35,6 +37,7 @@ def main():
pg = st.navigation( pg = st.navigation(
{ {
"Overview": [overview_page], "Overview": [overview_page],
"Data": [data_page],
"Training": [training_data_page, training_analysis_page, autogluon_page], "Training": [training_data_page, training_analysis_page, autogluon_page],
"Model State": [model_state_page], "Model State": [model_state_page],
"Inference": [inference_page], "Inference": [inference_page],

View file

@ -0,0 +1,470 @@
"""Plots for visualizing target labels in datasets."""
import antimeridian
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
from plotly.subplots import make_subplots
from shapely.geometry import shape
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
from entropice.ml.dataset import TrainingSet
from entropice.utils.types import TargetDataset, Task
def create_target_distribution_plot(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]) -> go.Figure:
"""Create a plot showing the distribution of target labels across datasets and tasks.
This function creates a comprehensive visualization showing:
- Classification tasks (binary, count_regimes, density_regimes): categorical distributions
- Regression tasks (count, density): histogram distributions
- Train/test split information for each combination
- Separate subplots for each TargetDataset
Args:
train_data_dict: Nested dictionary with structure:
{target_dataset: {task: TrainingSet}}
where target_dataset is in ["darts_v1", "darts_mllabels"]
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
Returns:
Plotly Figure with subplots showing target distributions.
"""
# Define task types and their properties
classification_tasks: list[Task] = ["binary", "count_regimes", "density_regimes"]
task_titles: dict[Task, str] = {
"binary": "Binary",
"count_regimes": "Count Regimes",
"density_regimes": "Density Regimes",
"count": "Count (Regression)",
"density": "Density (Regression)",
}
# Get all available target datasets and tasks
target_datasets = sorted(train_data_dict.keys())
all_tasks = sorted(
{task for tasks in train_data_dict.values() for task in tasks.keys()},
key=lambda x: (x not in classification_tasks, x), # Classification first
)
# Create subplots: one row per target dataset, one column per task
n_rows = len(target_datasets)
n_cols = len(all_tasks)
# Create column titles (tasks) for the subplots
column_titles = [task_titles.get(task, str(task)) for task in all_tasks] # type: ignore[arg-type]
# Create row titles (target datasets)
row_titles = [target_ds.replace("_", " ").title() for target_ds in target_datasets]
fig = make_subplots(
rows=n_rows,
cols=n_cols,
column_titles=column_titles,
vertical_spacing=0.20 / max(n_rows, 1),
horizontal_spacing=0.08 / max(n_cols, 1),
)
# Iterate through each target dataset and task
for row_idx, target_dataset in enumerate(target_datasets, start=1):
task_dict = train_data_dict[target_dataset] # type: ignore[index]
for col_idx, task in enumerate(all_tasks, start=1):
if task not in task_dict:
# Skip if this task is not available for this target dataset
continue
dataset = task_dict[task] # type: ignore[index]
# Determine if this is a classification or regression task
if task in classification_tasks:
_add_classification_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
else:
_add_regression_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
# Update layout
fig.update_layout(
height=400 * n_rows,
showlegend=True,
legend={
"orientation": "h",
"yanchor": "top",
"y": -0.15,
"xanchor": "center",
"x": 0.5,
},
margin={"l": 120, "r": 40, "t": 80, "b": 100},
)
# Update all x-axes and y-axes
fig.update_xaxes(title_font={"size": 11})
fig.update_yaxes(title_text="Count", title_font={"size": 11})
# Manually add row title annotations (left-aligned dataset names)
for row_idx, row_title in enumerate(row_titles, start=1):
# Calculate the y position for each row (at the top of the row)
# Account for vertical spacing between rows
spacing = 0.20 / max(n_rows, 1)
row_height = (1 - spacing * (n_rows - 1)) / n_rows
y_position = 1.08 - (row_idx - 1) * (row_height + spacing)
fig.add_annotation(
text=f"<b>{row_title}</b>",
xref="paper",
yref="paper",
x=-0.02,
y=y_position,
xanchor="left",
yanchor="top",
textangle=0,
font={"size": 20},
showarrow=False,
)
return fig
def _add_classification_subplot(
fig: go.Figure,
dataset: TrainingSet,
task: str,
row: int,
col: int,
n_rows: int,
) -> None:
"""Add a classification task subplot showing categorical distribution.
Args:
fig: Plotly figure to add subplot to.
dataset: TrainingSet containing the data.
task: Task name for color selection.
row: Subplot row index.
col: Subplot column index.
n_rows: Total number of rows (for legend control).
"""
y_binned = dataset.targets["y"]
categories = y_binned.cat.categories.tolist()
colors = get_palette(task, len(categories) + 2)[1:-1] # Avoid too light/dark colors
# Calculate counts for train and test splits
train_counts = [((y_binned == cat) & (dataset.split == "train")).sum() for cat in categories]
test_counts = [((y_binned == cat) & (dataset.split == "test")).sum() for cat in categories]
# Add train bars
fig.add_trace(
go.Bar(
name="Train",
x=categories,
y=train_counts,
marker_color=colors,
opacity=0.9,
text=train_counts,
textposition="inside",
textfont={"size": 9, "color": "white"},
legendgroup="split",
showlegend=(row == 1 and col == 1), # Only show legend once
),
row=row,
col=col,
)
# Add test bars
fig.add_trace(
go.Bar(
name="Test",
x=categories,
y=test_counts,
marker_color=colors,
opacity=0.5,
text=test_counts,
textposition="inside",
textfont={"size": 9, "color": "white"},
legendgroup="split",
showlegend=(row == 1 and col == 1), # Only show legend once
),
row=row,
col=col,
)
# Update subplot layout
fig.update_xaxes(tickangle=-45, row=row, col=col)
def _add_regression_subplot(
fig: go.Figure,
dataset: TrainingSet,
task: str,
row: int,
col: int,
n_rows: int,
) -> None:
"""Add a regression task subplot showing histogram distribution.
Args:
fig: Plotly figure to add subplot to.
dataset: TrainingSet containing the data.
task: Task name for color selection.
row: Subplot row index.
col: Subplot column index.
n_rows: Total number of rows (for legend control).
"""
# Get raw values
z_values = dataset.targets["z"]
# Split into train and test
train_values = z_values[dataset.split == "train"]
test_values = z_values[dataset.split == "test"]
# Determine bin edges based on combined data
all_values = pd.concat([train_values, test_values])
# Use a reasonable number of bins
n_bins = min(30, int(np.sqrt(len(all_values))))
bin_edges = np.histogram_bin_edges(all_values, bins=n_bins)
# Get colors for the task
colors = get_palette(task, 6)[1:5:2] # Use less bright/dark colors
# Add train histogram
fig.add_trace(
go.Histogram(
name="Train",
x=train_values,
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
marker_color=colors[0],
opacity=0.7,
legendgroup="split",
showlegend=(row == 1 and col == 1), # Only show legend once
),
row=row,
col=col,
)
# Add test histogram
fig.add_trace(
go.Histogram(
name="Test",
x=test_values,
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
marker_color=colors[1],
opacity=0.5,
legendgroup="split",
showlegend=(row == 1 and col == 1), # Only show legend once
),
row=row,
col=col,
)
# Update barmode to overlay for histograms
fig.update_xaxes(title_text=task.capitalize(), row=row, col=col)
def _assign_split_colors(gdf: pd.DataFrame) -> pd.DataFrame:
"""Assign colors based on train/test split.
Args:
gdf: GeoDataFrame with 'split' column.
Returns:
GeoDataFrame with 'fill_color' column added.
"""
split_colors = get_palette("split", 2)
color_map = {"train": hex_to_rgb(split_colors[0]), "test": hex_to_rgb(split_colors[1])}
# Convert to list to avoid categorical hashing issues with list values
gdf["fill_color"] = [color_map[x] for x in gdf["split"]]
return gdf
def _assign_classification_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
"""Assign colors for classification tasks based on categorical labels.
Args:
gdf: GeoDataFrame with categorical 'y' column.
task: Task name for color selection.
Returns:
GeoDataFrame with 'fill_color' column added.
"""
categories = gdf["y"].cat.categories.tolist()
colors_palette = get_palette(task, len(categories))
color_map = {cat: hex_to_rgb(colors_palette[i]) for i, cat in enumerate(categories)}
# Convert to list to avoid categorical hashing issues with list values
gdf["fill_color"] = [color_map[x] for x in gdf["y"]]
return gdf
def _assign_regression_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
"""Assign colors for regression tasks based on continuous values.
Args:
gdf: GeoDataFrame with numeric 'z' column.
task: Task name for color selection.
Returns:
GeoDataFrame with 'fill_color' column added.
"""
z_values = gdf["z"]
z_min, z_max = z_values.min(), z_values.max()
# Normalize to [0, 1]
if z_max > z_min:
z_normalized = (z_values - z_min) / (z_max - z_min)
else:
z_normalized = pd.Series([0.5] * len(z_values), index=z_values.index)
# Get colormap and map normalized values to RGB
cmap = get_cmap(task)
def value_to_rgb(normalized_val):
rgba = cmap(normalized_val)
return [int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)]
gdf["fill_color"] = z_normalized.apply(value_to_rgb)
return gdf
def _assign_elevation(gdf: pd.DataFrame, make_3d_map: bool, is_classification: bool) -> pd.DataFrame:
"""Assign elevation values for 3D visualization.
Args:
gdf: GeoDataFrame to add elevation to.
make_3d_map: Whether to create a 3D map.
is_classification: Whether this is a classification task.
Returns:
GeoDataFrame with 'elevation' column added.
"""
if not make_3d_map:
gdf["elevation"] = 0
return gdf
print(f"{make_3d_map=}, {is_classification=}")
if is_classification:
# For classification, use category codes for elevation
# Convert to int to avoid overflow with int8 categorical codes
z_values = gdf["y"].cat.codes.astype(int)
else:
# For regression, use z values for elevation
z_values = gdf["z"]
z_min, z_max = z_values.min(), z_values.max()
if z_max > z_min:
gdf["elevation"] = ((z_values - z_min) / (z_max - z_min)).fillna(0)
else:
gdf["elevation"] = 0
print(gdf["elevation"].describe())
return gdf
def create_target_spatial_distribution_map(
training_set: TrainingSet, make_3d_map: bool, show_split: bool, task: Task
) -> pdk.Deck:
"""Create a spatial distribution map of target labels using PyDeck.
Args:
training_set (TrainingSet): The training set containing spatial data and target labels.
make_3d_map (bool): Whether to create a 3D map (True) or a 2D map (False).
show_split (bool): Whether to color by train/test split (True) or by target labels (False).
task (Task): The task type to determine color mapping.
Returns:
pdk.Deck: A PyDeck Deck object representing the spatial distribution map.
"""
# Get the targets GeoDataFrame
gdf = training_set.targets.copy()
# Add split information
gdf["split"] = training_set.split.to_numpy()
# Determine if this is a classification or regression task
is_classification = gdf["y"].dtype.name == "category"
# Convert to WGS84 for pydeck
gdf_wgs84 = gdf.to_crs("EPSG:4326")
# Fix antimeridian issues
def fix_hex_geometry(geom):
try:
return shape(antimeridian.fix_shape(geom))
except (ValueError, Exception):
return geom
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Assign colors based on mode
if show_split:
gdf_wgs84 = _assign_split_colors(gdf_wgs84)
elif is_classification:
gdf_wgs84 = _assign_classification_colors(gdf_wgs84, task)
else:
gdf_wgs84 = _assign_regression_colors(gdf_wgs84, task)
# Add elevation for 3D visualization
gdf_wgs84 = _assign_elevation(gdf_wgs84, make_3d_map, is_classification)
# Convert to GeoJSON format
geojson_data = []
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"fill_color": row["fill_color"],
"elevation": row["elevation"],
"split": row["split"],
"y_label": str(row["y"]),
"z_value": float(row["z"]) if pd.notna(row["z"]) else None,
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=0.7,
stroked=True,
filled=True,
extruded=make_3d_map,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if make_3d_map else 0,
elevation_scale=500000 if make_3d_map else 0,
pickable=True,
)
# Set initial view state (centered on the Arctic)
view_state = pdk.ViewState(
latitude=70,
longitude=0,
zoom=2 if not make_3d_map else 1.5,
pitch=0 if not make_3d_map else 45,
)
# Create tooltip
tooltip_html = "<b>Split:</b> {split}<br/><b>Label:</b> {y_label}<br/><b>Value:</b> {z_value}<br/>"
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": tooltip_html,
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
return deck

View file

@ -14,7 +14,7 @@ from entropice.dashboard.plots.dataset_statistics import (
create_sample_count_bar_chart, create_sample_count_bar_chart,
) )
from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import ( from entropice.utils.types import (
GridConfig, GridConfig,
@ -436,9 +436,104 @@ def _render_aggregation_selection(
return dimension_filters return dimension_filters
def render_ensemble_details(
selected_members: list[L2SourceDataset],
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
):
"""Render ensemble-specific details: feature breakdown and member information.
This function displays the feature breakdown by data source and detailed
member information, independent of task or target selection. Can be reused
on other pages by supplying a dataset ensemble.
Args:
selected_members: List of selected member datasets
selected_member_stats: Statistics for selected members
"""
# Calculate total features for selected members
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
# Feature breakdown by source
st.markdown("#### Feature Breakdown by Data Source")
breakdown_data = []
for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get all unique data sources and create color map
unique_members_for_color = sorted(selected_member_stats.keys())
source_color_map_raw = {}
for member in unique_members_for_color:
source = member.split("-")[0]
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
palette = get_palette(source, n_colors=n_members + 2)
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
source_color_map_raw[member] = palette[idx + 1]
# Create and display pie chart
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
st.plotly_chart(fig, width="stretch")
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, width="stretch")
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
# Create detailed table
member_details_dict = {
member: {
"feature_count": ms.feature_count,
"variable_names": ms.variable_names,
"dimensions": ms.dimensions,
"size_bytes": ms.size_bytes,
}
for member, ms in selected_member_stats.items()
}
details_df = create_member_details_table(member_details_dict)
st.dataframe(details_df, hide_index=True, width="stretch")
# Individual member details
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size:,}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
def _render_configuration_summary( def _render_configuration_summary(
selected_members: list[L2SourceDataset], selected_members: list[L2SourceDataset],
selected_member_stats: dict[str, MemberStatistics], selected_member_stats: dict[L2SourceDataset, MemberStatistics],
selected_target: TargetDataset, selected_target: TargetDataset,
selected_task: Task, selected_task: Task,
selected_temporal_mode: TemporalMode, selected_temporal_mode: TemporalMode,
@ -527,81 +622,8 @@ def _render_configuration_summary(
) )
st.dataframe(class_dist_df, hide_index=True, width="stretch") st.dataframe(class_dist_df, hide_index=True, width="stretch")
# Feature breakdown by source # Render ensemble-specific details
st.markdown("#### Feature Breakdown by Data Source") render_ensemble_details(selected_members, selected_member_stats)
breakdown_data = []
for member, member_stats in selected_member_stats.items():
breakdown_data.append(
{
"Data Source": member,
"Number of Features": member_stats.feature_count,
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
}
)
breakdown_df = pd.DataFrame(breakdown_data)
# Get all unique data sources and create color map
unique_members_for_color = sorted(selected_member_stats.keys())
source_color_map_raw = {}
for member in unique_members_for_color:
source = member.split("-")[0]
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
palette = get_palette(source, n_colors=n_members + 2)
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
source_color_map_raw[member] = palette[idx + 1]
# Create and display pie chart
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
st.plotly_chart(fig, width="stretch")
# Show detailed table
st.dataframe(breakdown_df, hide_index=True, width="stretch")
# Detailed member information
with st.expander("📦 Detailed Source Information", expanded=False):
# Create detailed table
member_details_dict = {
member: {
"feature_count": ms.feature_count,
"variable_names": ms.variable_names,
"dimensions": ms.dimensions,
"size_bytes": ms.size_bytes,
}
for member, ms in selected_member_stats.items()
}
details_df = create_member_details_table(member_details_dict)
st.dataframe(details_df, hide_index=True, width="stretch")
# Individual member details
for member, member_stats in selected_member_stats.items():
st.markdown(f"### {member}")
# Variables
st.markdown("**Variables:**")
vars_html = " ".join(
[
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
for v in member_stats.variable_names
]
)
st.markdown(vars_html, unsafe_allow_html=True)
# Dimensions
st.markdown("**Dimensions:**")
dim_html = " ".join(
[
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
f"{dim_name}: {dim_size:,}</span>"
for dim_name, dim_size in member_stats.dimensions.items()
]
)
st.markdown(dim_html, unsafe_allow_html=True)
st.markdown("---")
@st.fragment @st.fragment
@ -650,7 +672,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
# Use pre-computed statistics (much faster) # Use pre-computed statistics (much faster)
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index] stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
# Filter to selected members only # Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members} selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
m: stats.members[m] for m in selected_members if m in stats.members
}
else: else:
# Create the actual ensemble with selected members and dimension filters # Create the actual ensemble with selected members and dimension filters
ensemble = DatasetEnsemble( ensemble = DatasetEnsemble(
@ -663,7 +687,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
# Recompute stats with the filtered ensemble # Recompute stats with the filtered ensemble
stats = DatasetStatistics.from_ensemble(ensemble) stats = DatasetStatistics.from_ensemble(ensemble)
# Filter to selected members only # Filter to selected members only
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members} selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
m: stats.members[m] for m in selected_members if m in stats.members
}
# Render configuration summary and statistics # Render configuration summary and statistics
_render_configuration_summary( _render_configuration_summary(
@ -678,16 +704,16 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
st.info("👆 Select at least one data source to see feature statistics") st.info("👆 Select at least one data source to see feature statistics")
def render_dataset_statistics(all_stats: DatasetStatsCache): def render_dataset_statistics(
all_stats: DatasetStatsCache,
training_sample_df: pd.DataFrame,
feature_breakdown_df: pd.DataFrame,
comparison_df: pd.DataFrame,
inference_sample_df: pd.DataFrame,
):
"""Render the dataset statistics section with sample and feature counts.""" """Render the dataset statistics section with sample and feature counts."""
st.header("📈 Dataset Statistics") st.header("📈 Dataset Statistics")
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
# Create tabs for different analysis views # Create tabs for different analysis views
analysis_tabs = st.tabs( analysis_tabs = st.tabs(
[ [

View file

@ -0,0 +1,225 @@
"""Target dataset dashboard section."""
from typing import cast
import matplotlib.colors as mcolors
import streamlit as st
from entropice.dashboard.plots.targets import create_target_distribution_plot, create_target_spatial_distribution_map
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
from entropice.ml.dataset import TrainingSet
from entropice.utils.types import TargetDataset, Task
def _render_distribution(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
st.subheader("Target Labels Distribution")
fig = create_target_distribution_plot(train_data_dict)
st.plotly_chart(fig, width="stretch")
st.markdown(
"""
The above plot shows the distribution of target labels in the training and test datasets.
You can use this information to understand the balance of classes and identify any potential
issues with class imbalance that may affect model training.
"""
)
def _render_split_legend(training_set: TrainingSet):
"""Render legend for train/test split visualization."""
st.markdown("**Data Split:**")
split_colors = get_palette("split", 2)
col1, col2 = st.columns(2)
with col1:
color_rgb = hex_to_rgb(split_colors[0])
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; '
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
f"<span>Train</span></div>",
unsafe_allow_html=True,
)
with col2:
color_rgb = hex_to_rgb(split_colors[1])
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; '
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
f"<span>Test</span></div>",
unsafe_allow_html=True,
)
# Show split statistics
train_count = (training_set.split == "train").sum()
test_count = (training_set.split == "test").sum()
total = len(training_set)
st.caption(
f"Train: {train_count:,} ({train_count / total * 100:.1f}%) | "
f"Test: {test_count:,} ({test_count / total * 100:.1f}%)"
)
def _render_classification_legend(training_set: TrainingSet, selected_task: Task):
"""Render legend for classification task visualization."""
st.markdown("**Target Classes:**")
categories = training_set.targets["y"].cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
intervals = training_set.target_intervals
# Show categories with colors and intervals
for i, cat in enumerate(categories):
color = colors_palette[i]
interval_min, interval_max = intervals[i]
# Format interval display
interval_str = _format_interval(interval_min, interval_max, selected_task)
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{cat}{interval_str}</span></div>",
unsafe_allow_html=True,
)
# Show total samples
st.caption(f"Total samples: {len(training_set):,}")
def _format_interval(interval_min, interval_max, selected_task: Task) -> str:
"""Format interval string based on task type."""
if interval_min is None or interval_max is None:
return ""
if selected_task in ["count", "count_regimes"]:
# Integer values for count
if interval_min == interval_max:
return f" ({int(interval_min)})"
return f" ({int(interval_min)}-{int(interval_max)})"
if selected_task in ["density", "density_regimes"]:
# Percentage values for density
if interval_min == interval_max:
return f" ({interval_min * 100:.4f}%)"
return f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
# Binary or other
return ""
def _render_regression_legend(training_set: TrainingSet, selected_task: Task):
"""Render legend for regression task visualization."""
st.markdown("**Target Values:**")
z_values = training_set.targets["z"]
z_min, z_max = z_values.min(), z_values.max()
z_mean = z_values.mean()
z_median = z_values.median()
st.markdown(f"- **Range:** {z_min:.4f} to {z_max:.4f}\n- **Mean:** {z_mean:.4f}\n- **Median:** {z_median:.4f}")
# Show gradient visualization using the actual colormap
st.markdown("**Color Gradient:**")
# Get the same colormap used in the map
cmap = get_cmap(selected_task)
# Sample colors from the colormap
n_samples = 10
gradient_colors = [mcolors.to_hex(cmap(i / (n_samples - 1))) for i in range(n_samples)]
gradient_str = ", ".join(gradient_colors)
st.markdown(
f'<div style="display: flex; align-items: center; margin-top: 8px;">'
f'<span style="margin-right: 8px;">Low ({z_min:.4f})</span>'
f'<div style="width: 200px; height: 20px; '
f"background: linear-gradient(to right, {gradient_str}); "
f'border: 1px solid #ccc;"></div>'
f'<span style="margin-left: 8px;">High ({z_max:.4f})</span>'
f"</div>",
unsafe_allow_html=True,
)
def _render_elevation_info(training_set: TrainingSet):
"""Render 3D elevation information."""
st.markdown("---")
st.markdown("**3D Elevation:**")
is_classification = training_set.targets["y"].dtype.name == "category"
if is_classification:
st.info("💡 Height represents category level. Rotate the map by holding Ctrl/Cmd and dragging with your mouse.")
else:
z_values = training_set.targets["z"]
z_min, z_max = z_values.min(), z_values.max()
st.info(
f"💡 Height represents target value (normalized): {z_min:.4f} (low) → {z_max:.4f} (high). "
"Rotate the map by holding Ctrl/Cmd and dragging with your mouse."
)
@st.fragment
def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
st.subheader("Target Labels Spatial Distribution")
cols = st.columns([2, 2, 1])
with cols[0]:
selected_target = cast(
TargetDataset,
st.selectbox(
"Select Target Dataset",
options=sorted(train_data_dict.keys()),
index=0,
),
)
with cols[1]:
selected_task = cast(
Task,
st.selectbox(
"Select Task",
options=sorted(train_data_dict[selected_target].keys()),
index=0,
),
)
with cols[2]:
# Controls weather a 3D map or a 2D map is shown
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
# Controls what should be shows, either the split or the labels / values
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False))
training_set = train_data_dict[selected_target][selected_task]
map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task)
st.pydeck_chart(map_deck)
# Add legend
with st.expander("🎨 Legend", expanded=True):
if show_split:
_render_split_legend(training_set)
else:
# Show target labels legend
is_classification = training_set.targets["y"].dtype.name == "category"
if is_classification:
_render_classification_legend(training_set, selected_task)
else:
_render_regression_legend(training_set, selected_task)
# Add elevation info for 3D maps
if make_3d_map:
_render_elevation_info(training_set)
def render_target_information_tab(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
"""Render target labels distribution and spatial visualization.
Args:
train_data_dict: Nested dictionary with structure:
{target_dataset: {task: TrainingSet}}
where target_dataset is in ["darts_v1", "darts_mllabels"]
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
"""
_render_distribution(train_data_dict)
st.divider()
_render_target_map(train_data_dict)

View file

@ -31,6 +31,20 @@ import streamlit as st
from pypalettes import load_cmap from pypalettes import load_cmap
def hex_to_rgb(hex_color: str) -> list[int]:
"""Convert hex color to RGB list.
Args:
hex_color: Hex color string (e.g., '#FF5733').
Returns:
List of [R, G, B] values (0-255).
"""
hex_color = hex_color.lstrip("#")
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
def get_cmap(variable: str) -> mcolors.Colormap: def get_cmap(variable: str) -> mcolors.Colormap:
"""Get a color palette by a "data" variable. """Get a color palette by a "data" variable.

View file

@ -0,0 +1,155 @@
"""Data page: Visualization of the data."""
from typing import cast
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
from entropice.dashboard.sections.targets import render_target_information_tab
from entropice.dashboard.utils.stats import DatasetStatistics
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.utils.types import (
GridConfig,
L2SourceDataset,
TargetDataset,
Task,
TemporalMode,
all_target_datasets,
all_tasks,
grid_configs,
)
def render_dataset_configuration_sidebar() -> DatasetEnsemble:
"""Render dataset configuration selector in sidebar with form.
Stores the selected ensemble in session state when form is submitted.
"""
with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration")
# Grid selection
grid_options = [gc.display_name for gc in grid_configs]
grid_level_combined = st.selectbox(
"Grid Configuration",
options=grid_options,
index=0,
help="Select the grid system and resolution level",
)
# Find the selected grid config
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
# Temporal mode selection
temporal_mode = st.selectbox(
"Temporal Mode",
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
index=0,
format_func=lambda x: "Synopsis (all years)"
if x == "synopsis"
else "Years-as-Features"
if x == "feature"
else f"Year {x}",
help="Select temporal mode: 'synopsis' for temporal features or specific year",
)
# Members selection
st.subheader("Dataset Members")
all_members = cast(
list[L2SourceDataset],
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
)
selected_members: list[L2SourceDataset] = []
for member in all_members:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member)
# Form submit button
submitted = st.form_submit_button(
"Load Dataset",
type="primary",
use_container_width=True,
disabled=len(selected_members) == 0,
)
if not submitted:
st.info("👆 Click 'Load Dataset' to apply changes.")
st.stop()
ensemble = DatasetEnsemble(
grid=selected_grid_config.grid,
level=selected_grid_config.level,
temporal_mode=cast(TemporalMode, temporal_mode),
members=selected_members,
)
return ensemble
def render_dataset_page():
"""Render the Dataset page of the dashboard."""
st.title("📊 Entropice Data Visualization")
st.markdown(
"""
This section allows you to visualize the source datasets and training data
used in your experiments. Use the sidebar to configure the dataset parameters.
"""
)
# Get the selected dataset ensemble from sidebar
ensemble = render_dataset_configuration_sidebar()
# Display dataset ID in a styled container
st.info(f"Loaded Dataset with ID `{ensemble.id()}`")
st.divider()
# Render dataset statistics section
stats = DatasetStatistics.from_ensemble(ensemble)
render_ensemble_details(ensemble.members, stats.members)
st.divider()
# Load data and precompute visualizations
# First, load for all task - target combinations the training data
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
for target in all_target_datasets:
train_data_dict[target] = {}
for task in all_tasks:
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
# Create tabs for different data views
tab_names = ["🎯 Targets", "📐 Areas"]
# Add tabs for each member based on what's in the ensemble
if "AlphaEarth" in ensemble.members:
tab_names.append("🌍 AlphaEarth")
if "ArcticDEM" in ensemble.members:
tab_names.append("🏔️ ArcticDEM")
if era5_members:
tab_names.append("🌡️ ERA5")
tabs = st.tabs(tab_names)
with tabs[0]:
st.header("🎯 Target Labels Visualization")
render_target_information_tab(train_data_dict)
with tabs[1]:
st.header("📐 Areas Visualization")
tab_index = 2
if "AlphaEarth" in ensemble.members:
with tabs[tab_index]:
st.header("🌍 AlphaEarth Visualization")
tab_index += 1
if "ArcticDEM" in ensemble.members:
with tabs[tab_index]:
st.header("🏔️ ArcticDEM Visualization")
tab_index += 1
if era5_members:
with tabs[tab_index]:
st.header("🌡️ ERA5 Visualization")
st.balloons()
stopwatch.summary()

View file

@ -9,9 +9,7 @@ from entropice.dashboard.sections.experiment_results import (
render_training_results_summary, render_training_results_summary,
) )
from entropice.dashboard.utils.loaders import load_all_training_results from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.stats import ( from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
load_all_default_dataset_statistics,
)
def render_overview_page(): def render_overview_page():
@ -47,7 +45,12 @@ def render_overview_page():
# Render dataset analysis section # Render dataset analysis section
all_stats = load_all_default_dataset_statistics() all_stats = load_all_default_dataset_statistics()
render_dataset_statistics(all_stats) training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
st.balloons() st.balloons()
stopwatch.summary() stopwatch.summary()