Add a Storage Analysis Section
This commit is contained in:
parent
636c034b55
commit
2664579a75
9 changed files with 370 additions and 1959 deletions
|
|
@ -18,7 +18,6 @@ from entropice.dashboard.views.inference_page import render_inference_page
|
||||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||||
from entropice.dashboard.views.overview_page import render_overview_page
|
from entropice.dashboard.views.overview_page import render_overview_page
|
||||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||||
from entropice.dashboard.views.training_data_page import render_training_data_page
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -28,7 +27,6 @@ def main():
|
||||||
# Setup Navigation
|
# Setup Navigation
|
||||||
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
|
||||||
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
data_page = st.Page(render_dataset_page, title="Dataset", icon="📊")
|
||||||
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
|
|
||||||
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
|
||||||
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
autogluon_page = st.Page(render_autogluon_analysis_page, title="AutoGluon Analysis", icon="🤖")
|
||||||
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
|
||||||
|
|
@ -38,8 +36,7 @@ def main():
|
||||||
{
|
{
|
||||||
"Overview": [overview_page],
|
"Overview": [overview_page],
|
||||||
"Data": [data_page],
|
"Data": [data_page],
|
||||||
"Training": [training_data_page, training_analysis_page, autogluon_page],
|
"Experiments": [training_analysis_page, autogluon_page, model_state_page],
|
||||||
"Model State": [model_state_page],
|
|
||||||
"Inference": [inference_page],
|
"Inference": [inference_page],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,366 +0,0 @@
|
||||||
"""Plotting functions for training data visualizations."""
|
|
||||||
|
|
||||||
import geopandas as gpd
|
|
||||||
import pandas as pd
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
import pydeck as pdk
|
|
||||||
import streamlit as st
|
|
||||||
|
|
||||||
from entropice.dashboard.utils.colors import get_palette
|
|
||||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
|
||||||
from entropice.ml.dataset import CategoricalTrainingDataset
|
|
||||||
|
|
||||||
|
|
||||||
def render_all_distribution_histograms(
|
|
||||||
train_data_dict: dict[str, CategoricalTrainingDataset],
|
|
||||||
):
|
|
||||||
"""Render histograms for all three tasks side by side.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.subheader("📊 Target Distribution by Task")
|
|
||||||
|
|
||||||
# Create a 3-column layout for the three tasks
|
|
||||||
cols = st.columns(3)
|
|
||||||
|
|
||||||
tasks = ["binary", "count", "density"]
|
|
||||||
task_titles = {
|
|
||||||
"binary": "Binary Classification",
|
|
||||||
"count": "Count Classification",
|
|
||||||
"density": "Density Classification",
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, task in enumerate(tasks):
|
|
||||||
dataset = train_data_dict[task]
|
|
||||||
categories = dataset.y.binned.cat.categories.tolist()
|
|
||||||
colors = get_palette(task, len(categories))
|
|
||||||
|
|
||||||
with cols[idx]:
|
|
||||||
st.markdown(f"**{task_titles[task]}**")
|
|
||||||
|
|
||||||
# Create histogram data
|
|
||||||
counts_df = pd.DataFrame(
|
|
||||||
{
|
|
||||||
"Category": categories,
|
|
||||||
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
|
|
||||||
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create stacked bar chart
|
|
||||||
fig = go.Figure()
|
|
||||||
|
|
||||||
fig.add_trace(
|
|
||||||
go.Bar(
|
|
||||||
name="Train",
|
|
||||||
x=counts_df["Category"],
|
|
||||||
y=counts_df["Train"],
|
|
||||||
marker_color=colors,
|
|
||||||
opacity=0.9,
|
|
||||||
text=counts_df["Train"],
|
|
||||||
textposition="inside",
|
|
||||||
textfont={"size": 10, "color": "white"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.add_trace(
|
|
||||||
go.Bar(
|
|
||||||
name="Test",
|
|
||||||
x=counts_df["Category"],
|
|
||||||
y=counts_df["Test"],
|
|
||||||
marker_color=colors,
|
|
||||||
opacity=0.6,
|
|
||||||
text=counts_df["Test"],
|
|
||||||
textposition="inside",
|
|
||||||
textfont={"size": 10, "color": "white"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
barmode="group",
|
|
||||||
height=400,
|
|
||||||
margin={"l": 20, "r": 20, "t": 20, "b": 20},
|
|
||||||
showlegend=True,
|
|
||||||
legend={
|
|
||||||
"orientation": "h",
|
|
||||||
"yanchor": "bottom",
|
|
||||||
"y": 1.02,
|
|
||||||
"xanchor": "right",
|
|
||||||
"x": 1,
|
|
||||||
},
|
|
||||||
xaxis_title=None,
|
|
||||||
yaxis_title="Count",
|
|
||||||
xaxis={"tickangle": -45},
|
|
||||||
)
|
|
||||||
|
|
||||||
st.plotly_chart(fig, width="stretch")
|
|
||||||
|
|
||||||
# Show summary statistics
|
|
||||||
total = len(dataset)
|
|
||||||
train_pct = (dataset.split == "train").sum() / total * 100
|
|
||||||
test_pct = (dataset.split == "test").sum() / total * 100
|
|
||||||
|
|
||||||
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
|
|
||||||
|
|
||||||
|
|
||||||
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
|
|
||||||
"""Assign colors to geodataframe based on the selected color mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gdf: GeoDataFrame to add colors to
|
|
||||||
color_mode: One of 'target_class' or 'split'
|
|
||||||
dataset: CategoricalTrainingDataset
|
|
||||||
selected_task: Task name for color palette selection
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GeoDataFrame with 'fill_color' column added
|
|
||||||
|
|
||||||
"""
|
|
||||||
if color_mode == "target_class":
|
|
||||||
categories = dataset.y.binned.cat.categories.tolist()
|
|
||||||
colors_palette = get_palette(selected_task, len(categories))
|
|
||||||
|
|
||||||
# Create color mapping
|
|
||||||
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
|
|
||||||
gdf["color"] = gdf["target_class"].map(color_map)
|
|
||||||
|
|
||||||
# Convert hex colors to RGB
|
|
||||||
def hex_to_rgb(hex_color):
|
|
||||||
hex_color = hex_color.lstrip("#")
|
|
||||||
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
|
|
||||||
|
|
||||||
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
|
|
||||||
|
|
||||||
elif color_mode == "split":
|
|
||||||
split_colors = {
|
|
||||||
"train": [66, 135, 245],
|
|
||||||
"test": [245, 135, 66],
|
|
||||||
} # Blue # Orange
|
|
||||||
gdf["fill_color"] = gdf["split"].map(split_colors)
|
|
||||||
|
|
||||||
return gdf
|
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
|
||||||
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
|
|
||||||
"""Render a pydeck spatial map showing training data distribution with interactive controls.
|
|
||||||
|
|
||||||
This is a Streamlit fragment that reruns independently when users interact with the
|
|
||||||
visualization controls (color mode and opacity), without re-running the entire page.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.subheader("🗺️ Spatial Distribution Map")
|
|
||||||
|
|
||||||
# Create controls in columns
|
|
||||||
col1, col2 = st.columns([3, 1])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
vis_mode = st.selectbox(
|
|
||||||
"Visualization mode",
|
|
||||||
options=["binary", "count", "density", "split"],
|
|
||||||
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
|
|
||||||
key="spatial_map_mode",
|
|
||||||
)
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
opacity = st.slider(
|
|
||||||
"Opacity",
|
|
||||||
min_value=0.1,
|
|
||||||
max_value=1.0,
|
|
||||||
value=0.7,
|
|
||||||
step=0.1,
|
|
||||||
key="spatial_map_opacity",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine which task dataset to use and color mode
|
|
||||||
if vis_mode == "split":
|
|
||||||
# Use binary dataset for split visualization
|
|
||||||
dataset = train_data_dict["binary"]
|
|
||||||
color_mode = "split"
|
|
||||||
selected_task = "binary"
|
|
||||||
else:
|
|
||||||
# Use the selected task
|
|
||||||
dataset = train_data_dict[vis_mode]
|
|
||||||
color_mode = "target_class"
|
|
||||||
selected_task = vis_mode
|
|
||||||
|
|
||||||
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
|
|
||||||
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Fix antimeridian issues
|
|
||||||
gdf["geometry"] = gdf["geometry"].apply(fix_hex_geometry)
|
|
||||||
|
|
||||||
# Add binned labels and split information from current dataset
|
|
||||||
gdf["target_class"] = dataset.y.binned.to_numpy()
|
|
||||||
gdf["split"] = dataset.split.to_numpy()
|
|
||||||
gdf["raw_value"] = dataset.z.to_numpy()
|
|
||||||
|
|
||||||
# Add information from all three tasks for tooltip
|
|
||||||
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
|
|
||||||
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
|
|
||||||
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
|
|
||||||
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
|
|
||||||
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
|
|
||||||
|
|
||||||
# Convert to WGS84 for pydeck
|
|
||||||
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Assign colors based on the selected mode
|
|
||||||
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
|
|
||||||
|
|
||||||
# Convert to GeoJSON format and add elevation for 3D visualization
|
|
||||||
geojson_data = []
|
|
||||||
# Normalize raw values for elevation (only for count and density)
|
|
||||||
use_elevation = vis_mode in ["count", "density"]
|
|
||||||
if use_elevation:
|
|
||||||
raw_values = gdf_wgs84["raw_value"]
|
|
||||||
min_val, max_val = raw_values.min(), raw_values.max()
|
|
||||||
# Normalize to 0-1 range for better 3D visualization
|
|
||||||
if max_val > min_val:
|
|
||||||
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
|
|
||||||
else:
|
|
||||||
gdf_wgs84["elevation"] = 0
|
|
||||||
|
|
||||||
for _, row in gdf_wgs84.iterrows():
|
|
||||||
feature = {
|
|
||||||
"type": "Feature",
|
|
||||||
"geometry": row["geometry"].__geo_interface__,
|
|
||||||
"properties": {
|
|
||||||
"target_class": str(row["target_class"]),
|
|
||||||
"split": str(row["split"]),
|
|
||||||
"raw_value": float(row["raw_value"]),
|
|
||||||
"fill_color": row["fill_color"],
|
|
||||||
"elevation": float(row["elevation"]) if use_elevation else 0,
|
|
||||||
"binary_label": str(row["binary_label"]),
|
|
||||||
"count_category": str(row["count_category"]),
|
|
||||||
"count_raw": int(row["count_raw"]),
|
|
||||||
"density_category": str(row["density_category"]),
|
|
||||||
"density_raw": f"{float(row['density_raw']):.4f}",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
geojson_data.append(feature)
|
|
||||||
|
|
||||||
# Create pydeck layer
|
|
||||||
layer = pdk.Layer(
|
|
||||||
"GeoJsonLayer",
|
|
||||||
geojson_data,
|
|
||||||
opacity=opacity,
|
|
||||||
stroked=True,
|
|
||||||
filled=True,
|
|
||||||
extruded=use_elevation,
|
|
||||||
wireframe=False,
|
|
||||||
get_fill_color="properties.fill_color",
|
|
||||||
get_line_color=[80, 80, 80],
|
|
||||||
line_width_min_pixels=0.5,
|
|
||||||
get_elevation="properties.elevation" if use_elevation else 0,
|
|
||||||
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
|
|
||||||
pickable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set initial view state (centered on the Arctic)
|
|
||||||
# Adjust pitch and zoom based on whether we're using elevation
|
|
||||||
view_state = pdk.ViewState(
|
|
||||||
latitude=70,
|
|
||||||
longitude=0,
|
|
||||||
zoom=2 if not use_elevation else 1.5,
|
|
||||||
pitch=0 if not use_elevation else 45,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create deck
|
|
||||||
deck = pdk.Deck(
|
|
||||||
layers=[layer],
|
|
||||||
initial_view_state=view_state,
|
|
||||||
tooltip={
|
|
||||||
"html": "<b>Binary:</b> {binary_label}<br/>"
|
|
||||||
"<b>Count Category:</b> {count_category}<br/>"
|
|
||||||
"<b>Count Raw:</b> {count_raw}<br/>"
|
|
||||||
"<b>Density Category:</b> {density_category}<br/>"
|
|
||||||
"<b>Density Raw:</b> {density_raw}<br/>"
|
|
||||||
"<b>Split:</b> {split}",
|
|
||||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
|
||||||
},
|
|
||||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render the map
|
|
||||||
st.pydeck_chart(deck)
|
|
||||||
|
|
||||||
# Show info about 3D visualization
|
|
||||||
if use_elevation:
|
|
||||||
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
|
|
||||||
|
|
||||||
# Add legend
|
|
||||||
with st.expander("Legend", expanded=True):
|
|
||||||
if color_mode == "target_class":
|
|
||||||
st.markdown("**Target Classes:**")
|
|
||||||
categories = dataset.y.binned.cat.categories.tolist()
|
|
||||||
colors_palette = get_palette(selected_task, len(categories))
|
|
||||||
intervals = dataset.y.intervals
|
|
||||||
|
|
||||||
# For count and density tasks, show intervals
|
|
||||||
if selected_task in ["count", "density"]:
|
|
||||||
for i, cat in enumerate(categories):
|
|
||||||
color = colors_palette[i]
|
|
||||||
interval_min, interval_max = intervals[i]
|
|
||||||
|
|
||||||
# Format interval display
|
|
||||||
if interval_min is None or interval_max is None:
|
|
||||||
interval_str = ""
|
|
||||||
elif selected_task == "count":
|
|
||||||
# Integer values for count
|
|
||||||
if interval_min == interval_max:
|
|
||||||
interval_str = f" ({int(interval_min)})"
|
|
||||||
else:
|
|
||||||
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
|
|
||||||
else: # density
|
|
||||||
# Percentage values for density
|
|
||||||
if interval_min == interval_max:
|
|
||||||
interval_str = f" ({interval_min * 100:.4f}%)"
|
|
||||||
else:
|
|
||||||
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
|
||||||
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
|
||||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
|
||||||
f"<span>{cat}{interval_str}</span></div>",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Binary task: use original column layout
|
|
||||||
legend_cols = st.columns(len(categories))
|
|
||||||
for i, cat in enumerate(categories):
|
|
||||||
with legend_cols[i]:
|
|
||||||
color = colors_palette[i]
|
|
||||||
st.markdown(
|
|
||||||
f'<div style="display: flex; align-items: center;">'
|
|
||||||
f'<div style="width: 20px; height: 20px; background-color: {color}; '
|
|
||||||
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
||||||
f"<span>{cat}</span></div>",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
if use_elevation:
|
|
||||||
st.markdown("---")
|
|
||||||
st.markdown("**Elevation (3D):**")
|
|
||||||
min_val = gdf_wgs84["raw_value"].min()
|
|
||||||
max_val = gdf_wgs84["raw_value"].max()
|
|
||||||
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
|
|
||||||
elif color_mode == "split":
|
|
||||||
st.markdown("**Data Split:**")
|
|
||||||
legend_html = (
|
|
||||||
'<div style="display: flex; gap: 20px;">'
|
|
||||||
'<div style="display: flex; align-items: center;">'
|
|
||||||
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
|
|
||||||
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
||||||
"<span>Train</span></div>"
|
|
||||||
'<div style="display: flex; align-items: center;">'
|
|
||||||
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
|
|
||||||
'margin-right: 8px; border: 1px solid #ccc;"></div>'
|
|
||||||
"<span>Test</span></div></div>"
|
|
||||||
)
|
|
||||||
st.markdown(legend_html, unsafe_allow_html=True)
|
|
||||||
|
|
@ -431,7 +431,7 @@ def _render_aggregation_selection(
|
||||||
|
|
||||||
if not submitted:
|
if not submitted:
|
||||||
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
|
st.info("👆 Click 'Apply Aggregation Filters' to update the configuration")
|
||||||
st.stop()
|
return dimension_filters
|
||||||
|
|
||||||
return dimension_filters
|
return dimension_filters
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,8 +103,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expandable details for each result
|
# Expandable details for each result
|
||||||
st.subheader("Individual Experiment Details")
|
with st.expander("Show Individual Experiment Details", expanded=False):
|
||||||
|
|
||||||
for tr in filtered_results:
|
for tr in filtered_results:
|
||||||
tr_info = tr.display_info
|
tr_info = tr.display_info
|
||||||
display_name = tr_info.get_display_name("model_first")
|
display_name = tr_info.get_display_name("model_first")
|
||||||
|
|
@ -160,7 +159,7 @@ def render_experiment_results(training_results: list[TrainingResult]): # noqa:
|
||||||
unique_vals = tr.results[param].nunique()
|
unique_vals = tr.results[param].nunique()
|
||||||
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
|
||||||
|
|
||||||
with st.expander("Show CV Results DataFrame"):
|
st.write("**CV Results DataFrame:**")
|
||||||
st.dataframe(tr.results, width="stretch", hide_index=True)
|
st.dataframe(tr.results, width="stretch", hide_index=True)
|
||||||
|
|
||||||
st.write(f"\n**Path:** `{tr.path}`")
|
st.write(f"\n**Path:** `{tr.path}`")
|
||||||
|
|
|
||||||
163
src/entropice/dashboard/sections/storage_statistics.py
Normal file
163
src/entropice/dashboard/sections/storage_statistics.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Storage Statistics Section for Entropice Dashboard."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.utils.loaders import StorageInfo, load_storage_statistics
|
||||||
|
from entropice.utils.paths import DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _format_bytes(bytes_value: int) -> str:
|
||||||
|
"""Format bytes into human-readable string."""
|
||||||
|
value = float(bytes_value)
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||||
|
if value < 1024.0:
|
||||||
|
return f"{value:.2f} {unit}"
|
||||||
|
value /= 1024.0
|
||||||
|
return f"{value:.2f} PB"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_storage_bar_chart(storage_infos: list[StorageInfo]) -> go.Figure:
|
||||||
|
"""Create a horizontal bar chart showing storage usage by subdirectory."""
|
||||||
|
if not storage_infos:
|
||||||
|
return go.Figure()
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
names = [info.name for info in storage_infos]
|
||||||
|
sizes = [info.size_bytes / (1024**3) for info in storage_infos] # Convert to GB
|
||||||
|
file_counts = [info.file_count for info in storage_infos]
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Add bar trace
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
y=names,
|
||||||
|
x=sizes,
|
||||||
|
orientation="h",
|
||||||
|
text=[f"{s:.2f} GB" for s in sizes],
|
||||||
|
textposition="auto",
|
||||||
|
hovertemplate="<b>%{y}</b><br>Size: %{x:.2f} GB<br>Files: %{customdata:,}<extra></extra>",
|
||||||
|
customdata=file_counts,
|
||||||
|
marker={
|
||||||
|
"color": sizes,
|
||||||
|
"colorscale": "Blues",
|
||||||
|
"showscale": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layout
|
||||||
|
fig.update_layout(
|
||||||
|
title="Storage Usage by Subdirectory",
|
||||||
|
xaxis_title="Size (GB)",
|
||||||
|
yaxis_title="Directory",
|
||||||
|
height=max(400, len(names) * 40), # Dynamic height based on number of directories
|
||||||
|
showlegend=False,
|
||||||
|
margin={"l": 200, "r": 50, "t": 50, "b": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def _create_storage_pie_chart(storage_infos: list[StorageInfo]) -> go.Figure:
|
||||||
|
"""Create a pie chart showing storage distribution."""
|
||||||
|
if not storage_infos:
|
||||||
|
return go.Figure()
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
names = [info.name for info in storage_infos]
|
||||||
|
sizes = [info.size_bytes for info in storage_infos]
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
|
go.Pie(
|
||||||
|
labels=names,
|
||||||
|
values=sizes,
|
||||||
|
textinfo="label+percent",
|
||||||
|
hovertemplate="<b>%{label}</b><br>Size: %{customdata}<br>%{percent}<extra></extra>",
|
||||||
|
customdata=[info.display_size for info in storage_infos],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layout
|
||||||
|
fig.update_layout(
|
||||||
|
title="Storage Distribution",
|
||||||
|
height=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def render_storage_statistics():
|
||||||
|
"""Render the storage statistics section showing disk usage for DATA_DIR subdirectories."""
|
||||||
|
st.header("💾 Storage Statistics")
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f"""
|
||||||
|
This section shows the disk usage of subdirectories in the data directory:
|
||||||
|
**`{DATA_DIR}`**
|
||||||
|
|
||||||
|
Data is collected using [dust](https://github.com/bootandy/dust), a modern disk usage analyzer.
|
||||||
|
Statistics are cached for 5 minutes to reduce overhead.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load storage statistics
|
||||||
|
with st.spinner("Analyzing storage usage..."):
|
||||||
|
storage_infos, total_size, total_files = load_storage_statistics()
|
||||||
|
|
||||||
|
if not storage_infos:
|
||||||
|
st.warning("No storage data available. The data directory may be empty or inaccessible.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Display summary metrics
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Storage Used", _format_bytes(total_size))
|
||||||
|
with col2:
|
||||||
|
st.metric("Total Files", f"{total_files:,}")
|
||||||
|
with col3:
|
||||||
|
st.metric("Number of Subdirectories", len(storage_infos))
|
||||||
|
|
||||||
|
# Create tabs for different visualizations
|
||||||
|
tab1, tab2, tab3 = st.tabs(["📊 Bar Chart", "🥧 Pie Chart", "📋 Detailed Table"])
|
||||||
|
|
||||||
|
with tab1:
|
||||||
|
st.plotly_chart(_create_storage_bar_chart(storage_infos), use_container_width=True)
|
||||||
|
|
||||||
|
with tab2:
|
||||||
|
st.plotly_chart(_create_storage_pie_chart(storage_infos), use_container_width=True)
|
||||||
|
|
||||||
|
with tab3:
|
||||||
|
# Create DataFrame for detailed view
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"Directory": info.name,
|
||||||
|
"Size": info.display_size,
|
||||||
|
"Size (Bytes)": info.size_bytes,
|
||||||
|
"Files": info.file_count,
|
||||||
|
"Percentage": f"{(info.size_bytes / total_size * 100):.2f}%",
|
||||||
|
}
|
||||||
|
for info in storage_infos
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
st.dataframe(
|
||||||
|
df[["Directory", "Size", "Files", "Percentage"]],
|
||||||
|
use_container_width=True,
|
||||||
|
hide_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add download button for detailed data
|
||||||
|
st.download_button(
|
||||||
|
label="📥 Download Storage Statistics (CSV)",
|
||||||
|
data=df.to_csv(index=False),
|
||||||
|
file_name="entropice_storage_statistics.csv",
|
||||||
|
mime="text/csv",
|
||||||
|
)
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
"""Data utilities for Entropice dashboard."""
|
"""Data utilities for Entropice dashboard."""
|
||||||
|
|
||||||
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
import subprocess
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -252,3 +254,148 @@ def load_training_sets(ensemble: DatasetEnsemble) -> dict[TargetDataset, dict[Ta
|
||||||
for task in all_tasks:
|
for task in all_tasks:
|
||||||
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
train_data_dict[target][task] = ensemble.create_training_set(target=target, task=task)
|
||||||
return train_data_dict
|
return train_data_dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StorageInfo:
|
||||||
|
"""Storage information for a directory."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
size_bytes: int
|
||||||
|
file_count: int
|
||||||
|
display_size: str
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_size_to_bytes(size_str: str) -> int:
|
||||||
|
"""Convert dust's human-readable size string to bytes.
|
||||||
|
|
||||||
|
Examples: "92K" -> 92*1024, "1.5M" -> 1.5*1024*1024, "928B" -> 928
|
||||||
|
"""
|
||||||
|
size_str = size_str.strip().upper()
|
||||||
|
if not size_str:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Extract numeric part and unit
|
||||||
|
numeric_part = ""
|
||||||
|
unit = ""
|
||||||
|
for char in size_str:
|
||||||
|
if char.isdigit() or char == ".":
|
||||||
|
numeric_part += char
|
||||||
|
else:
|
||||||
|
unit += char
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = float(numeric_part) if numeric_part else 0
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Convert based on unit
|
||||||
|
unit = unit.strip()
|
||||||
|
multipliers = {
|
||||||
|
"B": 1,
|
||||||
|
"K": 1024,
|
||||||
|
"M": 1024**2,
|
||||||
|
"G": 1024**3,
|
||||||
|
"T": 1024**4,
|
||||||
|
"P": 1024**5,
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(value * multipliers.get(unit, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dust_command(data_dir: Path, for_files: bool = False) -> dict | None:
|
||||||
|
"""Run dust command and return parsed JSON output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: Directory to analyze
|
||||||
|
for_files: If True, count files (-f flag); if False, count disk space
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON dict or None if command failed
|
||||||
|
|
||||||
|
"""
|
||||||
|
cmd = ["dust", "-j", "-d", "1"]
|
||||||
|
if for_files:
|
||||||
|
cmd.append("-f")
|
||||||
|
cmd.append(str(data_dir))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
return json.loads(result.stdout)
|
||||||
|
except (subprocess.TimeoutExpired, json.JSONDecodeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_file_counts_lookup(files_data: dict | None) -> dict[str, int]:
|
||||||
|
"""Build lookup dict for file counts from dust JSON output."""
|
||||||
|
file_counts = {}
|
||||||
|
if files_data and "children" in files_data:
|
||||||
|
for child in files_data["children"]:
|
||||||
|
name = Path(child["name"]).name
|
||||||
|
count_str = child.get("size", "0")
|
||||||
|
file_counts[name] = _parse_size_to_bytes(count_str)
|
||||||
|
return file_counts
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data(ttl=300) # Cache for 5 minutes
|
||||||
|
def load_storage_statistics() -> tuple[list[StorageInfo], int, int]:
|
||||||
|
"""Load storage statistics for DATA_DIR subdirectories using dust.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (subdirectory stats list, total size in bytes, total file count)
|
||||||
|
|
||||||
|
"""
|
||||||
|
data_dir = entropice.utils.paths.DATA_DIR
|
||||||
|
|
||||||
|
if not data_dir.exists():
|
||||||
|
return [], 0, 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run dust for disk space and file counts
|
||||||
|
space_data = _run_dust_command(data_dir, for_files=False)
|
||||||
|
files_data = _run_dust_command(data_dir, for_files=True)
|
||||||
|
|
||||||
|
if not space_data:
|
||||||
|
st.warning("Failed to get storage statistics from dust")
|
||||||
|
return [], 0, 0
|
||||||
|
|
||||||
|
# Build lookup dict for file counts
|
||||||
|
file_counts = _build_file_counts_lookup(files_data)
|
||||||
|
|
||||||
|
# Extract subdirectory information from space data
|
||||||
|
storage_infos = []
|
||||||
|
total_size = 0
|
||||||
|
total_files = 0
|
||||||
|
|
||||||
|
if "children" in space_data:
|
||||||
|
for child in space_data["children"]:
|
||||||
|
full_path = child.get("name", "")
|
||||||
|
dir_name = Path(full_path).name
|
||||||
|
size_str = child.get("size", "0")
|
||||||
|
size_bytes = _parse_size_to_bytes(size_str)
|
||||||
|
file_count = file_counts.get(dir_name, 0)
|
||||||
|
|
||||||
|
storage_infos.append(
|
||||||
|
StorageInfo(
|
||||||
|
name=dir_name,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
file_count=file_count,
|
||||||
|
display_size=size_str,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_size += size_bytes
|
||||||
|
total_files += file_count
|
||||||
|
|
||||||
|
# Sort by size descending
|
||||||
|
storage_infos.sort(key=lambda x: x.size_bytes, reverse=True)
|
||||||
|
|
||||||
|
return storage_infos, total_size, total_files
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
st.error("dust command not found. Please install dust: https://github.com/bootandy/dust")
|
||||||
|
return [], 0, 0
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error getting storage statistics: {e}")
|
||||||
|
return [], 0, 0
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from entropice.dashboard.sections.experiment_results import (
|
||||||
render_experiment_results,
|
render_experiment_results,
|
||||||
render_training_results_summary,
|
render_training_results_summary,
|
||||||
)
|
)
|
||||||
|
from entropice.dashboard.sections.storage_statistics import render_storage_statistics
|
||||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||||
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
from entropice.dashboard.utils.stats import DatasetStatistics, load_all_default_dataset_statistics
|
||||||
|
|
||||||
|
|
@ -52,5 +53,10 @@ def render_overview_page():
|
||||||
|
|
||||||
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
|
render_dataset_statistics(all_stats, training_sample_df, feature_breakdown_df, comparison_df, inference_sample_df)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Render storage statistics section
|
||||||
|
render_storage_statistics()
|
||||||
|
|
||||||
st.balloons()
|
st.balloons()
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
|
||||||
|
|
@ -1,481 +0,0 @@
|
||||||
"""Training Data page: Visualization of training data distributions."""
|
|
||||||
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import streamlit as st
|
|
||||||
from stopuhr import stopwatch
|
|
||||||
|
|
||||||
from entropice.dashboard.plots.source_data import (
|
|
||||||
render_alphaearth_map,
|
|
||||||
render_alphaearth_overview,
|
|
||||||
render_alphaearth_plots,
|
|
||||||
render_arcticdem_map,
|
|
||||||
render_arcticdem_overview,
|
|
||||||
render_arcticdem_plots,
|
|
||||||
render_areas_map,
|
|
||||||
render_era5_map,
|
|
||||||
render_era5_overview,
|
|
||||||
render_era5_plots,
|
|
||||||
)
|
|
||||||
from entropice.dashboard.plots.training_data import (
|
|
||||||
render_all_distribution_histograms,
|
|
||||||
render_spatial_map,
|
|
||||||
)
|
|
||||||
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
|
|
||||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
|
||||||
from entropice.spatial import grids
|
|
||||||
from entropice.utils.types import GridConfig, L2SourceDataset, TargetDataset, Task, grid_configs
|
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_configuration_sidebar():
|
|
||||||
"""Render dataset configuration selector in sidebar with form.
|
|
||||||
|
|
||||||
Stores the selected ensemble in session state when form is submitted.
|
|
||||||
"""
|
|
||||||
with st.sidebar.form("dataset_config_form"):
|
|
||||||
st.header("Dataset Configuration")
|
|
||||||
|
|
||||||
# Grid selection
|
|
||||||
grid_options = [gc.display_name for gc in grid_configs]
|
|
||||||
|
|
||||||
grid_level_combined = st.selectbox(
|
|
||||||
"Grid Configuration",
|
|
||||||
options=grid_options,
|
|
||||||
index=0,
|
|
||||||
help="Select the grid system and resolution level",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find the selected grid config
|
|
||||||
selected_grid_config: GridConfig = next(gc for gc in grid_configs if gc.display_name == grid_level_combined)
|
|
||||||
|
|
||||||
# Target feature selection
|
|
||||||
target = st.selectbox(
|
|
||||||
"Target Feature",
|
|
||||||
options=["darts_rts", "darts_mllabels"],
|
|
||||||
index=0,
|
|
||||||
help="Select the target variable for training",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Members selection
|
|
||||||
st.subheader("Dataset Members")
|
|
||||||
|
|
||||||
all_members = cast(
|
|
||||||
list[L2SourceDataset],
|
|
||||||
["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"],
|
|
||||||
)
|
|
||||||
selected_members: list[L2SourceDataset] = []
|
|
||||||
|
|
||||||
for member in all_members:
|
|
||||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
|
||||||
selected_members.append(member) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
# Form submit button
|
|
||||||
load_button = st.form_submit_button(
|
|
||||||
"Load Dataset",
|
|
||||||
type="primary",
|
|
||||||
use_container_width=True,
|
|
||||||
disabled=len(selected_members) == 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create DatasetEnsemble only when form is submitted
|
|
||||||
if load_button:
|
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid=selected_grid_config.grid,
|
|
||||||
level=selected_grid_config.level,
|
|
||||||
target=cast(TargetDataset, target),
|
|
||||||
members=selected_members,
|
|
||||||
)
|
|
||||||
# Store ensemble in session state
|
|
||||||
st.session_state["dataset_ensemble"] = ensemble
|
|
||||||
st.session_state["dataset_loaded"] = True
|
|
||||||
|
|
||||||
|
|
||||||
def render_dataset_statistics(ensemble: DatasetEnsemble):
|
|
||||||
"""Render dataset statistics and configuration overview.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### 📊 Dataset Configuration")
|
|
||||||
|
|
||||||
# Display current configuration in columns
|
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
st.metric(label="Grid Type", value=ensemble.grid.upper())
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
st.metric(label="Grid Level", value=ensemble.level)
|
|
||||||
|
|
||||||
with col3:
|
|
||||||
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
|
|
||||||
|
|
||||||
with col4:
|
|
||||||
st.metric(label="Members", value=len(ensemble.members))
|
|
||||||
|
|
||||||
# Display members in an expandable section
|
|
||||||
with st.expander("🗂️ Dataset Members", expanded=False):
|
|
||||||
members_cols = st.columns(len(ensemble.members))
|
|
||||||
for idx, member in enumerate(ensemble.members):
|
|
||||||
with members_cols[idx]:
|
|
||||||
st.markdown(f"✓ **{member}**")
|
|
||||||
|
|
||||||
# Display dataset ID in a styled container
|
|
||||||
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
|
||||||
|
|
||||||
# Display detailed dataset statistics
|
|
||||||
st.markdown("---")
|
|
||||||
st.markdown("### 📈 Dataset Statistics")
|
|
||||||
|
|
||||||
with st.spinner("Computing dataset statistics..."):
|
|
||||||
stats = ensemble.get_stats()
|
|
||||||
|
|
||||||
# High-level summary metrics
|
|
||||||
col1, col2, col3 = st.columns(3)
|
|
||||||
with col1:
|
|
||||||
st.metric(label="Total Samples", value=f"{stats['num_target_samples']:,}")
|
|
||||||
with col2:
|
|
||||||
st.metric(label="Total Features", value=f"{stats['total_features']:,}")
|
|
||||||
with col3:
|
|
||||||
st.metric(label="Data Sources", value=len(stats["members"]))
|
|
||||||
|
|
||||||
# Detailed member statistics in expandable section
|
|
||||||
with st.expander("📦 Data Source Details", expanded=False):
|
|
||||||
for member, member_stats in stats["members"].items():
|
|
||||||
st.markdown(f"### {member}")
|
|
||||||
|
|
||||||
# Create metrics for this member
|
|
||||||
metric_cols = st.columns(4)
|
|
||||||
with metric_cols[0]:
|
|
||||||
st.metric("Features", member_stats["num_features"])
|
|
||||||
with metric_cols[1]:
|
|
||||||
st.metric("Variables", member_stats["num_variables"])
|
|
||||||
with metric_cols[2]:
|
|
||||||
# Display dimensions in a more readable format
|
|
||||||
dim_str = " x ".join([f"{dim}" for dim in member_stats["dimensions"].values()]) # type: ignore[union-attr]
|
|
||||||
st.metric("Shape", dim_str)
|
|
||||||
with metric_cols[3]:
|
|
||||||
# Calculate total data points
|
|
||||||
total_points = 1
|
|
||||||
for dim_size in member_stats["dimensions"].values(): # type: ignore[union-attr]
|
|
||||||
total_points *= dim_size
|
|
||||||
st.metric("Data Points", f"{total_points:,}")
|
|
||||||
|
|
||||||
# Show variables as colored badges
|
|
||||||
st.markdown("**Variables:**")
|
|
||||||
vars_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #e3f2fd; color: #1976d2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">{v}</span>'
|
|
||||||
for v in member_stats["variables"] # type: ignore[union-attr]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(vars_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
# Show dimension details
|
|
||||||
st.markdown("**Dimensions:**")
|
|
||||||
dim_html = " ".join(
|
|
||||||
[
|
|
||||||
f'<span style="background-color: #f3e5f5; color: #7b1fa2; padding: 4px 8px; '
|
|
||||||
f'border-radius: 4px; margin: 2px; display: inline-block; font-size: 0.9em;">'
|
|
||||||
f"{dim_name}: {dim_size}</span>"
|
|
||||||
for dim_name, dim_size in member_stats["dimensions"].items() # type: ignore[union-attr]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
st.markdown(dim_html, unsafe_allow_html=True)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
|
|
||||||
def render_labels_view(ensemble: DatasetEnsemble, train_data_dict: dict[Task, CategoricalTrainingDataset]):
|
|
||||||
"""Render target labels distribution and spatial visualization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
train_data_dict: Pre-loaded training data for all tasks.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### Target Labels Distribution and Spatial Visualization")
|
|
||||||
|
|
||||||
# Calculate total samples (use binary as reference)
|
|
||||||
total_samples = len(train_data_dict["binary"])
|
|
||||||
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
|
||||||
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
|
||||||
|
|
||||||
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
|
|
||||||
|
|
||||||
# Render distribution histograms
|
|
||||||
st.markdown("---")
|
|
||||||
render_all_distribution_histograms(train_data_dict) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Render spatial map
|
|
||||||
binary_dataset = train_data_dict["binary"]
|
|
||||||
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
|
|
||||||
|
|
||||||
render_spatial_map(train_data_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def render_areas_view(ensemble: DatasetEnsemble, grid_gdf):
|
|
||||||
"""Render grid cell areas and land/water distribution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
grid_gdf: Pre-loaded grid GeoDataFrame.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### Grid Cell Areas and Land/Water Distribution")
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
"This visualization shows the spatial distribution of cell areas, land areas, "
|
|
||||||
"water areas, and land ratio across the grid. The grid has been filtered to "
|
|
||||||
"include only cells in the permafrost region (>50° latitude, <85° latitude) "
|
|
||||||
"with >10% land coverage."
|
|
||||||
)
|
|
||||||
|
|
||||||
st.success(
|
|
||||||
f"Loaded {len(grid_gdf)} grid cells with areas ranging from "
|
|
||||||
f"{grid_gdf['cell_area'].min():.2f} to {grid_gdf['cell_area'].max():.2f} km²"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show summary statistics
|
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
|
||||||
with col1:
|
|
||||||
st.metric("Total Cells", f"{len(grid_gdf):,}")
|
|
||||||
with col2:
|
|
||||||
st.metric("Avg Cell Area", f"{grid_gdf['cell_area'].mean():.2f} km²")
|
|
||||||
with col3:
|
|
||||||
st.metric("Avg Land Ratio", f"{grid_gdf['land_ratio'].mean():.1%}")
|
|
||||||
with col4:
|
|
||||||
total_land = grid_gdf["land_area"].sum()
|
|
||||||
st.metric("Total Land Area", f"{total_land:,.0f} km²")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Check if we should skip map rendering for performance
|
|
||||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
|
||||||
st.warning(
|
|
||||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
|
||||||
"due to performance considerations."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
render_areas_map(grid_gdf, ensemble.grid)
|
|
||||||
|
|
||||||
|
|
||||||
def render_alphaearth_view(ensemble: DatasetEnsemble, alphaearth_ds, targets):
|
|
||||||
"""Render AlphaEarth embeddings analysis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
alphaearth_ds: Pre-loaded AlphaEarth dataset.
|
|
||||||
targets: Pre-loaded targets GeoDataFrame.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### AlphaEarth Embeddings Analysis")
|
|
||||||
|
|
||||||
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
|
|
||||||
|
|
||||||
render_alphaearth_overview(alphaearth_ds)
|
|
||||||
render_alphaearth_plots(alphaearth_ds)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Check if we should skip map rendering for performance
|
|
||||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
|
||||||
st.warning(
|
|
||||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
|
||||||
"due to performance considerations."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
|
|
||||||
|
|
||||||
|
|
||||||
def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
|
|
||||||
"""Render ArcticDEM terrain analysis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
arcticdem_ds: Pre-loaded ArcticDEM dataset.
|
|
||||||
targets: Pre-loaded targets GeoDataFrame.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### ArcticDEM Terrain Analysis")
|
|
||||||
|
|
||||||
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
|
|
||||||
|
|
||||||
render_arcticdem_overview(arcticdem_ds)
|
|
||||||
render_arcticdem_plots(arcticdem_ds)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Check if we should skip map rendering for performance
|
|
||||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
|
||||||
st.warning(
|
|
||||||
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
|
||||||
"due to performance considerations."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
|
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
|
||||||
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
|
|
||||||
"""Render ERA5 climate data analysis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ensemble: The dataset ensemble configuration.
|
|
||||||
era5_data: Dictionary mapping ERA5 member names to (dataset, temporal_type) tuples.
|
|
||||||
targets: Pre-loaded targets GeoDataFrame.
|
|
||||||
|
|
||||||
"""
|
|
||||||
st.markdown("### ERA5 Climate Data Analysis")
|
|
||||||
|
|
||||||
# Let user select which ERA5 temporal aggregation to view
|
|
||||||
era5_options = {
|
|
||||||
"ERA5-yearly": "Yearly",
|
|
||||||
"ERA5-seasonal": "Seasonal (Winter/Summer)",
|
|
||||||
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
|
||||||
}
|
|
||||||
|
|
||||||
available_era5 = {k: v for k, v in era5_options.items() if k in era5_data}
|
|
||||||
|
|
||||||
selected_era5 = st.selectbox(
|
|
||||||
"Select ERA5 temporal aggregation",
|
|
||||||
options=list(available_era5.keys()),
|
|
||||||
format_func=lambda x: available_era5[x],
|
|
||||||
key="era5_temporal_select",
|
|
||||||
)
|
|
||||||
|
|
||||||
if selected_era5 and selected_era5 in era5_data:
|
|
||||||
era5_ds, temporal_type = era5_data[selected_era5]
|
|
||||||
|
|
||||||
render_era5_overview(era5_ds, temporal_type)
|
|
||||||
render_era5_plots(era5_ds, temporal_type)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Check if we should skip map rendering for performance
|
|
||||||
if (ensemble.grid == "hex" and ensemble.level == 6) or (ensemble.grid == "healpix" and ensemble.level == 10):
|
|
||||||
st.warning(
|
|
||||||
"🗡️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) "
|
|
||||||
"due to performance considerations."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
|
|
||||||
|
|
||||||
|
|
||||||
def render_training_data_page():
|
|
||||||
"""Render the Training Data page of the dashboard."""
|
|
||||||
st.title("🎯 Training Data")
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
Explore and visualize the training data for RTS prediction models.
|
|
||||||
Configure your dataset by selecting grid configuration, target dataset,
|
|
||||||
and data sources in the sidebar, then click "Load Dataset" to begin.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render sidebar configuration
|
|
||||||
render_dataset_configuration_sidebar()
|
|
||||||
|
|
||||||
# Check if dataset is loaded in session state
|
|
||||||
if not st.session_state.get("dataset_loaded", False) or "dataset_ensemble" not in st.session_state:
|
|
||||||
st.info(
|
|
||||||
"👈 Configure the dataset settings in the sidebar and click 'Load Dataset' to begin exploring training data"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get ensemble from session state
|
|
||||||
ensemble: DatasetEnsemble = st.session_state["dataset_ensemble"]
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Load all necessary data once
|
|
||||||
with st.spinner("Loading dataset..."):
|
|
||||||
# Load training data for all tasks
|
|
||||||
train_data_dict = load_all_training_data(ensemble)
|
|
||||||
|
|
||||||
# Load grid data
|
|
||||||
grid_gdf = grids.open(ensemble.grid, ensemble.level)
|
|
||||||
|
|
||||||
# Load targets (needed by all source data views)
|
|
||||||
targets = ensemble._read_target()
|
|
||||||
|
|
||||||
# Load AlphaEarth data if in members
|
|
||||||
alphaearth_ds = None
|
|
||||||
if "AlphaEarth" in ensemble.members:
|
|
||||||
alphaearth_ds, _ = load_source_data(ensemble, "AlphaEarth")
|
|
||||||
|
|
||||||
# Load ArcticDEM data if in members
|
|
||||||
arcticdem_ds = None
|
|
||||||
if "ArcticDEM" in ensemble.members:
|
|
||||||
arcticdem_ds, _ = load_source_data(ensemble, "ArcticDEM")
|
|
||||||
|
|
||||||
# Load ERA5 data for all temporal aggregations in members
|
|
||||||
era5_data = {}
|
|
||||||
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
|
||||||
for era5_member in era5_members:
|
|
||||||
era5_ds, _ = load_source_data(ensemble, era5_member)
|
|
||||||
temporal_type = era5_member.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
|
|
||||||
era5_data[era5_member] = (era5_ds, temporal_type)
|
|
||||||
|
|
||||||
st.success(
|
|
||||||
f"Loaded dataset with {len(train_data_dict['binary'])} samples and {ensemble.get_stats()['total_features']} features"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render dataset statistics
|
|
||||||
render_dataset_statistics(ensemble)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# Create tabs for different data views
|
|
||||||
tab_names = ["📊 Labels", "📐 Areas"]
|
|
||||||
|
|
||||||
# Add tabs for each member based on what's in the ensemble
|
|
||||||
if "AlphaEarth" in ensemble.members:
|
|
||||||
tab_names.append("🌍 AlphaEarth")
|
|
||||||
if "ArcticDEM" in ensemble.members:
|
|
||||||
tab_names.append("🏔️ ArcticDEM")
|
|
||||||
|
|
||||||
# Check for ERA5 members
|
|
||||||
if era5_members:
|
|
||||||
tab_names.append("🌡️ ERA5")
|
|
||||||
|
|
||||||
tabs = st.tabs(tab_names)
|
|
||||||
|
|
||||||
# Track current tab index
|
|
||||||
tab_idx = 0
|
|
||||||
|
|
||||||
# Labels tab
|
|
||||||
with tabs[tab_idx]:
|
|
||||||
render_labels_view(ensemble, train_data_dict)
|
|
||||||
tab_idx += 1
|
|
||||||
|
|
||||||
# Areas tab
|
|
||||||
with tabs[tab_idx]:
|
|
||||||
render_areas_view(ensemble, grid_gdf)
|
|
||||||
tab_idx += 1
|
|
||||||
|
|
||||||
# AlphaEarth tab
|
|
||||||
if "AlphaEarth" in ensemble.members:
|
|
||||||
with tabs[tab_idx]:
|
|
||||||
render_alphaearth_view(ensemble, alphaearth_ds, targets)
|
|
||||||
tab_idx += 1
|
|
||||||
|
|
||||||
# ArcticDEM tab
|
|
||||||
if "ArcticDEM" in ensemble.members:
|
|
||||||
with tabs[tab_idx]:
|
|
||||||
render_arcticdem_view(ensemble, arcticdem_ds, targets)
|
|
||||||
tab_idx += 1
|
|
||||||
|
|
||||||
# ERA5 tab (combining all temporal variants)
|
|
||||||
if era5_members:
|
|
||||||
with tabs[tab_idx]:
|
|
||||||
render_era5_view(ensemble, era5_data, targets)
|
|
||||||
|
|
||||||
# Show balloons once after all tabs are rendered
|
|
||||||
st.balloons()
|
|
||||||
stopwatch.summary()
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue