Add a Dataset Page to deprecate the Training Data page
This commit is contained in:
parent
a6e9a91692
commit
cb7b0f9e6b
7 changed files with 986 additions and 90 deletions
|
|
@ -13,6 +13,7 @@ Pages:
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
|
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
|
||||||
|
from entropice.dashboard.views.dataset_page import render_dataset_page
|
||||||
from entropice.dashboard.views.inference_page import render_inference_page
|
from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.views.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
|
|
@ -26,6 +27,7 @@ def main():
|
||||||
|
|
||||||
# Setup Navigation
|
# Setup Navigation
|
||||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
|
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
||||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
||||||
|
|
@ -35,6 +37,7 @@ def main():
|
||||||
pg = st.navigation(
|
pg = st.navigation(
|
||||||
{
|
{
|
||||||
"Overview": [overview_page],
|
"Overview": [overview_page],
|
||||||
|
"Data": [data_page],
|
||||||
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
||||||
"Model State": [model_state_page],
|
"Model State": [model_state_page],
|
||||||
"Inference": [inference_page],
|
"Inference": [inference_page],
|
||||||
|
|
|
||||||
470
src/entropice/dashboard/plots/targets.py
Normal file
470
src/entropice/dashboard/plots/targets.py
Normal file
|
|
@ -0,0 +1,470 @@
|
||||||
|
"""Plots for visualizing target labels in datasets."""
|
||||||
|
|
||||||
|
import antimeridian
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import pydeck as pdk
|
||||||
|
from plotly.subplots import make_subplots
|
||||||
|
from shapely.geometry import shape
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
|
||||||
|
from entropice.ml.dataset import TrainingSet
|
||||||
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_distribution_plot(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]) -> go.Figure:
|
||||||
|
"""Create a plot showing the distribution of target labels across datasets and tasks.
|
||||||
|
|
||||||
|
This function creates a comprehensive visualization showing:
|
||||||
|
- Classification tasks (binary, count_regimes, density_regimes): categorical distributions
|
||||||
|
- Regression tasks (count, density): histogram distributions
|
||||||
|
- Train/test split information for each combination
|
||||||
|
- Separate subplots for each TargetDataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data_dict: Nested dictionary with structure:
|
||||||
|
{target_dataset: {task: TrainingSet}}
|
||||||
|
where target_dataset is in ["darts_v1", "darts_mllabels"]
|
||||||
|
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with subplots showing target distributions.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Define task types and their properties
|
||||||
|
classification_tasks: list[Task] = ["binary", "count_regimes", "density_regimes"]
|
||||||
|
|
||||||
|
task_titles: dict[Task, str] = {
|
||||||
|
"binary": "Binary",
|
||||||
|
"count_regimes": "Count Regimes",
|
||||||
|
"density_regimes": "Density Regimes",
|
||||||
|
"count": "Count (Regression)",
|
||||||
|
"density": "Density (Regression)",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all available target datasets and tasks
|
||||||
|
target_datasets = sorted(train_data_dict.keys())
|
||||||
|
all_tasks = sorted(
|
||||||
|
{task for tasks in train_data_dict.values() for task in tasks.keys()},
|
||||||
|
key=lambda x: (x not in classification_tasks, x), # Classification first
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create subplots: one row per target dataset, one column per task
|
||||||
|
n_rows = len(target_datasets)
|
||||||
|
n_cols = len(all_tasks)
|
||||||
|
|
||||||
|
# Create column titles (tasks) for the subplots
|
||||||
|
column_titles = [task_titles.get(task, str(task)) for task in all_tasks] # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Create row titles (target datasets)
|
||||||
|
row_titles = [target_ds.replace("_", " ").title() for target_ds in target_datasets]
|
||||||
|
|
||||||
|
fig = make_subplots(
|
||||||
|
rows=n_rows,
|
||||||
|
cols=n_cols,
|
||||||
|
column_titles=column_titles,
|
||||||
|
vertical_spacing=0.20 / max(n_rows, 1),
|
||||||
|
horizontal_spacing=0.08 / max(n_cols, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate through each target dataset and task
|
||||||
|
for row_idx, target_dataset in enumerate(target_datasets, start=1):
|
||||||
|
task_dict = train_data_dict[target_dataset] # type: ignore[index]
|
||||||
|
|
||||||
|
for col_idx, task in enumerate(all_tasks, start=1):
|
||||||
|
if task not in task_dict:
|
||||||
|
# Skip if this task is not available for this target dataset
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset = task_dict[task] # type: ignore[index]
|
||||||
|
|
||||||
|
# Determine if this is a classification or regression task
|
||||||
|
if task in classification_tasks:
|
||||||
|
_add_classification_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
|
||||||
|
else:
|
||||||
|
_add_regression_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
|
||||||
|
|
||||||
|
# Update layout
|
||||||
|
fig.update_layout(
|
||||||
|
height=400 * n_rows,
|
||||||
|
showlegend=True,
|
||||||
|
legend={
|
||||||
|
"orientation": "h",
|
||||||
|
"yanchor": "top",
|
||||||
|
"y": -0.15,
|
||||||
|
"xanchor": "center",
|
||||||
|
"x": 0.5,
|
||||||
|
},
|
||||||
|
margin={"l": 120, "r": 40, "t": 80, "b": 100},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update all x-axes and y-axes
|
||||||
|
fig.update_xaxes(title_font={"size": 11})
|
||||||
|
fig.update_yaxes(title_text="Count", title_font={"size": 11})
|
||||||
|
|
||||||
|
# Manually add row title annotations (left-aligned dataset names)
|
||||||
|
for row_idx, row_title in enumerate(row_titles, start=1):
|
||||||
|
# Calculate the y position for each row (at the top of the row)
|
||||||
|
# Account for vertical spacing between rows
|
||||||
|
spacing = 0.20 / max(n_rows, 1)
|
||||||
|
row_height = (1 - spacing * (n_rows - 1)) / n_rows
|
||||||
|
y_position = 1.08 - (row_idx - 1) * (row_height + spacing)
|
||||||
|
|
||||||
|
fig.add_annotation(
|
||||||
|
text=f"<b>{row_title}</b>",
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=-0.02,
|
||||||
|
y=y_position,
|
||||||
|
xanchor="left",
|
||||||
|
yanchor="top",
|
||||||
|
textangle=0,
|
||||||
|
font={"size": 20},
|
||||||
|
showarrow=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def _add_classification_subplot(
|
||||||
|
fig: go.Figure,
|
||||||
|
dataset: TrainingSet,
|
||||||
|
task: str,
|
||||||
|
row: int,
|
||||||
|
col: int,
|
||||||
|
n_rows: int,
|
||||||
|
) -> None:
|
||||||
|
"""Add a classification task subplot showing categorical distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: Plotly figure to add subplot to.
|
||||||
|
dataset: TrainingSet containing the data.
|
||||||
|
task: Task name for color selection.
|
||||||
|
row: Subplot row index.
|
||||||
|
col: Subplot column index.
|
||||||
|
n_rows: Total number of rows (for legend control).
|
||||||
|
|
||||||
|
"""
|
||||||
|
y_binned = dataset.targets["y"]
|
||||||
|
categories = y_binned.cat.categories.tolist()
|
||||||
|
colors = get_palette(task, len(categories) + 2)[1:-1] # Avoid too light/dark colors
|
||||||
|
|
||||||
|
# Calculate counts for train and test splits
|
||||||
|
train_counts = [((y_binned == cat) & (dataset.split == "train")).sum() for cat in categories]
|
||||||
|
test_counts = [((y_binned == cat) & (dataset.split == "test")).sum() for cat in categories]
|
||||||
|
|
||||||
|
# Add train bars
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
name="Train",
|
||||||
|
x=categories,
|
||||||
|
y=train_counts,
|
||||||
|
marker_color=colors,
|
||||||
|
opacity=0.9,
|
||||||
|
text=train_counts,
|
||||||
|
textposition="inside",
|
||||||
|
textfont={"size": 9, "color": "white"},
|
||||||
|
legendgroup="split",
|
||||||
|
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add test bars
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
name="Test",
|
||||||
|
x=categories,
|
||||||
|
y=test_counts,
|
||||||
|
marker_color=colors,
|
||||||
|
opacity=0.5,
|
||||||
|
text=test_counts,
|
||||||
|
textposition="inside",
|
||||||
|
textfont={"size": 9, "color": "white"},
|
||||||
|
legendgroup="split",
|
||||||
|
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update subplot layout
|
||||||
|
fig.update_xaxes(tickangle=-45, row=row, col=col)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_regression_subplot(
|
||||||
|
fig: go.Figure,
|
||||||
|
dataset: TrainingSet,
|
||||||
|
task: str,
|
||||||
|
row: int,
|
||||||
|
col: int,
|
||||||
|
n_rows: int,
|
||||||
|
) -> None:
|
||||||
|
"""Add a regression task subplot showing histogram distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: Plotly figure to add subplot to.
|
||||||
|
dataset: TrainingSet containing the data.
|
||||||
|
task: Task name for color selection.
|
||||||
|
row: Subplot row index.
|
||||||
|
col: Subplot column index.
|
||||||
|
n_rows: Total number of rows (for legend control).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get raw values
|
||||||
|
z_values = dataset.targets["z"]
|
||||||
|
|
||||||
|
# Split into train and test
|
||||||
|
train_values = z_values[dataset.split == "train"]
|
||||||
|
test_values = z_values[dataset.split == "test"]
|
||||||
|
|
||||||
|
# Determine bin edges based on combined data
|
||||||
|
all_values = pd.concat([train_values, test_values])
|
||||||
|
# Use a reasonable number of bins
|
||||||
|
n_bins = min(30, int(np.sqrt(len(all_values))))
|
||||||
|
bin_edges = np.histogram_bin_edges(all_values, bins=n_bins)
|
||||||
|
|
||||||
|
# Get colors for the task
|
||||||
|
colors = get_palette(task, 6)[1:5:2] # Use less bright/dark colors
|
||||||
|
|
||||||
|
# Add train histogram
|
||||||
|
fig.add_trace(
|
||||||
|
go.Histogram(
|
||||||
|
name="Train",
|
||||||
|
x=train_values,
|
||||||
|
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
|
||||||
|
marker_color=colors[0],
|
||||||
|
opacity=0.7,
|
||||||
|
legendgroup="split",
|
||||||
|
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add test histogram
|
||||||
|
fig.add_trace(
|
||||||
|
go.Histogram(
|
||||||
|
name="Test",
|
||||||
|
x=test_values,
|
||||||
|
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
|
||||||
|
marker_color=colors[1],
|
||||||
|
opacity=0.5,
|
||||||
|
legendgroup="split",
|
||||||
|
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update barmode to overlay for histograms
|
||||||
|
fig.update_xaxes(title_text=task.capitalize(), row=row, col=col)
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_split_colors(gdf: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Assign colors based on train/test split.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gdf: GeoDataFrame with 'split' column.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with 'fill_color' column added.
|
||||||
|
|
||||||
|
"""
|
||||||
|
split_colors = get_palette("split", 2)
|
||||||
|
color_map = {"train": hex_to_rgb(split_colors[0]), "test": hex_to_rgb(split_colors[1])}
|
||||||
|
# Convert to list to avoid categorical hashing issues with list values
|
||||||
|
gdf["fill_color"] = [color_map[x] for x in gdf["split"]]
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_classification_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
|
||||||
|
"""Assign colors for classification tasks based on categorical labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gdf: GeoDataFrame with categorical 'y' column.
|
||||||
|
task: Task name for color selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with 'fill_color' column added.
|
||||||
|
|
||||||
|
"""
|
||||||
|
categories = gdf["y"].cat.categories.tolist()
|
||||||
|
colors_palette = get_palette(task, len(categories))
|
||||||
|
color_map = {cat: hex_to_rgb(colors_palette[i]) for i, cat in enumerate(categories)}
|
||||||
|
# Convert to list to avoid categorical hashing issues with list values
|
||||||
|
gdf["fill_color"] = [color_map[x] for x in gdf["y"]]
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_regression_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
|
||||||
|
"""Assign colors for regression tasks based on continuous values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gdf: GeoDataFrame with numeric 'z' column.
|
||||||
|
task: Task name for color selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with 'fill_color' column added.
|
||||||
|
|
||||||
|
"""
|
||||||
|
z_values = gdf["z"]
|
||||||
|
z_min, z_max = z_values.min(), z_values.max()
|
||||||
|
|
||||||
|
# Normalize to [0, 1]
|
||||||
|
if z_max > z_min:
|
||||||
|
z_normalized = (z_values - z_min) / (z_max - z_min)
|
||||||
|
else:
|
||||||
|
z_normalized = pd.Series([0.5] * len(z_values), index=z_values.index)
|
||||||
|
|
||||||
|
# Get colormap and map normalized values to RGB
|
||||||
|
cmap = get_cmap(task)
|
||||||
|
|
||||||
|
def value_to_rgb(normalized_val):
|
||||||
|
rgba = cmap(normalized_val)
|
||||||
|
return [int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)]
|
||||||
|
|
||||||
|
gdf["fill_color"] = z_normalized.apply(value_to_rgb)
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_elevation(gdf: pd.DataFrame, make_3d_map: bool, is_classification: bool) -> pd.DataFrame:
|
||||||
|
"""Assign elevation values for 3D visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gdf: GeoDataFrame to add elevation to.
|
||||||
|
make_3d_map: Whether to create a 3D map.
|
||||||
|
is_classification: Whether this is a classification task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with 'elevation' column added.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not make_3d_map:
|
||||||
|
gdf["elevation"] = 0
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
print(f"{make_3d_map=}, {is_classification=}")
|
||||||
|
if is_classification:
|
||||||
|
# For classification, use category codes for elevation
|
||||||
|
# Convert to int to avoid overflow with int8 categorical codes
|
||||||
|
z_values = gdf["y"].cat.codes.astype(int)
|
||||||
|
else:
|
||||||
|
# For regression, use z values for elevation
|
||||||
|
z_values = gdf["z"]
|
||||||
|
z_min, z_max = z_values.min(), z_values.max()
|
||||||
|
|
||||||
|
if z_max > z_min:
|
||||||
|
gdf["elevation"] = ((z_values - z_min) / (z_max - z_min)).fillna(0)
|
||||||
|
else:
|
||||||
|
gdf["elevation"] = 0
|
||||||
|
|
||||||
|
print(gdf["elevation"].describe())
|
||||||
|
return gdf
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_spatial_distribution_map(
|
||||||
|
training_set: TrainingSet, make_3d_map: bool, show_split: bool, task: Task
|
||||||
|
) -> pdk.Deck:
|
||||||
|
"""Create a spatial distribution map of target labels using PyDeck.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_set (TrainingSet): The training set containing spatial data and target labels.
|
||||||
|
make_3d_map (bool): Whether to create a 3D map (True) or a 2D map (False).
|
||||||
|
show_split (bool): Whether to color by train/test split (True) or by target labels (False).
|
||||||
|
task (Task): The task type to determine color mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pdk.Deck: A PyDeck Deck object representing the spatial distribution map.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get the targets GeoDataFrame
|
||||||
|
gdf = training_set.targets.copy()
|
||||||
|
|
||||||
|
# Add split information
|
||||||
|
gdf["split"] = training_set.split.to_numpy()
|
||||||
|
|
||||||
|
# Determine if this is a classification or regression task
|
||||||
|
is_classification = gdf["y"].dtype.name == "category"
|
||||||
|
|
||||||
|
# Convert to WGS84 for pydeck
|
||||||
|
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix antimeridian issues
|
||||||
|
def fix_hex_geometry(geom):
|
||||||
|
try:
|
||||||
|
return shape(antimeridian.fix_shape(geom))
|
||||||
|
except (ValueError, Exception):
|
||||||
|
return geom
|
||||||
|
|
||||||
|
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||||
|
|
||||||
|
# Assign colors based on mode
|
||||||
|
if show_split:
|
||||||
|
gdf_wgs84 = _assign_split_colors(gdf_wgs84)
|
||||||
|
elif is_classification:
|
||||||
|
gdf_wgs84 = _assign_classification_colors(gdf_wgs84, task)
|
||||||
|
else:
|
||||||
|
gdf_wgs84 = _assign_regression_colors(gdf_wgs84, task)
|
||||||
|
|
||||||
|
# Add elevation for 3D visualization
|
||||||
|
gdf_wgs84 = _assign_elevation(gdf_wgs84, make_3d_map, is_classification)
|
||||||
|
|
||||||
|
# Convert to GeoJSON format
|
||||||
|
geojson_data = []
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
"elevation": row["elevation"],
|
||||||
|
"split": row["split"],
|
||||||
|
"y_label": str(row["y"]),
|
||||||
|
"z_value": float(row["z"]) if pd.notna(row["z"]) else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=0.7,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
extruded=make_3d_map,
|
||||||
|
wireframe=False,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
get_elevation="properties.elevation" if make_3d_map else 0,
|
||||||
|
elevation_scale=500000 if make_3d_map else 0,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set initial view state (centered on the Arctic)
|
||||||
|
view_state = pdk.ViewState(
|
||||||
|
latitude=70,
|
||||||
|
longitude=0,
|
||||||
|
zoom=2 if not make_3d_map else 1.5,
|
||||||
|
pitch=0 if not make_3d_map else 45,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tooltip
|
||||||
|
tooltip_html = "<b>Split:</b> {split}<br/><b>Label:</b> {y_label}<br/><b>Value:</b> {z_value}<br/>"
|
||||||
|
|
||||||
|
# Create deck
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={
|
||||||
|
"html": tooltip_html,
|
||||||
|
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||||
|
},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return deck
|
||||||
|
|
@ -14,7 +14,7 @@ from entropice.dashboard.plots.dataset_statistics import (
|
||||||
create_sample_count_bar_chart,
|
create_sample_count_bar_chart,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
from entropice.dashboard.utils.colors import get_palette
|
||||||
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
|
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble
|
||||||
from entropice.utils.types import (
|
from entropice.utils.types import (
|
||||||
GridConfig,
|
GridConfig,
|
||||||
|
|
@ -436,9 +436,104 @@ def _render_aggregation_selection(
|
||||||
return dimension_filters
|
return dimension_filters
|
||||||
|
|
||||||
|
|
||||||
|
def render_ensemble_details(
|
||||||
|
selected_members: list[L2SourceDataset],
|
||||||
|
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
||||||
|
):
|
||||||
|
"""Render ensemble-specific details: feature breakdown and member information.
|
||||||
|
|
||||||
|
This function displays the feature breakdown by data source and detailed
|
||||||
|
member information, independent of task or target selection. Can be reused
|
||||||
|
on other pages by supplying a dataset ensemble.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_members: List of selected member datasets
|
||||||
|
selected_member_stats: Statistics for selected members
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Calculate total features for selected members
|
||||||
|
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||||
|
|
||||||
|
# Feature breakdown by source
|
||||||
|
st.markdown("#### Feature Breakdown by Data Source")
|
||||||
|
|
||||||
|
breakdown_data = []
|
||||||
|
for member, member_stats in selected_member_stats.items():
|
||||||
|
breakdown_data.append(
|
||||||
|
{
|
||||||
|
"Data Source": member,
|
||||||
|
"Number of Features": member_stats.feature_count,
|
||||||
|
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
breakdown_df = pd.DataFrame(breakdown_data)
|
||||||
|
|
||||||
|
# Get all unique data sources and create color map
|
||||||
|
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||||
|
source_color_map_raw = {}
|
||||||
|
for member in unique_members_for_color:
|
||||||
|
source = member.split("-")[0]
|
||||||
|
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||||
|
palette = get_palette(source, n_colors=n_members + 2)
|
||||||
|
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||||
|
source_color_map_raw[member] = palette[idx + 1]
|
||||||
|
|
||||||
|
# Create and display pie chart
|
||||||
|
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
|
||||||
|
# Show detailed table
|
||||||
|
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
# Detailed member information
|
||||||
|
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||||
|
# Create detailed table
|
||||||
|
member_details_dict = {
|
||||||
|
member: {
|
||||||
|
"feature_count": ms.feature_count,
|
||||||
|
"variable_names": ms.variable_names,
|
||||||
|
"dimensions": ms.dimensions,
|
||||||
|
"size_bytes": ms.size_bytes,
|
||||||
|
}
|
||||||
|
for member, ms in selected_member_stats.items()
|
||||||
|
}
|
||||||
|
details_df = create_member_details_table(member_details_dict)
|
||||||
|
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
|
# Individual member details
|
||||||
|
for member, member_stats in selected_member_stats.items():
|
||||||
|
st.markdown(f"### {member}")
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
st.markdown("**Variables:**")
|
||||||
|
vars_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||||
|
for v in member_stats.variable_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(vars_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Dimensions
|
||||||
|
st.markdown("**Dimensions:**")
|
||||||
|
dim_html = " ".join(
|
||||||
|
[
|
||||||
|
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||||
|
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||||
|
f"{dim_name}: {dim_size:,}</span>"
|
||||||
|
for dim_name, dim_size in member_stats.dimensions.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
st.markdown(dim_html, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
|
||||||
def _render_configuration_summary(
|
def _render_configuration_summary(
|
||||||
selected_members: list[L2SourceDataset],
|
selected_members: list[L2SourceDataset],
|
||||||
selected_member_stats: dict[str, MemberStatistics],
|
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
||||||
selected_target: TargetDataset,
|
selected_target: TargetDataset,
|
||||||
selected_task: Task,
|
selected_task: Task,
|
||||||
selected_temporal_mode: TemporalMode,
|
selected_temporal_mode: TemporalMode,
|
||||||
|
|
@ -527,81 +622,8 @@ def _render_configuration_summary(
|
||||||
)
|
)
|
||||||
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||||
|
|
||||||
# Feature breakdown by source
|
# Render ensemble-specific details
|
||||||
st.markdown("#### Feature Breakdown by Data Source")
|
render_ensemble_details(selected_members, selected_member_stats)
|
||||||
|
|
||||||
breakdown_data = []
|
|
||||||
for member, member_stats in selected_member_stats.items():
|
|
||||||
breakdown_data.append(
|
|
||||||
{
|
|
||||||
"Data Source": member,
|
|
||||||
"Number of Features": member_stats.feature_count,
|
|
||||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
breakdown_df = pd.DataFrame(breakdown_data)
|
|
||||||
|
|
||||||
# Get all unique data sources and create color map
|
|
||||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
|
||||||
source_color_map_raw = {}
|
|
||||||
for member in unique_members_for_color:
|
|
||||||
source = member.split("-")[0]
|
|
||||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
|
||||||
palette = get_palette(source, n_colors=n_members + 2)
|
|
||||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
|
||||||
source_color_map_raw[member] = palette[idx + 1]
|
|
||||||
|
|
||||||
# Create and display pie chart
|
|
||||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
|
||||||
st.plotly_chart(fig, width="stretch")
|
|
||||||
|
|
||||||
# Show detailed table
|
|
||||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
|
||||||
|
|
||||||
# Detailed member information
|
|
||||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
|
||||||
# Create detailed table
|
|
||||||
member_details_dict = {
|
|
||||||
member: {
|
|
||||||
"feature_count": ms.feature_count,
|
|
||||||
"variable_names": ms.variable_names,
|
|
||||||
"dimensions": ms.dimensions,
|
|
||||||
"size_bytes": ms.size_bytes,
|
|
||||||
}
|
|
||||||
for member, ms in selected_member_stats.items()
|
|
||||||
}
|
|
||||||
details_df = create_member_details_table(member_details_dict)
|
|
||||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
|
||||||
|
|
||||||
# Individual member details
|
|
||||||
for member, member_stats in selected_member_stats.items():
|
|
||||||
st.markdown(f"### {member}")
|
|
||||||
|
|
||||||
# Variables
|
|
||||||
st.markdown("**Variables:**")
|
|
||||||
vars_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
|
||||||
for v in member_stats.variable_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(vars_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Dimensions
|
|
||||||
st.markdown("**Dimensions:**")
|
|
||||||
dim_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
|
||||||
f"{dim_name}: {dim_size:,}</span>"
|
|
||||||
for dim_name, dim_size in member_stats.dimensions.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
@st.fragment
|
||||||
|
|
@ -650,7 +672,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||||
# Use pre-computed statistics (much faster)
|
# Use pre-computed statistics (much faster)
|
||||||
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
|
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
|
||||||
# Filter to selected members only
|
# Filter to selected members only
|
||||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
|
||||||
|
m: stats.members[m] for m in selected_members if m in stats.members
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Create the actual ensemble with selected members and dimension filters
|
# Create the actual ensemble with selected members and dimension filters
|
||||||
ensemble = DatasetEnsemble(
|
ensemble = DatasetEnsemble(
|
||||||
|
|
@ -663,7 +687,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||||
# Recompute stats with the filtered ensemble
|
# Recompute stats with the filtered ensemble
|
||||||
stats = DatasetStatistics.from_ensemble(ensemble)
|
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||||
# Filter to selected members only
|
# Filter to selected members only
|
||||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
|
||||||
|
m: stats.members[m] for m in selected_members if m in stats.members
|
||||||
|
}
|
||||||
|
|
||||||
# Render configuration summary and statistics
|
# Render configuration summary and statistics
|
||||||
_render_configuration_summary(
|
_render_configuration_summary(
|
||||||
|
|
@ -678,16 +704,16 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
||||||
st.info("👆 Select at least one data source to see feature statistics")
|
st.info("👆 Select at least one data source to see feature statistics")
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_statistics(all_stats: DatasetStatsCache):
|
def render_dataset_statistics(
|
||||||
|
all_stats: DatasetStatsCache,
|
||||||
|
training_sample_df: pd.DataFrame,
|
||||||
|
feature_breakdown_df: pd.DataFrame,
|
||||||
|
comparison_df: pd.DataFrame,
|
||||||
|
inference_sample_df: pd.DataFrame,
|
||||||
|
):
|
||||||
"""Render the dataset statistics section with sample and feature counts."""
|
"""Render the dataset statistics section with sample and feature counts."""
|
||||||
st.header("📈 Dataset Statistics")
|
st.header("📈 Dataset Statistics")
|
||||||
|
|
||||||
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
|
|
||||||
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
|
||||||
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
|
||||||
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
|
||||||
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
|
||||||
|
|
||||||
# Create tabs for different analysis views
|
# Create tabs for different analysis views
|
||||||
analysis_tabs = st.tabs(
|
analysis_tabs = st.tabs(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
225
src/entropice/dashboard/sections/targets.py
Normal file
225
src/entropice/dashboard/sections/targets.py
Normal file
|
|
@ -0,0 +1,225 @@
|
||||||
|
"""Target dataset dashboard section."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.targets import create_target_distribution_plot, create_target_spatial_distribution_map
|
||||||
|
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
|
||||||
|
from entropice.ml.dataset import TrainingSet
|
||||||
|
from entropice.utils.types import TargetDataset, Task
|
||||||
|
|
||||||
|
|
||||||
|
def _render_distribution(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||||
|
st.subheader("Target Labels Distribution")
|
||||||
|
fig = create_target_distribution_plot(train_data_dict)
|
||||||
|
st.plotly_chart(fig, width="stretch")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
The above plot shows the distribution of target labels in the training and test datasets.
|
||||||
|
You can use this information to understand the balance of classes and identify any potential
|
||||||
|
issues with class imbalance that may affect model training.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_split_legend(training_set: TrainingSet):
|
||||||
|
"""Render legend for train/test split visualization."""
|
||||||
|
st.markdown("**Data Split:**")
|
||||||
|
split_colors = get_palette("split", 2)
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
color_rgb = hex_to_rgb(split_colors[0])
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; '
|
||||||
|
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
|
||||||
|
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||||
|
f"<span>Train</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
with col2:
|
||||||
|
color_rgb = hex_to_rgb(split_colors[1])
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; '
|
||||||
|
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
|
||||||
|
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||||
|
f"<span>Test</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show split statistics
|
||||||
|
train_count = (training_set.split == "train").sum()
|
||||||
|
test_count = (training_set.split == "test").sum()
|
||||||
|
total = len(training_set)
|
||||||
|
st.caption(
|
||||||
|
f"Train: {train_count:,} ({train_count / total * 100:.1f}%) | "
|
||||||
|
f"Test: {test_count:,} ({test_count / total * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_classification_legend(training_set: TrainingSet, selected_task: Task):
|
||||||
|
"""Render legend for classification task visualization."""
|
||||||
|
st.markdown("**Target Classes:**")
|
||||||
|
categories = training_set.targets["y"].cat.categories.tolist()
|
||||||
|
colors_palette = get_palette(selected_task, len(categories))
|
||||||
|
intervals = training_set.target_intervals
|
||||||
|
|
||||||
|
# Show categories with colors and intervals
|
||||||
|
for i, cat in enumerate(categories):
|
||||||
|
color = colors_palette[i]
|
||||||
|
interval_min, interval_max = intervals[i]
|
||||||
|
|
||||||
|
# Format interval display
|
||||||
|
interval_str = _format_interval(interval_min, interval_max, selected_task)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
||||||
|
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||||
|
f"<span>{cat}{interval_str}</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show total samples
|
||||||
|
st.caption(f"Total samples: {len(training_set):,}")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_interval(interval_min, interval_max, selected_task: Task) -> str:
|
||||||
|
"""Format interval string based on task type."""
|
||||||
|
if interval_min is None or interval_max is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if selected_task in ["count", "count_regimes"]:
|
||||||
|
# Integer values for count
|
||||||
|
if interval_min == interval_max:
|
||||||
|
return f" ({int(interval_min)})"
|
||||||
|
return f" ({int(interval_min)}-{int(interval_max)})"
|
||||||
|
|
||||||
|
if selected_task in ["density", "density_regimes"]:
|
||||||
|
# Percentage values for density
|
||||||
|
if interval_min == interval_max:
|
||||||
|
return f" ({interval_min * 100:.4f}%)"
|
||||||
|
return f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
|
||||||
|
|
||||||
|
# Binary or other
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _render_regression_legend(training_set: TrainingSet, selected_task: Task):
|
||||||
|
"""Render legend for regression task visualization."""
|
||||||
|
st.markdown("**Target Values:**")
|
||||||
|
z_values = training_set.targets["z"]
|
||||||
|
z_min, z_max = z_values.min(), z_values.max()
|
||||||
|
z_mean = z_values.mean()
|
||||||
|
z_median = z_values.median()
|
||||||
|
|
||||||
|
st.markdown(f"- **Range:** {z_min:.4f} to {z_max:.4f}\n- **Mean:** {z_mean:.4f}\n- **Median:** {z_median:.4f}")
|
||||||
|
|
||||||
|
# Show gradient visualization using the actual colormap
|
||||||
|
st.markdown("**Color Gradient:**")
|
||||||
|
|
||||||
|
# Get the same colormap used in the map
|
||||||
|
cmap = get_cmap(selected_task)
|
||||||
|
# Sample colors from the colormap
|
||||||
|
n_samples = 10
|
||||||
|
gradient_colors = [mcolors.to_hex(cmap(i / (n_samples - 1))) for i in range(n_samples)]
|
||||||
|
gradient_str = ", ".join(gradient_colors)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-top: 8px;">'
|
||||||
|
f'<span style="margin-right: 8px;">Low ({z_min:.4f})</span>'
|
||||||
|
f'<div style="width: 200px; height: 20px; '
|
||||||
|
f"background: linear-gradient(to right, {gradient_str}); "
|
||||||
|
f'border: 1px solid #ccc;"></div>'
|
||||||
|
f'<span style="margin-left: 8px;">High ({z_max:.4f})</span>'
|
||||||
|
f"</div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_elevation_info(training_set: TrainingSet):
|
||||||
|
"""Render 3D elevation information."""
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**3D Elevation:**")
|
||||||
|
|
||||||
|
is_classification = training_set.targets["y"].dtype.name == "category"
|
||||||
|
if is_classification:
|
||||||
|
st.info("💡 Height represents category level. Rotate the map by holding Ctrl/Cmd and dragging with your mouse.")
|
||||||
|
else:
|
||||||
|
z_values = training_set.targets["z"]
|
||||||
|
z_min, z_max = z_values.min(), z_values.max()
|
||||||
|
st.info(
|
||||||
|
f"💡 Height represents target value (normalized): {z_min:.4f} (low) → {z_max:.4f} (high). "
|
||||||
|
"Rotate the map by holding Ctrl/Cmd and dragging with your mouse."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||||
|
st.subheader("Target Labels Spatial Distribution")
|
||||||
|
|
||||||
|
cols = st.columns([2, 2, 1])
|
||||||
|
with cols[0]:
|
||||||
|
selected_target = cast(
|
||||||
|
TargetDataset,
|
||||||
|
st.selectbox(
|
||||||
|
"Select Target Dataset",
|
||||||
|
options=sorted(train_data_dict.keys()),
|
||||||
|
index=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with cols[1]:
|
||||||
|
selected_task = cast(
|
||||||
|
Task,
|
||||||
|
st.selectbox(
|
||||||
|
"Select Task",
|
||||||
|
options=sorted(train_data_dict[selected_target].keys()),
|
||||||
|
index=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with cols[2]:
|
||||||
|
# Controls weather a 3D map or a 2D map is shown
|
||||||
|
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
|
||||||
|
# Controls what should be shows, either the split or the labels / values
|
||||||
|
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False))
|
||||||
|
|
||||||
|
training_set = train_data_dict[selected_target][selected_task]
|
||||||
|
map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task)
|
||||||
|
st.pydeck_chart(map_deck)
|
||||||
|
|
||||||
|
# Add legend
|
||||||
|
with st.expander("🎨 Legend", expanded=True):
|
||||||
|
if show_split:
|
||||||
|
_render_split_legend(training_set)
|
||||||
|
else:
|
||||||
|
# Show target labels legend
|
||||||
|
is_classification = training_set.targets["y"].dtype.name == "category"
|
||||||
|
if is_classification:
|
||||||
|
_render_classification_legend(training_set, selected_task)
|
||||||
|
else:
|
||||||
|
_render_regression_legend(training_set, selected_task)
|
||||||
|
|
||||||
|
# Add elevation info for 3D maps
|
||||||
|
if make_3d_map:
|
||||||
|
_render_elevation_info(training_set)
|
||||||
|
|
||||||
|
|
||||||
|
def render_target_information_tab(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||||
|
"""Render target labels distribution and spatial visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data_dict: Nested dictionary with structure:
|
||||||
|
{target_dataset: {task: TrainingSet}}
|
||||||
|
where target_dataset is in ["darts_v1", "darts_mllabels"]
|
||||||
|
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
|
|
||||||
|
"""
|
||||||
|
_render_distribution(train_data_dict)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
_render_target_map(train_data_dict)
|
||||||
|
|
@ -31,6 +31,20 @@ import streamlit as st
|
||||||
from pypalettes import load_cmap
|
from pypalettes import load_cmap
|
||||||
|
|
||||||
|
|
||||||
|
def hex_to_rgb(hex_color: str) -> list[int]:
|
||||||
|
"""Convert hex color to RGB list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hex_color: Hex color string (e.g., '#FF5733').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [R, G, B] values (0-255).
|
||||||
|
|
||||||
|
"""
|
||||||
|
hex_color = hex_color.lstrip("#")
|
||||||
|
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
||||||
|
|
||||||
|
|
||||||
def get_cmap(variable: str) -> mcolors.Colormap:
|
def get_cmap(variable: str) -> mcolors.Colormap:
|
||||||
"""Get a color palette by a "data" variable.
|
"""Get a color palette by a "data" variable.
|
||||||
|
|
||||||
|
|
|
||||||
155
src/entropice/dashboard/views/dataset_page.py
Normal file
155
src/entropice/dashboard/views/dataset_page.py
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""Data page: Visualization of the data."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
|
||||||
|
from entropice.dashboard.sections.targets import render_target_information_tab
|
||||||
|
from entropice.dashboard.utils.stats import DatasetStatistics
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||||
|
from entropice.utils.types import (
|
||||||
|
GridConfig,
|
||||||
|
L2SourceDataset,
|
||||||
|
TargetDataset,
|
||||||
|
Task,
|
||||||
|
TemporalMode,
|
||||||
|
all_target_datasets,
|
||||||
|
all_tasks,
|
||||||
|
grid_configs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
||||||
|
"""Render dataset configuration selector in sidebar with form.
|
||||||
|
|
||||||
|
Stores the selected ensemble in session state when form is submitted.
|
||||||
|
"""
|
||||||
|
with st.sidebar.form("dataset_config_form"):
|
||||||
|
st.header("Dataset Configuration")
|
||||||
|
|
||||||
|
# Grid selection
|
||||||
|
grid_options = [gc.display_name for gc in grid_configs]
|
||||||
|
|
||||||
|
grid_level_combined = st.selectbox(
|
||||||
|
"Grid Configuration",
|
||||||
|
options=grid_options,
|
||||||
|
index=0,
|
||||||
|
help="Select the grid system and resolution level",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the selected grid config
|
||||||
|
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||||
|
|
||||||
|
# Temporal mode selection
|
||||||
|
temporal_mode = st.selectbox(
|
||||||
|
"Temporal Mode",
|
||||||
|
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
|
||||||
|
index=0,
|
||||||
|
format_func=lambda x: "Synopsis (all years)"
|
||||||
|
if x == "synopsis"
|
||||||
|
else "Years-as-Features"
|
||||||
|
if x == "feature"
|
||||||
|
else f"Year {x}",
|
||||||
|
help="Select temporal mode: 'synopsis' for temporal features or specific year",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Members selection
|
||||||
|
st.subheader("Dataset Members")
|
||||||
|
|
||||||
|
all_members = cast(
|
||||||
|
list[L2SourceDataset],
|
||||||
|
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||||
|
)
|
||||||
|
selected_members: list[L2SourceDataset] = []
|
||||||
|
|
||||||
|
for member in all_members:
|
||||||
|
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||||
|
selected_members.append(member)
|
||||||
|
|
||||||
|
# Form submit button
|
||||||
|
submitted = st.form_submit_button(
|
||||||
|
"Load Dataset",
|
||||||
|
type="primary",
|
||||||
|
use_container_width=True,
|
||||||
|
disabled=len(selected_members) == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not submitted:
|
||||||
|
st.info("👆 Click 'Load Dataset' to apply changes.")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
ensemble = DatasetEnsemble(
|
||||||
|
grid=selected_grid_config.grid,
|
||||||
|
level=selected_grid_config.level,
|
||||||
|
temporal_mode=cast(TemporalMode, temporal_mode),
|
||||||
|
members=selected_members,
|
||||||
|
)
|
||||||
|
return ensemble
|
||||||
|
|
||||||
|
|
||||||
|
def render_dataset_page():
|
||||||
|
"""Render the Dataset page of the dashboard."""
|
||||||
|
st.title("📊 Entropice Data Visualization")
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
This section allows you to visualize the source datasets and training data
|
||||||
|
used in your experiments. Use the sidebar to configure the dataset parameters.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the selected dataset ensemble from sidebar
|
||||||
|
ensemble = render_dataset_configuration_sidebar()
|
||||||
|
|
||||||
|
# Display dataset ID in a styled container
|
||||||
|
st.info(f"Loaded Dataset with ID `{ensemble.id()}`")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Render dataset statistics section
|
||||||
|
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||||
|
render_ensemble_details(ensemble.members, stats.members)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Load data and precompute visualizations
|
||||||
|
# First, load for all task - target combinations the training data
|
||||||
|
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
|
||||||
|
for target in all_target_datasets:
|
||||||
|
train_data_dict[target] = {}
|
||||||
|
for task in all_tasks:
|
||||||
|
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
||||||
|
|
||||||
|
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||||
|
# Create tabs for different data views
|
||||||
|
tab_names = ["🎯 Targets", "📐 Areas"]
|
||||||
|
# Add tabs for each member based on what's in the ensemble
|
||||||
|
if "AlphaEarth" in ensemble.members:
|
||||||
|
tab_names.append("🌍 AlphaEarth")
|
||||||
|
if "ArcticDEM" in ensemble.members:
|
||||||
|
tab_names.append("🏔️ ArcticDEM")
|
||||||
|
if era5_members:
|
||||||
|
tab_names.append("🌡️ ERA5")
|
||||||
|
tabs = st.tabs(tab_names)
|
||||||
|
|
||||||
|
with tabs[0]:
|
||||||
|
st.header("🎯 Target Labels Visualization")
|
||||||
|
render_target_information_tab(train_data_dict)
|
||||||
|
with tabs[1]:
|
||||||
|
st.header("📐 Areas Visualization")
|
||||||
|
tab_index = 2
|
||||||
|
if "AlphaEarth" in ensemble.members:
|
||||||
|
with tabs[tab_index]:
|
||||||
|
st.header("🌍 AlphaEarth Visualization")
|
||||||
|
tab_index += 1
|
||||||
|
if "ArcticDEM" in ensemble.members:
|
||||||
|
with tabs[tab_index]:
|
||||||
|
st.header("🏔️ ArcticDEM Visualization")
|
||||||
|
tab_index += 1
|
||||||
|
if era5_members:
|
||||||
|
with tabs[tab_index]:
|
||||||
|
st.header("🌡️ ERA5 Visualization")
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
stopwatch.summary()
|
||||||
|
|
@ -9,9 +9,7 @@ from entropice.dashboard.sections.experiment_results import (
|
||||||
render_training_results_summary,
|
render_training_results_summary,
|
||||||
)
|
)
|
||||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||||
from entropice.dashboard.utils.stats import (
|
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||||
load_all_default_dataset_statistics,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def render_overview_page():
|
def render_overview_page():
|
||||||
|
|
@ -47,7 +45,12 @@ def render_overview_page():
|
||||||
|
|
||||||
# Render dataset analysis section
|
# Render dataset analysis section
|
||||||
all_stats = load_all_default_dataset_statistics()
|
all_stats = load_all_default_dataset_statistics()
|
||||||
render_dataset_statistics(all_stats)
|
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
||||||
|
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||||
|
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
||||||
|
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
||||||
|
|
||||||
|
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
|
||||||
|
|
||||||
st.balloons()
|
st.balloons()
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue