Add a Dataset Page to deprecate the Training Data page
This commit is contained in:
parent
a6e9a91692
commit
cb7b0f9e6b
7 changed files with 986 additions and 90 deletions
|
|
@ -13,6 +13,7 @@ Pages:
|
|||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.views.autogluon_analysis_page import render_autogluon_analysis_page
|
||||
from entropice.dashboard.views.dataset_page import render_dataset_page
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
|
|
@ -26,6 +27,7 @@ def main():
|
|||
|
||||
# Setup Navigation
|
||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
||||
|
|
@ -35,6 +37,7 @@ def main():
|
|||
pg = st.navigation(
|
||||
{
|
||||
"Overview": [overview_page],
|
||||
"Data": [data_page],
|
||||
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
||||
"Model State": [model_state_page],
|
||||
"Inference": [inference_page],
|
||||
|
|
|
|||
470
src/entropice/dashboard/plots/targets.py
Normal file
470
src/entropice/dashboard/plots/targets.py
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
"""Plots for visualizing target labels in datasets."""
|
||||
|
||||
import antimeridian
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import pydeck as pdk
|
||||
from plotly.subplots import make_subplots
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
|
||||
from entropice.ml.dataset import TrainingSet
|
||||
from entropice.utils.types import TargetDataset, Task
|
||||
|
||||
|
||||
def create_target_distribution_plot(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]) -> go.Figure:
|
||||
"""Create a plot showing the distribution of target labels across datasets and tasks.
|
||||
|
||||
This function creates a comprehensive visualization showing:
|
||||
- Classification tasks (binary, count_regimes, density_regimes): categorical distributions
|
||||
- Regression tasks (count, density): histogram distributions
|
||||
- Train/test split information for each combination
|
||||
- Separate subplots for each TargetDataset
|
||||
|
||||
Args:
|
||||
train_data_dict: Nested dictionary with structure:
|
||||
{target_dataset: {task: TrainingSet}}
|
||||
where target_dataset is in ["darts_v1", "darts_mllabels"]
|
||||
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||
|
||||
Returns:
|
||||
Plotly Figure with subplots showing target distributions.
|
||||
|
||||
"""
|
||||
# Define task types and their properties
|
||||
classification_tasks: list[Task] = ["binary", "count_regimes", "density_regimes"]
|
||||
|
||||
task_titles: dict[Task, str] = {
|
||||
"binary": "Binary",
|
||||
"count_regimes": "Count Regimes",
|
||||
"density_regimes": "Density Regimes",
|
||||
"count": "Count (Regression)",
|
||||
"density": "Density (Regression)",
|
||||
}
|
||||
|
||||
# Get all available target datasets and tasks
|
||||
target_datasets = sorted(train_data_dict.keys())
|
||||
all_tasks = sorted(
|
||||
{task for tasks in train_data_dict.values() for task in tasks.keys()},
|
||||
key=lambda x: (x not in classification_tasks, x), # Classification first
|
||||
)
|
||||
|
||||
# Create subplots: one row per target dataset, one column per task
|
||||
n_rows = len(target_datasets)
|
||||
n_cols = len(all_tasks)
|
||||
|
||||
# Create column titles (tasks) for the subplots
|
||||
column_titles = [task_titles.get(task, str(task)) for task in all_tasks] # type: ignore[arg-type]
|
||||
|
||||
# Create row titles (target datasets)
|
||||
row_titles = [target_ds.replace("_", " ").title() for target_ds in target_datasets]
|
||||
|
||||
fig = make_subplots(
|
||||
rows=n_rows,
|
||||
cols=n_cols,
|
||||
column_titles=column_titles,
|
||||
vertical_spacing=0.20 / max(n_rows, 1),
|
||||
horizontal_spacing=0.08 / max(n_cols, 1),
|
||||
)
|
||||
|
||||
# Iterate through each target dataset and task
|
||||
for row_idx, target_dataset in enumerate(target_datasets, start=1):
|
||||
task_dict = train_data_dict[target_dataset] # type: ignore[index]
|
||||
|
||||
for col_idx, task in enumerate(all_tasks, start=1):
|
||||
if task not in task_dict:
|
||||
# Skip if this task is not available for this target dataset
|
||||
continue
|
||||
|
||||
dataset = task_dict[task] # type: ignore[index]
|
||||
|
||||
# Determine if this is a classification or regression task
|
||||
if task in classification_tasks:
|
||||
_add_classification_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
|
||||
else:
|
||||
_add_regression_subplot(fig, dataset, task, row_idx, col_idx, n_rows)
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
height=400 * n_rows,
|
||||
showlegend=True,
|
||||
legend={
|
||||
"orientation": "h",
|
||||
"yanchor": "top",
|
||||
"y": -0.15,
|
||||
"xanchor": "center",
|
||||
"x": 0.5,
|
||||
},
|
||||
margin={"l": 120, "r": 40, "t": 80, "b": 100},
|
||||
)
|
||||
|
||||
# Update all x-axes and y-axes
|
||||
fig.update_xaxes(title_font={"size": 11})
|
||||
fig.update_yaxes(title_text="Count", title_font={"size": 11})
|
||||
|
||||
# Manually add row title annotations (left-aligned dataset names)
|
||||
for row_idx, row_title in enumerate(row_titles, start=1):
|
||||
# Calculate the y position for each row (at the top of the row)
|
||||
# Account for vertical spacing between rows
|
||||
spacing = 0.20 / max(n_rows, 1)
|
||||
row_height = (1 - spacing * (n_rows - 1)) / n_rows
|
||||
y_position = 1.08 - (row_idx - 1) * (row_height + spacing)
|
||||
|
||||
fig.add_annotation(
|
||||
text=f"<b>{row_title}</b>",
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=-0.02,
|
||||
y=y_position,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
textangle=0,
|
||||
font={"size": 20},
|
||||
showarrow=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _add_classification_subplot(
|
||||
fig: go.Figure,
|
||||
dataset: TrainingSet,
|
||||
task: str,
|
||||
row: int,
|
||||
col: int,
|
||||
n_rows: int,
|
||||
) -> None:
|
||||
"""Add a classification task subplot showing categorical distribution.
|
||||
|
||||
Args:
|
||||
fig: Plotly figure to add subplot to.
|
||||
dataset: TrainingSet containing the data.
|
||||
task: Task name for color selection.
|
||||
row: Subplot row index.
|
||||
col: Subplot column index.
|
||||
n_rows: Total number of rows (for legend control).
|
||||
|
||||
"""
|
||||
y_binned = dataset.targets["y"]
|
||||
categories = y_binned.cat.categories.tolist()
|
||||
colors = get_palette(task, len(categories) + 2)[1:-1] # Avoid too light/dark colors
|
||||
|
||||
# Calculate counts for train and test splits
|
||||
train_counts = [((y_binned == cat) & (dataset.split == "train")).sum() for cat in categories]
|
||||
test_counts = [((y_binned == cat) & (dataset.split == "test")).sum() for cat in categories]
|
||||
|
||||
# Add train bars
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
name="Train",
|
||||
x=categories,
|
||||
y=train_counts,
|
||||
marker_color=colors,
|
||||
opacity=0.9,
|
||||
text=train_counts,
|
||||
textposition="inside",
|
||||
textfont={"size": 9, "color": "white"},
|
||||
legendgroup="split",
|
||||
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Add test bars
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
name="Test",
|
||||
x=categories,
|
||||
y=test_counts,
|
||||
marker_color=colors,
|
||||
opacity=0.5,
|
||||
text=test_counts,
|
||||
textposition="inside",
|
||||
textfont={"size": 9, "color": "white"},
|
||||
legendgroup="split",
|
||||
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Update subplot layout
|
||||
fig.update_xaxes(tickangle=-45, row=row, col=col)
|
||||
|
||||
|
||||
def _add_regression_subplot(
|
||||
fig: go.Figure,
|
||||
dataset: TrainingSet,
|
||||
task: str,
|
||||
row: int,
|
||||
col: int,
|
||||
n_rows: int,
|
||||
) -> None:
|
||||
"""Add a regression task subplot showing histogram distribution.
|
||||
|
||||
Args:
|
||||
fig: Plotly figure to add subplot to.
|
||||
dataset: TrainingSet containing the data.
|
||||
task: Task name for color selection.
|
||||
row: Subplot row index.
|
||||
col: Subplot column index.
|
||||
n_rows: Total number of rows (for legend control).
|
||||
|
||||
"""
|
||||
# Get raw values
|
||||
z_values = dataset.targets["z"]
|
||||
|
||||
# Split into train and test
|
||||
train_values = z_values[dataset.split == "train"]
|
||||
test_values = z_values[dataset.split == "test"]
|
||||
|
||||
# Determine bin edges based on combined data
|
||||
all_values = pd.concat([train_values, test_values])
|
||||
# Use a reasonable number of bins
|
||||
n_bins = min(30, int(np.sqrt(len(all_values))))
|
||||
bin_edges = np.histogram_bin_edges(all_values, bins=n_bins)
|
||||
|
||||
# Get colors for the task
|
||||
colors = get_palette(task, 6)[1:5:2] # Use less bright/dark colors
|
||||
|
||||
# Add train histogram
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
name="Train",
|
||||
x=train_values,
|
||||
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
|
||||
marker_color=colors[0],
|
||||
opacity=0.7,
|
||||
legendgroup="split",
|
||||
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Add test histogram
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
name="Test",
|
||||
x=test_values,
|
||||
xbins={"start": bin_edges[0], "end": bin_edges[-1], "size": bin_edges[1] - bin_edges[0]},
|
||||
marker_color=colors[1],
|
||||
opacity=0.5,
|
||||
legendgroup="split",
|
||||
showlegend=(row == 1 and col == 1), # Only show legend once
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Update barmode to overlay for histograms
|
||||
fig.update_xaxes(title_text=task.capitalize(), row=row, col=col)
|
||||
|
||||
|
||||
def _assign_split_colors(gdf: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Assign colors based on train/test split.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with 'split' column.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with 'fill_color' column added.
|
||||
|
||||
"""
|
||||
split_colors = get_palette("split", 2)
|
||||
color_map = {"train": hex_to_rgb(split_colors[0]), "test": hex_to_rgb(split_colors[1])}
|
||||
# Convert to list to avoid categorical hashing issues with list values
|
||||
gdf["fill_color"] = [color_map[x] for x in gdf["split"]]
|
||||
return gdf
|
||||
|
||||
|
||||
def _assign_classification_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
|
||||
"""Assign colors for classification tasks based on categorical labels.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with categorical 'y' column.
|
||||
task: Task name for color selection.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with 'fill_color' column added.
|
||||
|
||||
"""
|
||||
categories = gdf["y"].cat.categories.tolist()
|
||||
colors_palette = get_palette(task, len(categories))
|
||||
color_map = {cat: hex_to_rgb(colors_palette[i]) for i, cat in enumerate(categories)}
|
||||
# Convert to list to avoid categorical hashing issues with list values
|
||||
gdf["fill_color"] = [color_map[x] for x in gdf["y"]]
|
||||
return gdf
|
||||
|
||||
|
||||
def _assign_regression_colors(gdf: pd.DataFrame, task: Task) -> pd.DataFrame:
|
||||
"""Assign colors for regression tasks based on continuous values.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame with numeric 'z' column.
|
||||
task: Task name for color selection.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with 'fill_color' column added.
|
||||
|
||||
"""
|
||||
z_values = gdf["z"]
|
||||
z_min, z_max = z_values.min(), z_values.max()
|
||||
|
||||
# Normalize to [0, 1]
|
||||
if z_max > z_min:
|
||||
z_normalized = (z_values - z_min) / (z_max - z_min)
|
||||
else:
|
||||
z_normalized = pd.Series([0.5] * len(z_values), index=z_values.index)
|
||||
|
||||
# Get colormap and map normalized values to RGB
|
||||
cmap = get_cmap(task)
|
||||
|
||||
def value_to_rgb(normalized_val):
|
||||
rgba = cmap(normalized_val)
|
||||
return [int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)]
|
||||
|
||||
gdf["fill_color"] = z_normalized.apply(value_to_rgb)
|
||||
return gdf
|
||||
|
||||
|
||||
def _assign_elevation(gdf: pd.DataFrame, make_3d_map: bool, is_classification: bool) -> pd.DataFrame:
|
||||
"""Assign elevation values for 3D visualization.
|
||||
|
||||
Args:
|
||||
gdf: GeoDataFrame to add elevation to.
|
||||
make_3d_map: Whether to create a 3D map.
|
||||
is_classification: Whether this is a classification task.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with 'elevation' column added.
|
||||
|
||||
"""
|
||||
if not make_3d_map:
|
||||
gdf["elevation"] = 0
|
||||
return gdf
|
||||
|
||||
print(f"{make_3d_map=}, {is_classification=}")
|
||||
if is_classification:
|
||||
# For classification, use category codes for elevation
|
||||
# Convert to int to avoid overflow with int8 categorical codes
|
||||
z_values = gdf["y"].cat.codes.astype(int)
|
||||
else:
|
||||
# For regression, use z values for elevation
|
||||
z_values = gdf["z"]
|
||||
z_min, z_max = z_values.min(), z_values.max()
|
||||
|
||||
if z_max > z_min:
|
||||
gdf["elevation"] = ((z_values - z_min) / (z_max - z_min)).fillna(0)
|
||||
else:
|
||||
gdf["elevation"] = 0
|
||||
|
||||
print(gdf["elevation"].describe())
|
||||
return gdf
|
||||
|
||||
|
||||
def create_target_spatial_distribution_map(
|
||||
training_set: TrainingSet, make_3d_map: bool, show_split: bool, task: Task
|
||||
) -> pdk.Deck:
|
||||
"""Create a spatial distribution map of target labels using PyDeck.
|
||||
|
||||
Args:
|
||||
training_set (TrainingSet): The training set containing spatial data and target labels.
|
||||
make_3d_map (bool): Whether to create a 3D map (True) or a 2D map (False).
|
||||
show_split (bool): Whether to color by train/test split (True) or by target labels (False).
|
||||
task (Task): The task type to determine color mapping.
|
||||
|
||||
Returns:
|
||||
pdk.Deck: A PyDeck Deck object representing the spatial distribution map.
|
||||
|
||||
"""
|
||||
# Get the targets GeoDataFrame
|
||||
gdf = training_set.targets.copy()
|
||||
|
||||
# Add split information
|
||||
gdf["split"] = training_set.split.to_numpy()
|
||||
|
||||
# Determine if this is a classification or regression task
|
||||
is_classification = gdf["y"].dtype.name == "category"
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||
|
||||
# Fix antimeridian issues
|
||||
def fix_hex_geometry(geom):
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except (ValueError, Exception):
|
||||
return geom
|
||||
|
||||
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(fix_hex_geometry)
|
||||
|
||||
# Assign colors based on mode
|
||||
if show_split:
|
||||
gdf_wgs84 = _assign_split_colors(gdf_wgs84)
|
||||
elif is_classification:
|
||||
gdf_wgs84 = _assign_classification_colors(gdf_wgs84, task)
|
||||
else:
|
||||
gdf_wgs84 = _assign_regression_colors(gdf_wgs84, task)
|
||||
|
||||
# Add elevation for 3D visualization
|
||||
gdf_wgs84 = _assign_elevation(gdf_wgs84, make_3d_map, is_classification)
|
||||
|
||||
# Convert to GeoJSON format
|
||||
geojson_data = []
|
||||
for _, row in gdf_wgs84.iterrows():
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"fill_color": row["fill_color"],
|
||||
"elevation": row["elevation"],
|
||||
"split": row["split"],
|
||||
"y_label": str(row["y"]),
|
||||
"z_value": float(row["z"]) if pd.notna(row["z"]) else None,
|
||||
},
|
||||
}
|
||||
geojson_data.append(feature)
|
||||
|
||||
# Create pydeck layer
|
||||
layer = pdk.Layer(
|
||||
"GeoJsonLayer",
|
||||
geojson_data,
|
||||
opacity=0.7,
|
||||
stroked=True,
|
||||
filled=True,
|
||||
extruded=make_3d_map,
|
||||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color=[80, 80, 80],
|
||||
line_width_min_pixels=0.5,
|
||||
get_elevation="properties.elevation" if make_3d_map else 0,
|
||||
elevation_scale=500000 if make_3d_map else 0,
|
||||
pickable=True,
|
||||
)
|
||||
|
||||
# Set initial view state (centered on the Arctic)
|
||||
view_state = pdk.ViewState(
|
||||
latitude=70,
|
||||
longitude=0,
|
||||
zoom=2 if not make_3d_map else 1.5,
|
||||
pitch=0 if not make_3d_map else 45,
|
||||
)
|
||||
|
||||
# Create tooltip
|
||||
tooltip_html = "<b>Split:</b> {split}<br/><b>Label:</b> {y_label}<br/><b>Value:</b> {z_value}<br/>"
|
||||
|
||||
# Create deck
|
||||
deck = pdk.Deck(
|
||||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": tooltip_html,
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
)
|
||||
|
||||
return deck
|
||||
|
|
@ -14,7 +14,7 @@ from entropice.dashboard.plots.dataset_statistics import (
|
|||
create_sample_count_bar_chart,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics, load_all_default_dataset_statistics
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, MemberStatistics
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
|
|
@ -436,9 +436,104 @@ def _render_aggregation_selection(
|
|||
return dimension_filters
|
||||
|
||||
|
||||
def render_ensemble_details(
|
||||
selected_members: list[L2SourceDataset],
|
||||
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
||||
):
|
||||
"""Render ensemble-specific details: feature breakdown and member information.
|
||||
|
||||
This function displays the feature breakdown by data source and detailed
|
||||
member information, independent of task or target selection. Can be reused
|
||||
on other pages by supplying a dataset ensemble.
|
||||
|
||||
Args:
|
||||
selected_members: List of selected member datasets
|
||||
selected_member_stats: Statistics for selected members
|
||||
|
||||
"""
|
||||
# Calculate total features for selected members
|
||||
total_features = sum(ms.feature_count for ms in selected_member_stats.values())
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||
source_color_map_raw = {}
|
||||
for member in unique_members_for_color:
|
||||
source = member.split("-")[0]
|
||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||
palette = get_palette(source, n_colors=n_members + 2)
|
||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||
source_color_map_raw[member] = palette[idx + 1]
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
# Create detailed table
|
||||
member_details_dict = {
|
||||
member: {
|
||||
"feature_count": ms.feature_count,
|
||||
"variable_names": ms.variable_names,
|
||||
"dimensions": ms.dimensions,
|
||||
"size_bytes": ms.size_bytes,
|
||||
}
|
||||
for member, ms in selected_member_stats.items()
|
||||
}
|
||||
details_df = create_member_details_table(member_details_dict)
|
||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||
|
||||
# Individual member details
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size:,}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
|
||||
def _render_configuration_summary(
|
||||
selected_members: list[L2SourceDataset],
|
||||
selected_member_stats: dict[str, MemberStatistics],
|
||||
selected_member_stats: dict[L2SourceDataset, MemberStatistics],
|
||||
selected_target: TargetDataset,
|
||||
selected_task: Task,
|
||||
selected_temporal_mode: TemporalMode,
|
||||
|
|
@ -527,81 +622,8 @@ def _render_configuration_summary(
|
|||
)
|
||||
st.dataframe(class_dist_df, hide_index=True, width="stretch")
|
||||
|
||||
# Feature breakdown by source
|
||||
st.markdown("#### Feature Breakdown by Data Source")
|
||||
|
||||
breakdown_data = []
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
breakdown_data.append(
|
||||
{
|
||||
"Data Source": member,
|
||||
"Number of Features": member_stats.feature_count,
|
||||
"Percentage": f"{member_stats.feature_count / total_features * 100:.1f}%",
|
||||
}
|
||||
)
|
||||
|
||||
breakdown_df = pd.DataFrame(breakdown_data)
|
||||
|
||||
# Get all unique data sources and create color map
|
||||
unique_members_for_color = sorted(selected_member_stats.keys())
|
||||
source_color_map_raw = {}
|
||||
for member in unique_members_for_color:
|
||||
source = member.split("-")[0]
|
||||
n_members = sum(1 for m in unique_members_for_color if m.split("-")[0] == source)
|
||||
palette = get_palette(source, n_colors=n_members + 2)
|
||||
idx = [m for m in unique_members_for_color if m.split("-")[0] == source].index(member)
|
||||
source_color_map_raw[member] = palette[idx + 1]
|
||||
|
||||
# Create and display pie chart
|
||||
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map_raw)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
|
||||
# Show detailed table
|
||||
st.dataframe(breakdown_df, hide_index=True, width="stretch")
|
||||
|
||||
# Detailed member information
|
||||
with st.expander("📦 Detailed Source Information", expanded=False):
|
||||
# Create detailed table
|
||||
member_details_dict = {
|
||||
member: {
|
||||
"feature_count": ms.feature_count,
|
||||
"variable_names": ms.variable_names,
|
||||
"dimensions": ms.dimensions,
|
||||
"size_bytes": ms.size_bytes,
|
||||
}
|
||||
for member, ms in selected_member_stats.items()
|
||||
}
|
||||
details_df = create_member_details_table(member_details_dict)
|
||||
st.dataframe(details_df, hide_index=True, width="stretch")
|
||||
|
||||
# Individual member details
|
||||
for member, member_stats in selected_member_stats.items():
|
||||
st.markdown(f"### {member}")
|
||||
|
||||
# Variables
|
||||
st.markdown("**Variables:**")
|
||||
vars_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
||||
for v in member_stats.variable_names
|
||||
]
|
||||
)
|
||||
st.markdown(vars_html, unsafe_allow_html=True)
|
||||
|
||||
# Dimensions
|
||||
st.markdown("**Dimensions:**")
|
||||
dim_html = " ".join(
|
||||
[
|
||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
||||
f"{dim_name}: {dim_size:,}</span>"
|
||||
for dim_name, dim_size in member_stats.dimensions.items()
|
||||
]
|
||||
)
|
||||
st.markdown(dim_html, unsafe_allow_html=True)
|
||||
|
||||
st.markdown("---")
|
||||
# Render ensemble-specific details
|
||||
render_ensemble_details(selected_members, selected_member_stats)
|
||||
|
||||
|
||||
@st.fragment
|
||||
|
|
@ -650,7 +672,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
|||
# Use pre-computed statistics (much faster)
|
||||
stats = all_stats[grid_level_key][selected_temporal_mode] # type: ignore[literal-required,index]
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
|
||||
m: stats.members[m] for m in selected_members if m in stats.members
|
||||
}
|
||||
else:
|
||||
# Create the actual ensemble with selected members and dimension filters
|
||||
ensemble = DatasetEnsemble(
|
||||
|
|
@ -663,7 +687,9 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
|||
# Recompute stats with the filtered ensemble
|
||||
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||
# Filter to selected members only
|
||||
selected_member_stats = {m: stats.members[m] for m in selected_members if m in stats.members}
|
||||
selected_member_stats: dict[L2SourceDataset, MemberStatistics] = {
|
||||
m: stats.members[m] for m in selected_members if m in stats.members
|
||||
}
|
||||
|
||||
# Render configuration summary and statistics
|
||||
_render_configuration_summary(
|
||||
|
|
@ -678,16 +704,16 @@ def render_configuration_explorer_tab(all_stats: DatasetStatsCache):
|
|||
st.info("👆 Select at least one data source to see feature statistics")
|
||||
|
||||
|
||||
def render_dataset_statistics(all_stats: DatasetStatsCache):
|
||||
def render_dataset_statistics(
|
||||
all_stats: DatasetStatsCache,
|
||||
training_sample_df: pd.DataFrame,
|
||||
feature_breakdown_df: pd.DataFrame,
|
||||
comparison_df: pd.DataFrame,
|
||||
inference_sample_df: pd.DataFrame,
|
||||
):
|
||||
"""Render the dataset statistics section with sample and feature counts."""
|
||||
st.header("📈 Dataset Statistics")
|
||||
|
||||
all_stats: DatasetStatsCache = load_all_default_dataset_statistics()
|
||||
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
||||
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
||||
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
||||
|
||||
# Create tabs for different analysis views
|
||||
analysis_tabs = st.tabs(
|
||||
[
|
||||
|
|
|
|||
225
src/entropice/dashboard/sections/targets.py
Normal file
225
src/entropice/dashboard/sections/targets.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Target dataset dashboard section."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.targets import create_target_distribution_plot, create_target_spatial_distribution_map
|
||||
from entropice.dashboard.utils.colors import get_cmap, get_palette, hex_to_rgb
|
||||
from entropice.ml.dataset import TrainingSet
|
||||
from entropice.utils.types import TargetDataset, Task
|
||||
|
||||
|
||||
def _render_distribution(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||
st.subheader("Target Labels Distribution")
|
||||
fig = create_target_distribution_plot(train_data_dict)
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
st.markdown(
|
||||
"""
|
||||
The above plot shows the distribution of target labels in the training and test datasets.
|
||||
You can use this information to understand the balance of classes and identify any potential
|
||||
issues with class imbalance that may affect model training.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _render_split_legend(training_set: TrainingSet):
|
||||
"""Render legend for train/test split visualization."""
|
||||
st.markdown("**Data Split:**")
|
||||
split_colors = get_palette("split", 2)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
color_rgb = hex_to_rgb(split_colors[0])
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center;">'
|
||||
f'<div style="width: 20px; height: 20px; '
|
||||
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
|
||||
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||
f"<span>Train</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
with col2:
|
||||
color_rgb = hex_to_rgb(split_colors[1])
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center;">'
|
||||
f'<div style="width: 20px; height: 20px; '
|
||||
f"background-color: rgb({color_rgb[0]}, {color_rgb[1]}, {color_rgb[2]}); "
|
||||
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
||||
f"<span>Test</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Show split statistics
|
||||
train_count = (training_set.split == "train").sum()
|
||||
test_count = (training_set.split == "test").sum()
|
||||
total = len(training_set)
|
||||
st.caption(
|
||||
f"Train: {train_count:,} ({train_count / total * 100:.1f}%) | "
|
||||
f"Test: {test_count:,} ({test_count / total * 100:.1f}%)"
|
||||
)
|
||||
|
||||
|
||||
def _render_classification_legend(training_set: TrainingSet, selected_task: Task):
|
||||
"""Render legend for classification task visualization."""
|
||||
st.markdown("**Target Classes:**")
|
||||
categories = training_set.targets["y"].cat.categories.tolist()
|
||||
colors_palette = get_palette(selected_task, len(categories))
|
||||
intervals = training_set.target_intervals
|
||||
|
||||
# Show categories with colors and intervals
|
||||
for i, cat in enumerate(categories):
|
||||
color = colors_palette[i]
|
||||
interval_min, interval_max = intervals[i]
|
||||
|
||||
# Format interval display
|
||||
interval_str = _format_interval(interval_min, interval_max, selected_task)
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span>{cat}{interval_str}</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Show total samples
|
||||
st.caption(f"Total samples: {len(training_set):,}")
|
||||
|
||||
|
||||
def _format_interval(interval_min, interval_max, selected_task: Task) -> str:
|
||||
"""Format interval string based on task type."""
|
||||
if interval_min is None or interval_max is None:
|
||||
return ""
|
||||
|
||||
if selected_task in ["count", "count_regimes"]:
|
||||
# Integer values for count
|
||||
if interval_min == interval_max:
|
||||
return f" ({int(interval_min)})"
|
||||
return f" ({int(interval_min)}-{int(interval_max)})"
|
||||
|
||||
if selected_task in ["density", "density_regimes"]:
|
||||
# Percentage values for density
|
||||
if interval_min == interval_max:
|
||||
return f" ({interval_min * 100:.4f}%)"
|
||||
return f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
|
||||
|
||||
# Binary or other
|
||||
return ""
|
||||
|
||||
|
||||
def _render_regression_legend(training_set: TrainingSet, selected_task: Task):
|
||||
"""Render legend for regression task visualization."""
|
||||
st.markdown("**Target Values:**")
|
||||
z_values = training_set.targets["z"]
|
||||
z_min, z_max = z_values.min(), z_values.max()
|
||||
z_mean = z_values.mean()
|
||||
z_median = z_values.median()
|
||||
|
||||
st.markdown(f"- **Range:** {z_min:.4f} to {z_max:.4f}\n- **Mean:** {z_mean:.4f}\n- **Median:** {z_median:.4f}")
|
||||
|
||||
# Show gradient visualization using the actual colormap
|
||||
st.markdown("**Color Gradient:**")
|
||||
|
||||
# Get the same colormap used in the map
|
||||
cmap = get_cmap(selected_task)
|
||||
# Sample colors from the colormap
|
||||
n_samples = 10
|
||||
gradient_colors = [mcolors.to_hex(cmap(i / (n_samples - 1))) for i in range(n_samples)]
|
||||
gradient_str = ", ".join(gradient_colors)
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-top: 8px;">'
|
||||
f'<span style="margin-right: 8px;">Low ({z_min:.4f})</span>'
|
||||
f'<div style="width: 200px; height: 20px; '
|
||||
f"background: linear-gradient(to right, {gradient_str}); "
|
||||
f'border: 1px solid #ccc;"></div>'
|
||||
f'<span style="margin-left: 8px;">High ({z_max:.4f})</span>'
|
||||
f"</div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
def _render_elevation_info(training_set: TrainingSet):
|
||||
"""Render 3D elevation information."""
|
||||
st.markdown("---")
|
||||
st.markdown("**3D Elevation:**")
|
||||
|
||||
is_classification = training_set.targets["y"].dtype.name == "category"
|
||||
if is_classification:
|
||||
st.info("💡 Height represents category level. Rotate the map by holding Ctrl/Cmd and dragging with your mouse.")
|
||||
else:
|
||||
z_values = training_set.targets["z"]
|
||||
z_min, z_max = z_values.min(), z_values.max()
|
||||
st.info(
|
||||
f"💡 Height represents target value (normalized): {z_min:.4f} (low) → {z_max:.4f} (high). "
|
||||
"Rotate the map by holding Ctrl/Cmd and dragging with your mouse."
|
||||
)
|
||||
|
||||
|
||||
@st.fragment
|
||||
def _render_target_map(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||
st.subheader("Target Labels Spatial Distribution")
|
||||
|
||||
cols = st.columns([2, 2, 1])
|
||||
with cols[0]:
|
||||
selected_target = cast(
|
||||
TargetDataset,
|
||||
st.selectbox(
|
||||
"Select Target Dataset",
|
||||
options=sorted(train_data_dict.keys()),
|
||||
index=0,
|
||||
),
|
||||
)
|
||||
with cols[1]:
|
||||
selected_task = cast(
|
||||
Task,
|
||||
st.selectbox(
|
||||
"Select Task",
|
||||
options=sorted(train_data_dict[selected_target].keys()),
|
||||
index=0,
|
||||
),
|
||||
)
|
||||
with cols[2]:
|
||||
# Controls weather a 3D map or a 2D map is shown
|
||||
make_3d_map = cast(bool, st.checkbox("3D Map", value=True))
|
||||
# Controls what should be shows, either the split or the labels / values
|
||||
show_split = cast(bool, st.checkbox("Show Train/Test Split", value=False))
|
||||
|
||||
training_set = train_data_dict[selected_target][selected_task]
|
||||
map_deck = create_target_spatial_distribution_map(training_set, make_3d_map, show_split, selected_task)
|
||||
st.pydeck_chart(map_deck)
|
||||
|
||||
# Add legend
|
||||
with st.expander("🎨 Legend", expanded=True):
|
||||
if show_split:
|
||||
_render_split_legend(training_set)
|
||||
else:
|
||||
# Show target labels legend
|
||||
is_classification = training_set.targets["y"].dtype.name == "category"
|
||||
if is_classification:
|
||||
_render_classification_legend(training_set, selected_task)
|
||||
else:
|
||||
_render_regression_legend(training_set, selected_task)
|
||||
|
||||
# Add elevation info for 3D maps
|
||||
if make_3d_map:
|
||||
_render_elevation_info(training_set)
|
||||
|
||||
|
||||
def render_target_information_tab(train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]]):
|
||||
"""Render target labels distribution and spatial visualization.
|
||||
|
||||
Args:
|
||||
train_data_dict: Nested dictionary with structure:
|
||||
{target_dataset: {task: TrainingSet}}
|
||||
where target_dataset is in ["darts_v1", "darts_mllabels"]
|
||||
and task is in ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||
|
||||
"""
|
||||
_render_distribution(train_data_dict)
|
||||
|
||||
st.divider()
|
||||
|
||||
_render_target_map(train_data_dict)
|
||||
|
|
@ -31,6 +31,20 @@ import streamlit as st
|
|||
from pypalettes import load_cmap
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> list[int]:
|
||||
"""Convert hex color to RGB list.
|
||||
|
||||
Args:
|
||||
hex_color: Hex color string (e.g., '#FF5733').
|
||||
|
||||
Returns:
|
||||
List of [R, G, B] values (0-255).
|
||||
|
||||
"""
|
||||
hex_color = hex_color.lstrip("#")
|
||||
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
||||
|
||||
|
||||
def get_cmap(variable: str) -> mcolors.Colormap:
|
||||
"""Get a color palette by a "data" variable.
|
||||
|
||||
|
|
|
|||
155
src/entropice/dashboard/views/dataset_page.py
Normal file
155
src/entropice/dashboard/views/dataset_page.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Data page: Visualization of the data."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import streamlit as st
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.sections.dataset_statistics import render_ensemble_details
|
||||
from entropice.dashboard.sections.targets import render_target_information_tab
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics
|
||||
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||
from entropice.utils.types import (
|
||||
GridConfig,
|
||||
L2SourceDataset,
|
||||
TargetDataset,
|
||||
Task,
|
||||
TemporalMode,
|
||||
all_target_datasets,
|
||||
all_tasks,
|
||||
grid_configs,
|
||||
)
|
||||
|
||||
|
||||
def render_dataset_configuration_sidebar() -> DatasetEnsemble:
|
||||
"""Render dataset configuration selector in sidebar with form.
|
||||
|
||||
Stores the selected ensemble in session state when form is submitted.
|
||||
"""
|
||||
with st.sidebar.form("dataset_config_form"):
|
||||
st.header("Dataset Configuration")
|
||||
|
||||
# Grid selection
|
||||
grid_options = [gc.display_name for gc in grid_configs]
|
||||
|
||||
grid_level_combined = st.selectbox(
|
||||
"Grid Configuration",
|
||||
options=grid_options,
|
||||
index=0,
|
||||
help="Select the grid system and resolution level",
|
||||
)
|
||||
|
||||
# Find the selected grid config
|
||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
||||
|
||||
# Temporal mode selection
|
||||
temporal_mode = st.selectbox(
|
||||
"Temporal Mode",
|
||||
options=cast(list[TemporalMode], ["synopsis", "feature", 2018, 2019, 2020, 2021, 2022, 2023]),
|
||||
index=0,
|
||||
format_func=lambda x: "Synopsis (all years)"
|
||||
if x == "synopsis"
|
||||
else "Years-as-Features"
|
||||
if x == "feature"
|
||||
else f"Year {x}",
|
||||
help="Select temporal mode: 'synopsis' for temporal features or specific year",
|
||||
)
|
||||
|
||||
# Members selection
|
||||
st.subheader("Dataset Members")
|
||||
|
||||
all_members = cast(
|
||||
list[L2SourceDataset],
|
||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
||||
)
|
||||
selected_members: list[L2SourceDataset] = []
|
||||
|
||||
for member in all_members:
|
||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||
selected_members.append(member)
|
||||
|
||||
# Form submit button
|
||||
submitted = st.form_submit_button(
|
||||
"Load Dataset",
|
||||
type="primary",
|
||||
use_container_width=True,
|
||||
disabled=len(selected_members) == 0,
|
||||
)
|
||||
|
||||
if not submitted:
|
||||
st.info("👆 Click 'Load Dataset' to apply changes.")
|
||||
st.stop()
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=selected_grid_config.grid,
|
||||
level=selected_grid_config.level,
|
||||
temporal_mode=cast(TemporalMode, temporal_mode),
|
||||
members=selected_members,
|
||||
)
|
||||
return ensemble
|
||||
|
||||
|
||||
def render_dataset_page():
|
||||
"""Render the Dataset page of the dashboard."""
|
||||
st.title("📊 Entropice Data Visualization")
|
||||
st.markdown(
|
||||
"""
|
||||
This section allows you to visualize the source datasets and training data
|
||||
used in your experiments. Use the sidebar to configure the dataset parameters.
|
||||
"""
|
||||
)
|
||||
|
||||
# Get the selected dataset ensemble from sidebar
|
||||
ensemble = render_dataset_configuration_sidebar()
|
||||
|
||||
# Display dataset ID in a styled container
|
||||
st.info(f"Loaded Dataset with ID `{ensemble.id()}`")
|
||||
|
||||
st.divider()
|
||||
|
||||
# Render dataset statistics section
|
||||
stats = DatasetStatistics.from_ensemble(ensemble)
|
||||
render_ensemble_details(ensemble.members, stats.members)
|
||||
|
||||
st.divider()
|
||||
|
||||
# Load data and precompute visualizations
|
||||
# First, load for all task - target combinations the training data
|
||||
train_data_dict: dict[TargetDataset, dict[Task, TrainingSet]] = {}
|
||||
for target in all_target_datasets:
|
||||
train_data_dict[target] = {}
|
||||
for task in all_tasks:
|
||||
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
||||
|
||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||
# Create tabs for different data views
|
||||
tab_names = ["🎯 Targets", "📐 Areas"]
|
||||
# Add tabs for each member based on what's in the ensemble
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
tab_names.append("🌍 AlphaEarth")
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
tab_names.append("🏔️ ArcticDEM")
|
||||
if era5_members:
|
||||
tab_names.append("🌡️ ERA5")
|
||||
tabs = st.tabs(tab_names)
|
||||
|
||||
with tabs[0]:
|
||||
st.header("🎯 Target Labels Visualization")
|
||||
render_target_information_tab(train_data_dict)
|
||||
with tabs[1]:
|
||||
st.header("📐 Areas Visualization")
|
||||
tab_index = 2
|
||||
if "AlphaEarth" in ensemble.members:
|
||||
with tabs[tab_index]:
|
||||
st.header("🌍 AlphaEarth Visualization")
|
||||
tab_index += 1
|
||||
if "ArcticDEM" in ensemble.members:
|
||||
with tabs[tab_index]:
|
||||
st.header("🏔️ ArcticDEM Visualization")
|
||||
tab_index += 1
|
||||
if era5_members:
|
||||
with tabs[tab_index]:
|
||||
st.header("🌡️ ERA5 Visualization")
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
@ -9,9 +9,7 @@ from entropice.dashboard.sections.experiment_results import (
|
|||
render_training_results_summary,
|
||||
)
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.dashboard.utils.stats import (
|
||||
load_all_default_dataset_statistics,
|
||||
)
|
||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||
|
||||
|
||||
def render_overview_page():
|
||||
|
|
@ -47,7 +45,12 @@ def render_overview_page():
|
|||
|
||||
# Render dataset analysis section
|
||||
all_stats = load_all_default_dataset_statistics()
|
||||
render_dataset_statistics(all_stats)
|
||||
training_sample_df = DatasetStatistics.get_target_sample_count_df(all_stats)
|
||||
feature_breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
|
||||
comparison_df = DatasetStatistics.get_comparison_df(all_stats)
|
||||
inference_sample_df = DatasetStatistics.get_inference_sample_count_df(all_stats)
|
||||
|
||||
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
|
||||
|
||||
st.balloons()
|
||||
stopwatch.summary()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue