Fix minor bugs on training data page

This commit is contained in:
Tobias Hölzer 2025-12-19 01:04:01 +01:00
parent 3d6417ef6b
commit 31933b58d3
3 changed files with 42 additions and 17 deletions

View file

@ -173,7 +173,7 @@ def render_alphaearth_plots(ds: xr.Dataset):
hovermode="x unified",
)
st.plotly_chart(fig, use_container_width=True)
st.plotly_chart(fig, width="stretch")
# Band statistics
with st.expander("📈 Statistics by Embedding Band", expanded=False):
@ -194,7 +194,7 @@ def render_alphaearth_plots(ds: xr.Dataset):
)
band_df = pd.DataFrame(band_stats)
st.dataframe(band_df, use_container_width=True, hide_index=True)
st.dataframe(band_df, width="stretch", hide_index=True)
if len(ds["band"]) > 10:
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
@ -250,7 +250,7 @@ def render_arcticdem_overview(ds: xr.Dataset):
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, use_container_width=True, hide_index=True)
st.dataframe(stats_df, width="stretch", hide_index=True)
@st.fragment
@ -296,7 +296,7 @@ def render_arcticdem_plots(ds: xr.Dataset):
height=400,
)
st.plotly_chart(fig, use_container_width=True)
st.plotly_chart(fig, width="stretch")
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
@ -362,7 +362,7 @@ def render_era5_overview(ds: xr.Dataset, temporal_type: str):
)
stats_df = pd.DataFrame(var_stats)
st.dataframe(stats_df, use_container_width=True, hide_index=True)
st.dataframe(stats_df, width="stretch", hide_index=True)
@st.fragment
@ -416,7 +416,7 @@ def render_era5_plots(ds: xr.Dataset, temporal_type: str):
hovermode="x unified",
)
st.plotly_chart(fig, use_container_width=True)
st.plotly_chart(fig, width="stretch")
@st.fragment

View file

@ -163,8 +163,6 @@ def render_training_data_page():
render_spatial_map(train_data_dict)
st.balloons()
# AlphaEarth tab
tab_idx = 1
if "AlphaEarth" in ensemble.members:
@ -181,9 +179,14 @@ def render_training_data_page():
st.markdown("---")
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
st.balloons()
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
tab_idx += 1
@ -202,9 +205,14 @@ def render_training_data_page():
st.markdown("---")
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
st.balloons()
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
tab_idx += 1
@ -243,9 +251,17 @@ def render_training_data_page():
st.markdown("---")
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
if (ensemble.grid == "hex" and ensemble.level == 6) or (
ensemble.grid == "healpix" and ensemble.level == 10
):
st.warning(
"🗺️ Spatial map rendering is disabled for this grid configuration (hex-6 or healpix-10) due to performance considerations."
)
else:
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
st.balloons()
# Show balloons once after all tabs are rendered
st.balloons()
else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")

View file

@ -90,7 +90,16 @@ def bin_values(
non_none_values = values[~none_mask]
assert len(non_none_values) > 5, "Not enough non-none values to create bins."
binned_non_none = pd.qcut(non_none_values, q=5, labels=labels[1:]).cat.set_categories(labels, ordered=True)
# Create bins without labels first to handle duplicates
binned_non_none = pd.qcut(non_none_values, q=5, labels=False, duplicates="drop")
# Map the bin indices to labels, using only as many labels as bins created
n_bins = binned_non_none.max() + 1 # bins are 0-indexed
active_labels = labels[1 : 1 + n_bins] # Skip "None" and take n_bins labels
binned_non_none = binned_non_none.map(dict(enumerate(active_labels))).astype("category")
binned_non_none = binned_non_none.cat.set_categories(labels, ordered=True)
binned = pd.Series(index=values.index, dtype="category")
binned = binned.cat.set_categories(labels, ordered=True)
binned.update(binned_non_none)