From cb7b0f9e6b4f9bd7eea9128f081d85beafd0ba6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Fri, 16 Jan 2026 23:51:38 +0100 Subject: [PATCH] Add a Dataset Page to deprecate the Training Data page --- src/entropice/dashboard/app.py | 3 + src/entropice/dashboard/plots/targets.py | 470 ++++++++++++++++++ .../dashboard/sections/dataset_statistics.py | 198 ++++---- src/entropice/dashboard/sections/targets.py | 225 +++++++++ src/entropice/dashboard/utils/colors.py | 14 + src/entropice/dashboard/views/dataset_page.py | 155 ++++++ .../dashboard/views/overview_page.py | 11 +- 7 files changed, 986 insertions(+), 90 deletions(-) create mode 100644 src/entropice/dashboard/plots/targets.py create mode 100644 src/entropice/dashboard/sections/targets.py create mode 100644 src/entropice/dashboard/views/dataset_page.py diff --git a/src/entropice/dashboard/app.py b/src/entropice/dashboard/app.py index 6eea4d9..6a54837 100644 --- a/src/entropice/dashboard/app.py +++ b/src/entropice/dashboard/app.py @@ -13,6 +13,7 @@ Pages: import streamlit as st from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page +from entropice.dashboard.views.dataset_page import render_dataset_page from entropice.dashboard.views.inference_page import render_inference_page from entropice.dashboard.views.model_state_page import render_model_state_page from entropice.dashboard.views.overview_page import render_overview_page @@ -26,6 +27,7 @@ def main(): # Setup Navigation overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True) + data_page = st.Page(render_dataset_page, title="Dataset", icon="📊") training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️") training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾") autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖") @@ -35,6 +37,7 @@ def main(): pg = st.navigation( { "Overview": [overview_page], + "Data": [data_page], "Training": [training_data_page, training_analysis_page, autogluon_page], "Model State": [model_state_page], "Inference": [inference_page], diff --git a/src/entropice/dashboard/plots/targets.py b/src/entropice/dashboard/plots/targets.py new file mode 100644 index 0000000..2948b00 --- /dev/null +++ b/src/entropice/dashboard/plots/targets.py @@ -0,0 +1,470 @@ +"""Plots for visualizing target labels in datasets.""" + +import antimeridian +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import pydeck as pdk +from plotly.subplots import make_subplots +from shapely.geometry import shape + +from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb +from entropice.ml.dataset import TrainingSet +from entropice.utils.types import TargetDataset, Task + + +def create_target_distribution_plot(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]) -> go.Figure: + """Create a plot showing the distribution of target labels across datasets and tasks. + + This function creates a comprehensive visualization showing: + - Classification tasks (binary, count_regimes, density_regimes): categorical distributions + - Regression tasks (count, density): histogram distributions + - Train/test split information for each combination + - Separate subplots for each TargetDataset + + Args: + train_data_dict: Nested dictionary with structure: + {target_dataset: {task: TrainingSet}} + where target_dataset is in ["darts_v1", "darts_mllabels"] + and task is in ["binary", "count_regimes", "density_regimes", "count", "density"] + + Returns: + Plotly Figure with subplots showing target distributions. + + """ + # Define task types and their properties + classification_tasks: list[Task] = ["binary", "count_regimes", "density_regimes"] + + task_titles: dict[Task, str] = { + "binary": "Binary", + "count_regimes": "Count Regimes", + "density_regimes": "Density Regimes", + "count": "Count (Regression)", + "density": "Density (Regression)", + } + + # Get all available target datasets and tasks + target_datasets = sorted(train_data_dict.keys()) + all_tasks = sorted( + {task for tasks in train_data_dict.values() for task in tasks.keys()}, + key=lambda x: (x not in classification_tasks, x), # Classification first + ) + + # Create subplots: one row per target dataset, one column per task + n_rows = len(target_datasets) + n_cols = len(all_tasks) + + # Create column titles (tasks) for the subplots + column_titles = [task_titles.get(task, str(task)) for task in all_tasks] # type: ignore[arg-type] + + # Create row titles (target datasets) + row_titles = [target_ds.replace("_", " ").title() for target_ds in target_datasets] + + fig = make_subplots( + rows=n_rows, + cols=n_cols, + column_titles=column_titles, + vertical_spacing=0.20 / max(n_rows, 1), + horizontal_spacing=0.08 / max(n_cols, 1), + ) + + # Iterate through each target dataset and task + for row_idx, target_dataset in enumerate(target_datasets, start=1): + task_dict = train_data_dict[target_dataset] # type: ignore[index] + + for col_idx, task in enumerate(all_tasks, start=1): + if task not in task_dict: + # Skip if this task is not available for this target dataset + continue + + dataset = task_dict[task] # type: ignore[index] + + # Determine if this is a classification or regression task + if task in classification_tasks: + _add_classification_subplot(fig, dataset, task, row_idx, col_idx, n_rows) + else: + _add_regression_subplot(fig, dataset, task, row_idx, col_idx, n_rows) + + # Update layout + fig.update_layout( + height=400 * n_rows, + showlegend=True, + legend={ + "orientation": "h", + "yanchor": "top", + "y": -0.15, + "xanchor": "center", + "x": 0.5, + }, + margin={"l": 120, "r": 40, "t": 80, "b": 100}, + ) + + # Update all x-axes and y-axes + fig.update_xaxes(title_font={"size": 11}) + fig.update_yaxes(title_text="Count", title_font={"size": 11}) + + # Manually add row title annotations (left-aligned dataset names) + for row_idx, row_title in enumerate(row_titles, start=1): + # Calculate the y position for each row (at the top of the row) + # Account for vertical spacing between rows + spacing = 0.20 / max(n_rows, 1) + row_height = (1 - spacing * (n_rows - 1)) / n_rows + y_position = 1.08 - (row_idx - 1) * (row_height + spacing) + + fig.add_annotation( + text=f"{row_title}", + xref="paper", + yref="paper", + x=-0.02, + y=y_position, + xanchor="left", + yanchor="top", + textangle=0, + font={"size": 20}, + showarrow=False, + ) + + return fig + + +def _add_classification_subplot( + fig: go.Figure, + dataset: TrainingSet, + task: str, + row: int, + col: int, + n_rows: int, +) -> None: + """Add a classification task subplot showing categorical distribution. + + Args: + fig: Plotly figure to add subplot to. + dataset: TrainingSet containing the data. + task: Task name for color selection. + row: Subplot row index. + col: Subplot column index. + n_rows: Total number of rows (for legend control). + + """ + y_binned = dataset.targets["y"] + categories = y_binned.cat.categories.tolist() + colors = get_palette(task, len(categories) + 2)[1:-1] # Avoid too light/dark colors + + # Calculate counts for train and test splits + train_counts = [((y_binned == cat) & (dataset.split == "train")).sum() for cat in categories] + test_counts = [((y_binned == cat) & (dataset.split == "test")).sum() for cat in categories] + + # Add train bars + fig.add_trace( + go.Bar( + name="Train", + x=categories, + y=train_counts, + marker_color=colors, + opacity=0.9, + text=train_counts, + textposition="inside", + textfont={"size": 9, "color": "white"}, + legendgroup="split", + showlegend=(row == 1 and col == 1), # Only show legend once + ), + row=row, + col=col, + ) + + # Add test bars + fig.add_trace( + go.Bar( + name="Test", + x=categories, + y=test_counts, + marker_color=colors, + opacity=0.5, + text=test_counts, + textposition="inside", + textfont={"size": 9, "color": "white"}, + legendgroup="split", + showlegend=(row == 1 and col == 1), # Only show legend once + ), + row=row, + col=col, + ) + + # Update subplot layout + fig.update_xaxes(tickangle=-45, row=row, col=col) + + +def _add_regression_subplot( + fig: go.Figure, + dataset: TrainingSet, + task: str, + row: int, + col: int, + n_rows: int, +) -> None: + """Add a regression task subplot showing histogram distribution. + + Args: + fig: Plotly figure to add subplot to. + dataset: TrainingSet containing the data. + task: Task name for color selection. + row: Subplot row index. + col: Subplot column index. + n_rows: Total number of rows (for legend control). + + """ + # Get raw values + z_values = dataset.targets["z"] + + # Split into train and test + train_values = z_values[dataset.split == "train"] + test_values = z_values[dataset.split == "test"] + + # Determine bin edges based on combined data + all_values = pd.concat([train_values, test_values]) + # Use a reasonable number of bins + n_bins = min(30, int(np.sqrt(len(all_values)))) + bin_edges = np.histogram_bin_edges(all_values, bins=n_bins) + + # Get colors for the task + colors = get_palette(task, 6)[1:5:2] # Use less bright/dark colors + + # Add train histogram + fig.add_trace( + go.Histogram( + name="Train", + x=train_values, + xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]}, + marker_color=colors[0], + opacity=0.7, + legendgroup="split", + showlegend=(row == 1 and col == 1), # Only show legend once + ), + row=row, + col=col, + ) + + # Add test histogram + fig.add_trace( + go.Histogram( + name="Test", + x=test_values, + xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]}, + marker_color=colors[1], + opacity=0.5, + legendgroup="split", + showlegend=(row == 1 and col == 1), # Only show legend once + ), + row=row, + col=col, + ) + + # Update barmode to overlay for histograms + fig.update_xaxes(title_text=task.capitalize(), row=row, col=col) + + +def _assign_split_colors(gdf: pd.DataFrame) -> pd.DataFrame: + """Assign colors based on train/test split. + + Args: + gdf: GeoDataFrame with 'split' column. + + Returns: + GeoDataFrame with 'fill_color' column added. + + """ + split_colors = get_palette("split", 2) + color_map = {"train": hex_to_rgb(split_colors[0]), "test": hex_to_rgb(split_colors[1])} + # Convert to list to avoid categorical hashing issues with list values + gdf["fill_color"] = [color_map[x] for x in gdf["split"]] + return gdf + + +def _assign_classification_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame: + """Assign colors for classification tasks based on categorical labels. + + Args: + gdf: GeoDataFrame with categorical 'y' column. + task: Task name for color selection. + + Returns: + GeoDataFrame with 'fill_color' column added. + + """ + categories = gdf["y"].cat.categories.tolist() + colors_palette = get_palette(task, len(categories)) + color_map = {cat: hex_to_rgb(colors_palette[i]) for i, cat in enumerate(categories)} + # Convert to list to avoid categorical hashing issues with list values + gdf["fill_color"] = [color_map[x] for x in gdf["y"]] + return gdf + + +def _assign_regression_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame: + """Assign colors for regression tasks based on continuous values. + + Args: + gdf: GeoDataFrame with numeric 'z' column. + task: Task name for color selection. + + Returns: + GeoDataFrame with 'fill_color' column added. + + """ + z_values = gdf["z"] + z_min, z_max = z_values.min(), z_values.max() + + # Normalize to [0, 1] + if z_max > z_min: + z_normalized = (z_values - z_min) / (z_max - z_min) + else: + z_normalized = pd.Series([0.5] * len(z_values), index=z_values.index) + + # Get colormap and map normalized values to RGB + cmap = get_cmap(task) + + def value_to_rgb(normalized_val): + rgba = cmap(normalized_val) + return [int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)] + + gdf["fill_color"] = z_normalized.apply(value_to_rgb) + return gdf + + +def _assign_elevation(gdf: pd.DataFrame, make_3d_map: bool, is_classification: bool) -> pd.DataFrame: + """Assign elevation values for 3D visualization. + + Args: + gdf: GeoDataFrame to add elevation to. + make_3d_map: Whether to create a 3D map. + is_classification: Whether this is a classification task. + + Returns: + GeoDataFrame with 'elevation' column added. + + """ + if not make_3d_map: + gdf["elevation"] = 0 + return gdf + + print(f"{make_3d_map=}, {is_classification=}") + if is_classification: + # For classification, use category codes for elevation + # Convert to int to avoid overflow with int8 categorical codes + z_values = gdf["y"].cat.codes.astype(int) + else: + # For regression, use z values for elevation + z_values = gdf["z"] + z_min, z_max = z_values.min(), z_values.max() + + if z_max > z_min: + gdf["elevation"] = ((z_values - z_min) / (z_max - z_min)).fillna(0) + else: + gdf["elevation"] = 0 + + print(gdf["elevation"].describe()) + return gdf + + +def create_target_spatial_distribution_map( + training_set: TrainingSet, make_3d_map: bool, show_split: bool, task: Task +) -> pdk.Deck: + """Create a spatial distribution map of target labels using PyDeck. + + Args: + training_set (TrainingSet): The training set containing spatial data and target labels. + make_3d_map (bool): Whether to create a 3D map (True) or a 2D map (False). + show_split (bool): Whether to color by train/test split (True) or by target labels (False). + task (Task): The task type to determine color mapping. + + Returns: + pdk.Deck: A PyDeck Deck object representing the spatial distribution map. + + """ + # Get the targets GeoDataFrame + gdf = training_set.targets.copy() + + # Add split information + gdf["split"] = training_set.split.to_numpy() + + # Determine if this is a classification or regression task + is_classification = gdf["y"].dtype.name == "category" + + # Convert to WGS84 for pydeck + gdf_wgs84 = gdf.to_crs("EPSG:4326") + + # Fix antimeridian issues + def fix_hex_geometry(geom): + try: + return shape(antimeridian.fix_shape(geom)) + except (ValueError, Exception): + return geom + + gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry) + + # Assign colors based on mode + if show_split: + gdf_wgs84 = _assign_split_colors(gdf_wgs84) + elif is_classification: + gdf_wgs84 = _assign_classification_colors(gdf_wgs84, task) + else: + gdf_wgs84 = _assign_regression_colors(gdf_wgs84, task) + + # Add elevation for 3D visualization + gdf_wgs84 = _assign_elevation(gdf_wgs84, make_3d_map, is_classification) + + # Convert to GeoJSON format + geojson_data = [] + for _, row in gdf_wgs84.iterrows(): + feature = { + "type": "Feature", + "geometry": row["geometry"].__geo_interface__, + "properties": { + "fill_color": row["fill_color"], + "elevation": row["elevation"], + "split": row["split"], + "y_label": str(row["y"]), + "z_value": float(row["z"]) if pd.notna(row["z"]) else None, + }, + } + geojson_data.append(feature) + + # Create pydeck layer + layer = pdk.Layer( + "GeoJsonLayer", + geojson_data, + opacity=0.7, + stroked=True, + filled=True, + extruded=make_3d_map, + wireframe=False, + get_fill_color="properties.fill_color", + get_line_color=[80, 80, 80], + line_width_min_pixels=0.5, + get_elevation="properties.elevation" if make_3d_map else 0, + elevation_scale=500000 if make_3d_map else 0, + pickable=True, + ) + + # Set initial view state (centered on the Arctic) + view_state = pdk.ViewState( + latitude=70, + longitude=0, + zoom=2 if not make_3d_map else 1.5, + pitch=0 if not make_3d_map else 45, + ) + + # Create tooltip + tooltip_html = "Split: {split}
Label: {y_label}
Value: {z_value}
" + + # Create deck + deck = pdk.Deck( + layers=[layer], + initial_view_state=view_state, + tooltip={ + "html": tooltip_html, + "style": {"backgroundColor": "steelblue", "color": "white"}, + }, + map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", + ) + + return deck diff --git a/src/entropice/dashboard/sections/dataset_statistics.py b/src/entropice/dashboard/sections/dataset_statistics.py index c0c1a56..7105402 100644 --- a/src/entropice/dashboard/sections/dataset_statistics.py +++ b/src/entropice/dashboard/sections/dataset_statistics.py @@ -14,7 +14,7 @@ from entropice.dashboard.plots.dataset_statistics import ( create_sample_count_bar_chart, ) from entropice.dashboard.utils.colors import get_palette -from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics +from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics from entropice.ml.dataset import DatasetEnsemble from entropice.utils.types import ( GridConfig, @@ -436,9 +436,104 @@ def _render_aggregation_selection( return dimension_filters +def render_ensemble_details( + selected_members: list[L2SourceDataset], + selected_member_stats: dict[L2SourceDataset, MemberStatistics], +): + """Render ensemble-specific details: feature breakdown and member information. + + This function displays the feature breakdown by data source and detailed + member information, independent of task or target selection. Can be reused + on other pages by supplying a dataset ensemble. + + Args: + selected_members: List of selected member datasets + selected_member_stats: Statistics for selected members + + """ + # Calculate total features for selected members + total_features = sum(ms.feature_count for ms in selected_member_stats.values()) + + # Feature breakdown by source + st.markdown("#### Feature Breakdown by Data Source") + + breakdown_data = [] + for member, member_stats in selected_member_stats.items(): + breakdown_data.append( + { + "Data Source": member, + "Number of Features": member_stats.feature_count, + "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%", + } + ) + + breakdown_df = pd.DataFrame(breakdown_data) + + # Get all unique data sources and create color map + unique_members_for_color = sorted(selected_member_stats.keys()) + source_color_map_raw = {} + for member in unique_members_for_color: + source = member.split("-")[0] + n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source) + palette = get_palette(source, n_colors=n_members + 2) + idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member) + source_color_map_raw[member] = palette[idx + 1] + + # Create and display pie chart + fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw) + st.plotly_chart(fig, width="stretch") + + # Show detailed table + st.dataframe(breakdown_df, hide_index=True, width="stretch") + + # Detailed member information + with st.expander("📦 Detailed Source Information", expanded=False): + # Create detailed table + member_details_dict = { + member: { + "feature_count": ms.feature_count, + "variable_names": ms.variable_names, + "dimensions": ms.dimensions, + "size_bytes": ms.size_bytes, + } + for member, ms in selected_member_stats.items() + } + details_df = create_member_details_table(member_details_dict) + st.dataframe(details_df, hide_index=True, width="stretch") + + # Individual member details + for member, member_stats in selected_member_stats.items(): + st.markdown(f"### {member}") + + # Variables + st.markdown("**Variables:**") + vars_html = " ".join( + [ + f'{v}' + for v in member_stats.variable_names + ] + ) + st.markdown(vars_html, unsafe_allow_html=True) + + # Dimensions + st.markdown("**Dimensions:**") + dim_html = " ".join( + [ + f'' + f"{dim_name}: {dim_size:,}" + for dim_name, dim_size in member_stats.dimensions.items() + ] + ) + st.markdown(dim_html, unsafe_allow_html=True) + + st.markdown("---") + + def _render_configuration_summary( selected_members: list[L2SourceDataset], - selected_member_stats: dict[str, MemberStatistics], + selected_member_stats: dict[L2SourceDataset, MemberStatistics], selected_target: TargetDataset, selected_task: Task, selected_temporal_mode: TemporalMode, @@ -527,81 +622,8 @@ def _render_configuration_summary( ) st.dataframe(class_dist_df, hide_index=True, width="stretch") - # Feature breakdown by source - st.markdown("#### Feature Breakdown by Data Source") - - breakdown_data = [] - for member, member_stats in selected_member_stats.items(): - breakdown_data.append( - { - "Data Source": member, - "Number of Features": member_stats.feature_count, - "Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%", - } - ) - - breakdown_df = pd.DataFrame(breakdown_data) - - # Get all unique data sources and create color map - unique_members_for_color = sorted(selected_member_stats.keys()) - source_color_map_raw = {} - for member in unique_members_for_color: - source = member.split("-")[0] - n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source) - palette = get_palette(source, n_colors=n_members + 2) - idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member) - source_color_map_raw[member] = palette[idx + 1] - - # Create and display pie chart - fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw) - st.plotly_chart(fig, width="stretch") - - # Show detailed table - st.dataframe(breakdown_df, hide_index=True, width="stretch") - - # Detailed member information - with st.expander("📦 Detailed Source Information", expanded=False): - # Create detailed table - member_details_dict = { - member: { - "feature_count": ms.feature_count, - "variable_names": ms.variable_names, - "dimensions": ms.dimensions, - "size_bytes": ms.size_bytes, - } - for member, ms in selected_member_stats.items() - } - details_df = create_member_details_table(member_details_dict) - st.dataframe(details_df, hide_index=True, width="stretch") - - # Individual member details - for member, member_stats in selected_member_stats.items(): - st.markdown(f"### {member}") - - # Variables - st.markdown("**Variables:**") - vars_html = " ".join( - [ - f'{v}' - for v in member_stats.variable_names - ] - ) - st.markdown(vars_html, unsafe_allow_html=True) - - # Dimensions - st.markdown("**Dimensions:**") - dim_html = " ".join( - [ - f'' - f"{dim_name}: {dim_size:,}" - for dim_name, dim_size in member_stats.dimensions.items() - ] - ) - st.markdown(dim_html, unsafe_allow_html=True) - - st.markdown("---") + # Render ensemble-specific details + render_ensemble_details(selected_members, selected_member_stats) @st.fragment @@ -650,7 +672,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache): # Use pre-computed statistics (much faster) stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index] # Filter to selected members only - selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members} + selected_member_stats: dict[L2SourceDataset, MemberStatistics] = { + m: stats.members[m] for m in selected_members if m in stats.members + } else: # Create the actual ensemble with selected members and dimension filters ensemble = DatasetEnsemble( @@ -663,7 +687,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache): # Recompute stats with the filtered ensemble stats = DatasetStatistics.from_ensemble(ensemble) # Filter to selected members only - selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members} + selected_member_stats: dict[L2SourceDataset, MemberStatistics] = { + m: stats.members[m] for m in selected_members if m in stats.members + } # Render configuration summary and statistics _render_configuration_summary( @@ -678,16 +704,16 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache): st.info("👆 Select at least one data source to see feature statistics") -def render_dataset_statistics(all_stats: DatasetStatsCache): +def render_dataset_statistics( + all_stats: DatasetStatsCache, + training_sample_df: pd.DataFrame, + feature_breakdown_df: pd.DataFrame, + comparison_df: pd.DataFrame, + inference_sample_df: pd.DataFrame, +): """Render the dataset statistics section with sample and feature counts.""" st.header("📈 Dataset Statistics") - all_stats: DatasetStatsCache = load_all_default_dataset_statistics() - training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats) - feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) - comparison_df = DatasetStatistics.get_comparison_df(all_stats) - inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats) - # Create tabs for different analysis views analysis_tabs = st.tabs( [ diff --git a/src/entropice/dashboard/sections/targets.py b/src/entropice/dashboard/sections/targets.py new file mode 100644 index 0000000..8c41cd1 --- /dev/null +++ b/src/entropice/dashboard/sections/targets.py @@ -0,0 +1,225 @@ +"""Target dataset dashboard section.""" + +from typing import cast + +import matplotlib.colors as mcolors +import streamlit as st + +from entropice.dashboard.plots.targets import create_target_distribution_plot, create_target_spatial_distribution_map +from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb +from entropice.ml.dataset import TrainingSet +from entropice.utils.types import TargetDataset, Task + + +def _render_distribution(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]): + st.subheader("Target Labels Distribution") + fig = create_target_distribution_plot(train_data_dict) + st.plotly_chart(fig, width="stretch") + st.markdown( + """ + The above plot shows the distribution of target labels in the training and test datasets. + You can use this information to understand the balance of classes and identify any potential + issues with class imbalance that may affect model training. + """ + ) + + +def _render_split_legend(training_set: TrainingSet): + """Render legend for train/test split visualization.""" + st.markdown("**Data Split:**") + split_colors = get_palette("split", 2) + + col1, col2 = st.columns(2) + with col1: + color_rgb = hex_to_rgb(split_colors[0]) + st.markdown( + f'
' + f'
' + f"Train
", + unsafe_allow_html=True, + ) + with col2: + color_rgb = hex_to_rgb(split_colors[1]) + st.markdown( + f'
' + f'
' + f"Test
", + unsafe_allow_html=True, + ) + + # Show split statistics + train_count = (training_set.split == "train").sum() + test_count = (training_set.split == "test").sum() + total = len(training_set) + st.caption( + f"Train: {train_count:,} ({train_count / total * 100:.1f}%) | " + f"Test: {test_count:,} ({test_count / total * 100:.1f}%)" + ) + + +def _render_classification_legend(training_set: TrainingSet, selected_task: Task): + """Render legend for classification task visualization.""" + st.markdown("**Target Classes:**") + categories = training_set.targets["y"].cat.categories.tolist() + colors_palette = get_palette(selected_task, len(categories)) + intervals = training_set.target_intervals + + # Show categories with colors and intervals + for i, cat in enumerate(categories): + color = colors_palette[i] + interval_min, interval_max = intervals[i] + + # Format interval display + interval_str = _format_interval(interval_min, interval_max, selected_task) + + st.markdown( + f'
' + f'
' + f"{cat}{interval_str}
", + unsafe_allow_html=True, + ) + + # Show total samples + st.caption(f"Total samples: {len(training_set):,}") + + +def _format_interval(interval_min, interval_max, selected_task: Task) -> str: + """Format interval string based on task type.""" + if interval_min is None or interval_max is None: + return "" + + if selected_task in ["count", "count_regimes"]: + # Integer values for count + if interval_min == interval_max: + return f" ({int(interval_min)})" + return f" ({int(interval_min)}-{int(interval_max)})" + + if selected_task in ["density", "density_regimes"]: + # Percentage values for density + if interval_min == interval_max: + return f" ({interval_min * 100:.4f}%)" + return f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)" + + # Binary or other + return "" + + +def _render_regression_legend(training_set: TrainingSet, selected_task: Task): + """Render legend for regression task visualization.""" + st.markdown("**Target Values:**") + z_values = training_set.targets["z"] + z_min, z_max = z_values.min(), z_values.max() + z_mean = z_values.mean() + z_median = z_values.median() + + st.markdown(f"- **Range:** {z_min:.4f} to {z_max:.4f}\n- **Mean:** {z_mean:.4f}\n- **Median:** {z_median:.4f}") + + # Show gradient visualization using the actual colormap + st.markdown("**Color Gradient:**") + + # Get the same colormap used in the map + cmap = get_cmap(selected_task) + # Sample colors from the colormap + n_samples = 10 + gradient_colors = [mcolors.to_hex(cmap(i / (n_samples - 1))) for i in range(n_samples)] + gradient_str = ", ".join(gradient_colors) + + st.markdown( + f'
' + f'Low ({z_min:.4f})' + f'
' + f'High ({z_max:.4f})' + f"
", + unsafe_allow_html=True, + ) + + +def _render_elevation_info(training_set: TrainingSet): + """Render 3D elevation information.""" + st.markdown("---") + st.markdown("**3D Elevation:**") + + is_classification = training_set.targets["y"].dtype.name == "category" + if is_classification: + st.info("💡 Height represents category level. Rotate the map by holding Ctrl/Cmd and dragging with your mouse.") + else: + z_values = training_set.targets["z"] + z_min, z_max = z_values.min(), z_values.max() + st.info( + f"💡 Height represents target value (normalized): {z_min:.4f} (low) → {z_max:.4f} (high). " + "Rotate the map by holding Ctrl/Cmd and dragging with your mouse." + ) + + +@st.fragment +def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]): + st.subheader("Target Labels Spatial Distribution") + + cols = st.columns([2, 2, 1]) + with cols[0]: + selected_target = cast( + TargetDataset, + st.selectbox( + "Select Target Dataset", + options=sorted(train_data_dict.keys()), + index=0, + ), + ) + with cols[1]: + selected_task = cast( + Task, + st.selectbox( + "Select Task", + options=sorted(train_data_dict[selected_target].keys()), + index=0, + ), + ) + with cols[2]: + # Controls weather a 3D map or a 2D map is shown + make_3d_map = cast(bool, st.checkbox("3D Map", value=True)) + # Controls what should be shows, either the split or the labels / values + show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False)) + + training_set = train_data_dict[selected_target][selected_task] + map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task) + st.pydeck_chart(map_deck) + + # Add legend + with st.expander("🎨 Legend", expanded=True): + if show_split: + _render_split_legend(training_set) + else: + # Show target labels legend + is_classification = training_set.targets["y"].dtype.name == "category" + if is_classification: + _render_classification_legend(training_set, selected_task) + else: + _render_regression_legend(training_set, selected_task) + + # Add elevation info for 3D maps + if make_3d_map: + _render_elevation_info(training_set) + + +def render_target_information_tab(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]): + """Render target labels distribution and spatial visualization. + + Args: + train_data_dict: Nested dictionary with structure: + {target_dataset: {task: TrainingSet}} + where target_dataset is in ["darts_v1", "darts_mllabels"] + and task is in ["binary", "count_regimes", "density_regimes", "count", "density"] + + """ + _render_distribution(train_data_dict) + + st.divider() + + _render_target_map(train_data_dict) diff --git a/src/entropice/dashboard/utils/colors.py b/src/entropice/dashboard/utils/colors.py index aa1a758..fe4f0a0 100644 --- a/src/entropice/dashboard/utils/colors.py +++ b/src/entropice/dashboard/utils/colors.py @@ -31,6 +31,20 @@ import streamlit as st from pypalettes import load_cmap +def hex_to_rgb(hex_color: str) -> list[int]: + """Convert hex color to RGB list. + + Args: + hex_color: Hex color string (e.g., '#FF5733'). + + Returns: + List of [R, G, B] values (0-255). + + """ + hex_color = hex_color.lstrip("#") + return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)] + + def get_cmap(variable: str) -> mcolors.Colormap: """Get a color palette by a "data" variable. diff --git a/src/entropice/dashboard/views/dataset_page.py b/src/entropice/dashboard/views/dataset_page.py new file mode 100644 index 0000000..1fab911 --- /dev/null +++ b/src/entropice/dashboard/views/dataset_page.py @@ -0,0 +1,155 @@ +"""Data page: Visualization of the data.""" + +from typing import cast + +import streamlit as st +from stopuhr import stopwatch + +from entropice.dashboard.sections.dataset_statistics import render_ensemble_details +from entropice.dashboard.sections.targets import render_target_information_tab +from entropice.dashboard.utils.stats import DatasetStatistics +from entropice.ml.dataset import DatasetEnsemble, TrainingSet +from entropice.utils.types import ( + GridConfig, + L2SourceDataset, + TargetDataset, + Task, + TemporalMode, + all_target_datasets, + all_tasks, + grid_configs, +) + + +def render_dataset_configuration_sidebar() -> DatasetEnsemble: + """Render dataset configuration selector in sidebar with form. + + Stores the selected ensemble in session state when form is submitted. + """ + with st.sidebar.form("dataset_config_form"): + st.header("Dataset Configuration") + + # Grid selection + grid_options = [gc.display_name for gc in grid_configs] + + grid_level_combined = st.selectbox( + "Grid Configuration", + options=grid_options, + index=0, + help="Select the grid system and resolution level", + ) + + # Find the selected grid config + selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined) + + # Temporal mode selection + temporal_mode = st.selectbox( + "Temporal Mode", + options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]), + index=0, + format_func=lambda x: "Synopsis (all years)" + if x == "synopsis" + else "Years-as-Features" + if x == "feature" + else f"Year {x}", + help="Select temporal mode: 'synopsis' for temporal features or specific year", + ) + + # Members selection + st.subheader("Dataset Members") + + all_members = cast( + list[L2SourceDataset], + ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"], + ) + selected_members: list[L2SourceDataset] = [] + + for member in all_members: + if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): + selected_members.append(member) + + # Form submit button + submitted = st.form_submit_button( + "Load Dataset", + type="primary", + use_container_width=True, + disabled=len(selected_members) == 0, + ) + + if not submitted: + st.info("👆 Click 'Load Dataset' to apply changes.") + st.stop() + + ensemble = DatasetEnsemble( + grid=selected_grid_config.grid, + level=selected_grid_config.level, + temporal_mode=cast(TemporalMode, temporal_mode), + members=selected_members, + ) + return ensemble + + +def render_dataset_page(): + """Render the Dataset page of the dashboard.""" + st.title("📊 Entropice Data Visualization") + st.markdown( + """ + This section allows you to visualize the source datasets and training data + used in your experiments. Use the sidebar to configure the dataset parameters. + """ + ) + + # Get the selected dataset ensemble from sidebar + ensemble = render_dataset_configuration_sidebar() + + # Display dataset ID in a styled container + st.info(f"Loaded Dataset with ID `{ensemble.id()}`") + + st.divider() + + # Render dataset statistics section + stats = DatasetStatistics.from_ensemble(ensemble) + render_ensemble_details(ensemble.members, stats.members) + + st.divider() + + # Load data and precompute visualizations + # First, load for all task - target combinations the training data + train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {} + for target in all_target_datasets: + train_data_dict[target] = {} + for task in all_tasks: + train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task) + + era5_members = [m for m in ensemble.members if m.startswith("ERA5")] + # Create tabs for different data views + tab_names = ["🎯 Targets", "📐 Areas"] + # Add tabs for each member based on what's in the ensemble + if "AlphaEarth" in ensemble.members: + tab_names.append("🌍 AlphaEarth") + if "ArcticDEM" in ensemble.members: + tab_names.append("🏔️ ArcticDEM") + if era5_members: + tab_names.append("🌡️ ERA5") + tabs = st.tabs(tab_names) + + with tabs[0]: + st.header("🎯 Target Labels Visualization") + render_target_information_tab(train_data_dict) + with tabs[1]: + st.header("📐 Areas Visualization") + tab_index = 2 + if "AlphaEarth" in ensemble.members: + with tabs[tab_index]: + st.header("🌍 AlphaEarth Visualization") + tab_index += 1 + if "ArcticDEM" in ensemble.members: + with tabs[tab_index]: + st.header("🏔️ ArcticDEM Visualization") + tab_index += 1 + if era5_members: + with tabs[tab_index]: + st.header("🌡️ ERA5 Visualization") + + st.balloons() + stopwatch.summary() diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 0a87a8e..90e3499 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -9,9 +9,7 @@ from entropice.dashboard.sections.experiment_results import ( render_training_results_summary, ) from entropice.dashboard.utils.loaders import load_all_training_results -from entropice.dashboard.utils.stats import ( - load_all_default_dataset_statistics, -) +from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics def render_overview_page(): @@ -47,7 +45,12 @@ def render_overview_page(): # Render dataset analysis section all_stats = load_all_default_dataset_statistics() - render_dataset_statistics(all_stats) + training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats) + feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) + comparison_df = DatasetStatistics.get_comparison_df(all_stats) + inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats) + + render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df) st.balloons() stopwatch.summary()