diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py
index 6eea4d9..6a54837 100644
--- a/src/entropice/dashboard/app.py
+++ b/src/entropice/dashboard/app.py
@@ -13,6 +13,7 @@ Pages:
import streamlit as st
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
+from entropice.dashboard.views.dataset_page import render_dataset_page
from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.views.overview_page import render_overview_page
@@ -26,6 +27,7 @@ def main():
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
+ data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
@@ -35,6 +37,7 @@ def main():
pg = st.navigation(
{
"Overview": [overview_page],
+ "Data": [data_page],
"Training": [training_data_page, training_analysis_page, autogluon_page],
"Model State": [model_state_page],
"Inference": [inference_page],
diff --git a/src/entropice/dashboard/plots/targets.py b/src/entropice/dashboard/plots/targets.py
new file mode 100644
index 0000000..2948b00
--- /dev/null
+++ b/src/entropice/dashboard/plots/targets.py
@@ -0,0 +1,470 @@
+"""Plots for visualizing target labels in datasets."""
+
+import antimeridian
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
+import pydeck as pdk
+from plotly.subplots import make_subplots
+from shapely.geometry import shape
+
+from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
+from entropice.ml.dataset import TrainingSet
+from entropice.utils.types import TargetDataset, Task
+
+
+def create_target_distribution_plot(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]) -> go.Figure:
+ """Create a plot showing the distribution of target labels across datasets and tasks.
+
+ This function creates a comprehensive visualization showing:
+ - Classification tasks (binary, count_regimes, density_regimes): categorical distributions
+ - Regression tasks (count, density): histogram distributions
+ - Train/test split information for each combination
+ - Separate subplots for each TargetDataset
+
+ Args:
+ train_data_dict: Nested dictionary with structure:
+ {target_dataset: {task: TrainingSet}}
+ where target_dataset is in ["darts_v1", "darts_mllabels"]
+ and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
+
+ Returns:
+ Plotly Figure with subplots showing target distributions.
+
+ """
+ # Define task types and their properties
+ classification_tasks: list[Task] = ["binary", "count_regimes", "density_regimes"]
+
+ task_titles: dict[Task, str] = {
+ "binary": "Binary",
+ "count_regimes": "Count Regimes",
+ "density_regimes": "Density Regimes",
+ "count": "Count (Regression)",
+ "density": "Density (Regression)",
+ }
+
+ # Get all available target datasets and tasks
+ target_datasets = sorted(train_data_dict.keys())
+ all_tasks = sorted(
+ {task for tasks in train_data_dict.values() for task in tasks.keys()},
+ key=lambda x: (x not in classification_tasks, x), # Classification first
+ )
+
+ # Create subplots: one row per target dataset, one column per task
+ n_rows = len(target_datasets)
+ n_cols = len(all_tasks)
+
+ # Create column titles (tasks) for the subplots
+ column_titles = [task_titles.get(task, str(task)) for task in all_tasks] # type: ignore[arg-type]
+
+ # Create row titles (target datasets)
+ row_titles = [target_ds.replace("_", " ").title() for target_ds in target_datasets]
+
+ fig = make_subplots(
+ rows=n_rows,
+ cols=n_cols,
+ column_titles=column_titles,
+ vertical_spacing=0.20 / max(n_rows, 1),
+ horizontal_spacing=0.08 / max(n_cols, 1),
+ )
+
+ # Iterate through each target dataset and task
+ for row_idx, target_dataset in enumerate(target_datasets, start=1):
+ task_dict = train_data_dict[target_dataset] # type: ignore[index]
+
+ for col_idx, task in enumerate(all_tasks, start=1):
+ if task not in task_dict:
+ # Skip if this task is not available for this target dataset
+ continue
+
+ dataset = task_dict[task] # type: ignore[index]
+
+ # Determine if this is a classification or regression task
+ if task in classification_tasks:
+ _add_classification_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
+ else:
+ _add_regression_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
+
+ # Update layout
+ fig.update_layout(
+ height=400 * n_rows,
+ showlegend=True,
+ legend={
+ "orientation": "h",
+ "yanchor": "top",
+ "y": -0.15,
+ "xanchor": "center",
+ "x": 0.5,
+ },
+ margin={"l": 120, "r": 40, "t": 80, "b": 100},
+ )
+
+ # Update all x-axes and y-axes
+ fig.update_xaxes(title_font={"size": 11})
+ fig.update_yaxes(title_text="Count", title_font={"size": 11})
+
+ # Manually add row title annotations (left-aligned dataset names)
+ for row_idx, row_title in enumerate(row_titles, start=1):
+ # Calculate the y position for each row (at the top of the row)
+ # Account for vertical spacing between rows
+ spacing = 0.20 / max(n_rows, 1)
+ row_height = (1 - spacing * (n_rows - 1)) / n_rows
+ y_position = 1.08 - (row_idx - 1) * (row_height + spacing)
+
+ fig.add_annotation(
+ text=f"{row_title}",
+ xref="paper",
+ yref="paper",
+ x=-0.02,
+ y=y_position,
+ xanchor="left",
+ yanchor="top",
+ textangle=0,
+ font={"size": 20},
+ showarrow=False,
+ )
+
+ return fig
+
+
+def _add_classification_subplot(
+ fig: go.Figure,
+ dataset: TrainingSet,
+ task: str,
+ row: int,
+ col: int,
+ n_rows: int,
+) -> None:
+ """Add a classification task subplot showing categorical distribution.
+
+ Args:
+ fig: Plotly figure to add subplot to.
+ dataset: TrainingSet containing the data.
+ task: Task name for color selection.
+ row: Subplot row index.
+ col: Subplot column index.
+ n_rows: Total number of rows (for legend control).
+
+ """
+ y_binned = dataset.targets["y"]
+ categories = y_binned.cat.categories.tolist()
+ colors = get_palette(task, len(categories) + 2)[1:-1] # Avoid too light/dark colors
+
+ # Calculate counts for train and test splits
+ train_counts = [((y_binned == cat) & (dataset.split == "train")).sum() for cat in categories]
+ test_counts = [((y_binned == cat) & (dataset.split == "test")).sum() for cat in categories]
+
+ # Add train bars
+ fig.add_trace(
+ go.Bar(
+ name="Train",
+ x=categories,
+ y=train_counts,
+ marker_color=colors,
+ opacity=0.9,
+ text=train_counts,
+ textposition="inside",
+ textfont={"size": 9, "color": "white"},
+ legendgroup="split",
+ showlegend=(row == 1 and col == 1), # Only show legend once
+ ),
+ row=row,
+ col=col,
+ )
+
+ # Add test bars
+ fig.add_trace(
+ go.Bar(
+ name="Test",
+ x=categories,
+ y=test_counts,
+ marker_color=colors,
+ opacity=0.5,
+ text=test_counts,
+ textposition="inside",
+ textfont={"size": 9, "color": "white"},
+ legendgroup="split",
+ showlegend=(row == 1 and col == 1), # Only show legend once
+ ),
+ row=row,
+ col=col,
+ )
+
+ # Update subplot layout
+ fig.update_xaxes(tickangle=-45, row=row, col=col)
+
+
+def _add_regression_subplot(
+ fig: go.Figure,
+ dataset: TrainingSet,
+ task: str,
+ row: int,
+ col: int,
+ n_rows: int,
+) -> None:
+ """Add a regression task subplot showing histogram distribution.
+
+ Args:
+ fig: Plotly figure to add subplot to.
+ dataset: TrainingSet containing the data.
+ task: Task name for color selection.
+ row: Subplot row index.
+ col: Subplot column index.
+ n_rows: Total number of rows (for legend control).
+
+ """
+ # Get raw values
+ z_values = dataset.targets["z"]
+
+ # Split into train and test
+ train_values = z_values[dataset.split == "train"]
+ test_values = z_values[dataset.split == "test"]
+
+ # Determine bin edges based on combined data
+ all_values = pd.concat([train_values, test_values])
+ # Use a reasonable number of bins
+ n_bins = min(30, int(np.sqrt(len(all_values))))
+ bin_edges = np.histogram_bin_edges(all_values, bins=n_bins)
+
+ # Get colors for the task
+ colors = get_palette(task, 6)[1:5:2] # Use less bright/dark colors
+
+ # Add train histogram
+ fig.add_trace(
+ go.Histogram(
+ name="Train",
+ x=train_values,
+ xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
+ marker_color=colors[0],
+ opacity=0.7,
+ legendgroup="split",
+ showlegend=(row == 1 and col == 1), # Only show legend once
+ ),
+ row=row,
+ col=col,
+ )
+
+ # Add test histogram
+ fig.add_trace(
+ go.Histogram(
+ name="Test",
+ x=test_values,
+ xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
+ marker_color=colors[1],
+ opacity=0.5,
+ legendgroup="split",
+ showlegend=(row == 1 and col == 1), # Only show legend once
+ ),
+ row=row,
+ col=col,
+ )
+
+ # Update barmode to overlay for histograms
+ fig.update_xaxes(title_text=task.capitalize(), row=row, col=col)
+
+
+def _assign_split_colors(gdf: pd.DataFrame) -> pd.DataFrame:
+ """Assign colors based on train/test split.
+
+ Args:
+ gdf: GeoDataFrame with 'split' column.
+
+ Returns:
+ GeoDataFrame with 'fill_color' column added.
+
+ """
+ split_colors = get_palette("split", 2)
+ color_map = {"train": hex_to_rgb(split_colors[0]), "test": hex_to_rgb(split_colors[1])}
+ # Convert to list to avoid categorical hashing issues with list values
+ gdf["fill_color"] = [color_map[x] for x in gdf["split"]]
+ return gdf
+
+
+def _assign_classification_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
+ """Assign colors for classification tasks based on categorical labels.
+
+ Args:
+ gdf: GeoDataFrame with categorical 'y' column.
+ task: Task name for color selection.
+
+ Returns:
+ GeoDataFrame with 'fill_color' column added.
+
+ """
+ categories = gdf["y"].cat.categories.tolist()
+ colors_palette = get_palette(task, len(categories))
+ color_map = {cat: hex_to_rgb(colors_palette[i]) for i, cat in enumerate(categories)}
+ # Convert to list to avoid categorical hashing issues with list values
+ gdf["fill_color"] = [color_map[x] for x in gdf["y"]]
+ return gdf
+
+
+def _assign_regression_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
+ """Assign colors for regression tasks based on continuous values.
+
+ Args:
+ gdf: GeoDataFrame with numeric 'z' column.
+ task: Task name for color selection.
+
+ Returns:
+ GeoDataFrame with 'fill_color' column added.
+
+ """
+ z_values = gdf["z"]
+ z_min, z_max = z_values.min(), z_values.max()
+
+ # Normalize to [0, 1]
+ if z_max > z_min:
+ z_normalized = (z_values - z_min) / (z_max - z_min)
+ else:
+ z_normalized = pd.Series([0.5] * len(z_values), index=z_values.index)
+
+ # Get colormap and map normalized values to RGB
+ cmap = get_cmap(task)
+
+ def value_to_rgb(normalized_val):
+ rgba = cmap(normalized_val)
+ return [int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)]
+
+ gdf["fill_color"] = z_normalized.apply(value_to_rgb)
+ return gdf
+
+
+def _assign_elevation(gdf: pd.DataFrame, make_3d_map: bool, is_classification: bool) -> pd.DataFrame:
+ """Assign elevation values for 3D visualization.
+
+ Args:
+ gdf: GeoDataFrame to add elevation to.
+ make_3d_map: Whether to create a 3D map.
+ is_classification: Whether this is a classification task.
+
+ Returns:
+ GeoDataFrame with 'elevation' column added.
+
+ """
+ if not make_3d_map:
+ gdf["elevation"] = 0
+ return gdf
+
+ print(f"{make_3d_map=}, {is_classification=}")
+ if is_classification:
+ # For classification, use category codes for elevation
+ # Convert to int to avoid overflow with int8 categorical codes
+ z_values = gdf["y"].cat.codes.astype(int)
+ else:
+ # For regression, use z values for elevation
+ z_values = gdf["z"]
+ z_min, z_max = z_values.min(), z_values.max()
+
+ if z_max > z_min:
+ gdf["elevation"] = ((z_values - z_min) / (z_max - z_min)).fillna(0)
+ else:
+ gdf["elevation"] = 0
+
+ print(gdf["elevation"].describe())
+ return gdf
+
+
+def create_target_spatial_distribution_map(
+ training_set: TrainingSet, make_3d_map: bool, show_split: bool, task: Task
+) -> pdk.Deck:
+ """Create a spatial distribution map of target labels using PyDeck.
+
+ Args:
+ training_set (TrainingSet): The training set containing spatial data and target labels.
+ make_3d_map (bool): Whether to create a 3D map (True) or a 2D map (False).
+ show_split (bool): Whether to color by train/test split (True) or by target labels (False).
+ task (Task): The task type to determine color mapping.
+
+ Returns:
+ pdk.Deck: A PyDeck Deck object representing the spatial distribution map.
+
+ """
+ # Get the targets GeoDataFrame
+ gdf = training_set.targets.copy()
+
+ # Add split information
+ gdf["split"] = training_set.split.to_numpy()
+
+ # Determine if this is a classification or regression task
+ is_classification = gdf["y"].dtype.name == "category"
+
+ # Convert to WGS84 for pydeck
+ gdf_wgs84 = gdf.to_crs("EPSG:4326")
+
+ # Fix antimeridian issues
+ def fix_hex_geometry(geom):
+ try:
+ return shape(antimeridian.fix_shape(geom))
+ except (ValueError, Exception):
+ return geom
+
+ gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
+
+ # Assign colors based on mode
+ if show_split:
+ gdf_wgs84 = _assign_split_colors(gdf_wgs84)
+ elif is_classification:
+ gdf_wgs84 = _assign_classification_colors(gdf_wgs84, task)
+ else:
+ gdf_wgs84 = _assign_regression_colors(gdf_wgs84, task)
+
+ # Add elevation for 3D visualization
+ gdf_wgs84 = _assign_elevation(gdf_wgs84, make_3d_map, is_classification)
+
+ # Convert to GeoJSON format
+ geojson_data = []
+ for _, row in gdf_wgs84.iterrows():
+ feature = {
+ "type": "Feature",
+ "geometry": row["geometry"].__geo_interface__,
+ "properties": {
+ "fill_color": row["fill_color"],
+ "elevation": row["elevation"],
+ "split": row["split"],
+ "y_label": str(row["y"]),
+ "z_value": float(row["z"]) if pd.notna(row["z"]) else None,
+ },
+ }
+ geojson_data.append(feature)
+
+ # Create pydeck layer
+ layer = pdk.Layer(
+ "GeoJsonLayer",
+ geojson_data,
+ opacity=0.7,
+ stroked=True,
+ filled=True,
+ extruded=make_3d_map,
+ wireframe=False,
+ get_fill_color="properties.fill_color",
+ get_line_color=[80, 80, 80],
+ line_width_min_pixels=0.5,
+ get_elevation="properties.elevation" if make_3d_map else 0,
+ elevation_scale=500000 if make_3d_map else 0,
+ pickable=True,
+ )
+
+ # Set initial view state (centered on the Arctic)
+ view_state = pdk.ViewState(
+ latitude=70,
+ longitude=0,
+ zoom=2 if not make_3d_map else 1.5,
+ pitch=0 if not make_3d_map else 45,
+ )
+
+ # Create tooltip
+ tooltip_html = "Split: {split}
Label: {y_label}
Value: {z_value}
"
+
+ # Create deck
+ deck = pdk.Deck(
+ layers=[layer],
+ initial_view_state=view_state,
+ tooltip={
+ "html": tooltip_html,
+ "style": {"backgroundColor": "steelblue", "color": "white"},
+ },
+ map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
+ )
+
+ return deck
diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py
index c0c1a56..7105402 100644
--- a/src/entropice/dashboard/sections/dataset_statistics.py
+++ b/src/entropice/dashboard/sections/dataset_statistics.py
@@ -14,7 +14,7 @@ from entropice.dashboard.plots.dataset_statistics import (
create_sample_count_bar_chart,
)
from entropice.dashboard.utils.colors import get_palette
-from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
+from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import (
GridConfig,
@@ -436,9 +436,104 @@ def _render_aggregation_selection(
return dimension_filters
+def render_ensemble_details(
+ selected_members: list[L2SourceDataset],
+ selected_member_stats: dict[L2SourceDataset, MemberStatistics],
+):
+ """Render ensemble-specific details: feature breakdown and member information.
+
+ This function displays the feature breakdown by data source and detailed
+ member information, independent of task or target selection. Can be reused
+ on other pages by supplying a dataset ensemble.
+
+ Args:
+ selected_members: List of selected member datasets
+ selected_member_stats: Statistics for selected members
+
+ """
+ # Calculate total features for selected members
+ total_features = sum(ms.feature_count for ms in selected_member_stats.values())
+
+ # Feature breakdown by source
+ st.markdown("#### Feature Breakdown by Data Source")
+
+ breakdown_data = []
+ for member, member_stats in selected_member_stats.items():
+ breakdown_data.append(
+ {
+ "Data Source": member,
+ "Number of Features": member_stats.feature_count,
+ "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
+ }
+ )
+
+ breakdown_df = pd.DataFrame(breakdown_data)
+
+ # Get all unique data sources and create color map
+ unique_members_for_color = sorted(selected_member_stats.keys())
+ source_color_map_raw = {}
+ for member in unique_members_for_color:
+ source = member.split("-")[0]
+ n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
+ palette = get_palette(source, n_colors=n_members + 2)
+ idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
+ source_color_map_raw[member] = palette[idx + 1]
+
+ # Create and display pie chart
+ fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
+ st.plotly_chart(fig, width="stretch")
+
+ # Show detailed table
+ st.dataframe(breakdown_df, hide_index=True, width="stretch")
+
+ # Detailed member information
+ with st.expander("📦 Detailed Source Information", expanded=False):
+ # Create detailed table
+ member_details_dict = {
+ member: {
+ "feature_count": ms.feature_count,
+ "variable_names": ms.variable_names,
+ "dimensions": ms.dimensions,
+ "size_bytes": ms.size_bytes,
+ }
+ for member, ms in selected_member_stats.items()
+ }
+ details_df = create_member_details_table(member_details_dict)
+ st.dataframe(details_df, hide_index=True, width="stretch")
+
+ # Individual member details
+ for member, member_stats in selected_member_stats.items():
+ st.markdown(f"### {member}")
+
+ # Variables
+ st.markdown("**Variables:**")
+ vars_html = " ".join(
+ [
+ f'{v}'
+ for v in member_stats.variable_names
+ ]
+ )
+ st.markdown(vars_html, unsafe_allow_html=True)
+
+ # Dimensions
+ st.markdown("**Dimensions:**")
+ dim_html = " ".join(
+ [
+ f''
+ f"{dim_name}: {dim_size:,}"
+ for dim_name, dim_size in member_stats.dimensions.items()
+ ]
+ )
+ st.markdown(dim_html, unsafe_allow_html=True)
+
+ st.markdown("---")
+
+
def _render_configuration_summary(
selected_members: list[L2SourceDataset],
- selected_member_stats: dict[str, MemberStatistics],
+ selected_member_stats: dict[L2SourceDataset, MemberStatistics],
selected_target: TargetDataset,
selected_task: Task,
selected_temporal_mode: TemporalMode,
@@ -527,81 +622,8 @@ def _render_configuration_summary(
)
st.dataframe(class_dist_df, hide_index=True, width="stretch")
- # Feature breakdown by source
- st.markdown("#### Feature Breakdown by Data Source")
-
- breakdown_data = []
- for member, member_stats in selected_member_stats.items():
- breakdown_data.append(
- {
- "Data Source": member,
- "Number of Features": member_stats.feature_count,
- "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
- }
- )
-
- breakdown_df = pd.DataFrame(breakdown_data)
-
- # Get all unique data sources and create color map
- unique_members_for_color = sorted(selected_member_stats.keys())
- source_color_map_raw = {}
- for member in unique_members_for_color:
- source = member.split("-")[0]
- n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
- palette = get_palette(source, n_colors=n_members + 2)
- idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
- source_color_map_raw[member] = palette[idx + 1]
-
- # Create and display pie chart
- fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
- st.plotly_chart(fig, width="stretch")
-
- # Show detailed table
- st.dataframe(breakdown_df, hide_index=True, width="stretch")
-
- # Detailed member information
- with st.expander("📦 Detailed Source Information", expanded=False):
- # Create detailed table
- member_details_dict = {
- member: {
- "feature_count": ms.feature_count,
- "variable_names": ms.variable_names,
- "dimensions": ms.dimensions,
- "size_bytes": ms.size_bytes,
- }
- for member, ms in selected_member_stats.items()
- }
- details_df = create_member_details_table(member_details_dict)
- st.dataframe(details_df, hide_index=True, width="stretch")
-
- # Individual member details
- for member, member_stats in selected_member_stats.items():
- st.markdown(f"### {member}")
-
- # Variables
- st.markdown("**Variables:**")
- vars_html = " ".join(
- [
- f'{v}'
- for v in member_stats.variable_names
- ]
- )
- st.markdown(vars_html, unsafe_allow_html=True)
-
- # Dimensions
- st.markdown("**Dimensions:**")
- dim_html = " ".join(
- [
- f''
- f"{dim_name}: {dim_size:,}"
- for dim_name, dim_size in member_stats.dimensions.items()
- ]
- )
- st.markdown(dim_html, unsafe_allow_html=True)
-
- st.markdown("---")
+ # Render ensemble-specific details
+ render_ensemble_details(selected_members, selected_member_stats)
@st.fragment
@@ -650,7 +672,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
# Use pre-computed statistics (much faster)
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
# Filter to selected members only
- selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+ selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
+ m: stats.members[m] for m in selected_members if m in stats.members
+ }
else:
# Create the actual ensemble with selected members and dimension filters
ensemble = DatasetEnsemble(
@@ -663,7 +687,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
# Recompute stats with the filtered ensemble
stats = DatasetStatistics.from_ensemble(ensemble)
# Filter to selected members only
- selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
+ selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
+ m: stats.members[m] for m in selected_members if m in stats.members
+ }
# Render configuration summary and statistics
_render_configuration_summary(
@@ -678,16 +704,16 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
st.info("👆 Select at least one data source to see feature statistics")
-def render_dataset_statistics(all_stats: DatasetStatsCache):
+def render_dataset_statistics(
+ all_stats: DatasetStatsCache,
+ training_sample_df: pd.DataFrame,
+ feature_breakdown_df: pd.DataFrame,
+ comparison_df: pd.DataFrame,
+ inference_sample_df: pd.DataFrame,
+):
"""Render the dataset statistics section with sample and feature counts."""
st.header("📈 Dataset Statistics")
- all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
- training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
- feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
- comparison_df = DatasetStatistics.get_comparison_df(all_stats)
- inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
-
# Create tabs for different analysis views
analysis_tabs = st.tabs(
[
diff --git a/src/entropice/dashboard/sections/targets.py b/src/entropice/dashboard/sections/targets.py
new file mode 100644
index 0000000..8c41cd1
--- /dev/null
+++ b/src/entropice/dashboard/sections/targets.py
@@ -0,0 +1,225 @@
+"""Target dataset dashboard section."""
+
+from typing import cast
+
+import matplotlib.colors as mcolors
+import streamlit as st
+
+from entropice.dashboard.plots.targets import create_target_distribution_plot, create_target_spatial_distribution_map
+from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
+from entropice.ml.dataset import TrainingSet
+from entropice.utils.types import TargetDataset, Task
+
+
+def _render_distribution(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
+ st.subheader("Target Labels Distribution")
+ fig = create_target_distribution_plot(train_data_dict)
+ st.plotly_chart(fig, width="stretch")
+ st.markdown(
+ """
+ The above plot shows the distribution of target labels in the training and test datasets.
+ You can use this information to understand the balance of classes and identify any potential
+ issues with class imbalance that may affect model training.
+ """
+ )
+
+
+def _render_split_legend(training_set: TrainingSet):
+ """Render legend for train/test split visualization."""
+ st.markdown("**Data Split:**")
+ split_colors = get_palette("split", 2)
+
+ col1, col2 = st.columns(2)
+ with col1:
+ color_rgb = hex_to_rgb(split_colors[0])
+ st.markdown(
+ f'