Add data sources viz to dashboard
This commit is contained in:
parent
f5ea72e05e
commit
3d6417ef6b
3 changed files with 952 additions and 19 deletions
798
src/entropice/dashboard/plots/source_data.py
Normal file
798
src/entropice/dashboard/plots/source_data.py
Normal file
|
|
@ -0,0 +1,798 @@
|
||||||
|
"""Plotting functions for source data visualizations (AlphaEarth, ArcticDEM, ERA5)."""
|
||||||
|
|
||||||
|
import antimeridian
|
||||||
|
import geopandas as gpd
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import pydeck as pdk
|
||||||
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
from shapely.geometry import shape
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.colors import get_cmap
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_hex_geometry(geom):
|
||||||
|
"""Fix hexagon geometry crossing the antimeridian."""
|
||||||
|
try:
|
||||||
|
return shape(antimeridian.fix_shape(geom))
|
||||||
|
except ValueError as e:
|
||||||
|
st.error(f"Error fixing geometry: {e}")
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
def render_alphaearth_overview(ds: xr.Dataset):
|
||||||
|
"""Render overview statistics for AlphaEarth embeddings data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing AlphaEarth embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("📊 AlphaEarth Embeddings Statistics")
|
||||||
|
|
||||||
|
# Overall statistics
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("Embedding Dimensions", f"{len(ds['band'])}")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric("Years Available", f"{len(ds['year'])}")
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
st.metric("Aggregations", f"{len(ds['agg'])}")
|
||||||
|
|
||||||
|
# Show temporal coverage
|
||||||
|
st.markdown("**Temporal Coverage:**")
|
||||||
|
years = sorted(ds["year"].values)
|
||||||
|
st.write(f"Years: {min(years)} - {max(years)}")
|
||||||
|
|
||||||
|
# Show aggregations
|
||||||
|
st.markdown("**Available Aggregations:**")
|
||||||
|
aggs = ds["agg"].to_numpy()
|
||||||
|
st.write(", ".join(str(a) for a in aggs))
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_alphaearth_plots(ds: xr.Dataset):
|
||||||
|
"""Render interactive plots for AlphaEarth embeddings data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing AlphaEarth embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Embedding Distribution by Band**")
|
||||||
|
|
||||||
|
embeddings_data = ds["embeddings"]
|
||||||
|
|
||||||
|
# Select year and aggregation for visualization
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
selected_year = st.selectbox("Select Year", options=sorted(ds["year"].values), key="stats_year")
|
||||||
|
with col2:
|
||||||
|
selected_agg = st.selectbox("Select Aggregation", options=ds["agg"].values, key="stats_agg")
|
||||||
|
|
||||||
|
# Get data for selected year and aggregation
|
||||||
|
year_agg_data = embeddings_data.sel(year=selected_year, agg=selected_agg)
|
||||||
|
|
||||||
|
# Calculate statistics for each band
|
||||||
|
band_stats = []
|
||||||
|
for band_idx in range(len(ds["band"])):
|
||||||
|
band_data = year_agg_data.isel(band=band_idx).values.flatten()
|
||||||
|
band_data = band_data[~np.isnan(band_data)] # Remove NaN values
|
||||||
|
|
||||||
|
if len(band_data) > 0:
|
||||||
|
band_stats.append(
|
||||||
|
{
|
||||||
|
"Band": band_idx,
|
||||||
|
"Mean": float(np.mean(band_data)),
|
||||||
|
"Std": float(np.std(band_data)),
|
||||||
|
"Min": float(np.min(band_data)),
|
||||||
|
"25%": float(np.percentile(band_data, 25)),
|
||||||
|
"Median": float(np.median(band_data)),
|
||||||
|
"75%": float(np.percentile(band_data, 75)),
|
||||||
|
"Max": float(np.max(band_data)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
band_df = pd.DataFrame(band_stats)
|
||||||
|
|
||||||
|
# Create plot showing distribution across bands
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Add min/max range first (background)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Min"],
|
||||||
|
mode="lines",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
name="Min/Max Range",
|
||||||
|
showlegend=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Max"],
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(200, 200, 200, 0.1)",
|
||||||
|
line={"color": "lightgray", "width": 1, "dash": "dash"},
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add std band
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Mean"] - band_df["Std"],
|
||||||
|
mode="lines",
|
||||||
|
line={"width": 0},
|
||||||
|
showlegend=False,
|
||||||
|
hoverinfo="skip",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Mean"] + band_df["Std"],
|
||||||
|
mode="lines",
|
||||||
|
fill="tonexty",
|
||||||
|
fillcolor="rgba(31, 119, 180, 0.2)",
|
||||||
|
line={"width": 0},
|
||||||
|
name="±1 Std",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add mean line on top
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=band_df["Band"],
|
||||||
|
y=band_df["Mean"],
|
||||||
|
mode="lines+markers",
|
||||||
|
name="Mean",
|
||||||
|
line={"color": "#1f77b4", "width": 2},
|
||||||
|
marker={"size": 4},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Embedding Statistics per Band (Year: {selected_year}, Aggregation: {selected_agg})",
|
||||||
|
xaxis_title="Band",
|
||||||
|
yaxis_title="Embedding Value",
|
||||||
|
height=450,
|
||||||
|
hovermode="x unified",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
# Band statistics
|
||||||
|
with st.expander("📈 Statistics by Embedding Band", expanded=False):
|
||||||
|
st.markdown("Statistics aggregated across all years and aggregations for each embedding dimension:")
|
||||||
|
|
||||||
|
# Calculate statistics for each band
|
||||||
|
band_stats = []
|
||||||
|
for band_idx in range(min(10, len(ds["band"]))): # Show first 10 bands
|
||||||
|
band_data = embeddings_data.isel(band=band_idx)
|
||||||
|
band_stats.append(
|
||||||
|
{
|
||||||
|
"Band": int(band_idx),
|
||||||
|
"Mean": float(band_data.mean().values),
|
||||||
|
"Std": float(band_data.std().values),
|
||||||
|
"Min": float(band_data.min().values),
|
||||||
|
"Max": float(band_data.max().values),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
band_df = pd.DataFrame(band_stats)
|
||||||
|
st.dataframe(band_df, use_container_width=True, hide_index=True)
|
||||||
|
|
||||||
|
if len(ds["band"]) > 10:
|
||||||
|
st.info(f"Showing first 10 of {len(ds['band'])} embedding dimensions")
|
||||||
|
|
||||||
|
|
||||||
|
def render_arcticdem_overview(ds: xr.Dataset):
|
||||||
|
"""Render overview statistics for ArcticDEM terrain data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ArcticDEM data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("🏔️ ArcticDEM Terrain Statistics")
|
||||||
|
|
||||||
|
# Overall statistics
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("Variables", f"{len(ds.data_vars)}")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||||||
|
|
||||||
|
# Show available variables
|
||||||
|
st.markdown("**Available Variables:**")
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
st.write(", ".join(variables))
|
||||||
|
|
||||||
|
# Show aggregations
|
||||||
|
st.markdown("**Available Aggregations:**")
|
||||||
|
aggs = ds["aggregations"].to_numpy()
|
||||||
|
st.write(", ".join(str(a) for a in aggs))
|
||||||
|
|
||||||
|
# Statistics by variable
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Variable Statistics (across all aggregations)**")
|
||||||
|
|
||||||
|
var_stats = []
|
||||||
|
for var_name in ds.data_vars:
|
||||||
|
var_data = ds[var_name]
|
||||||
|
var_stats.append(
|
||||||
|
{
|
||||||
|
"Variable": var_name,
|
||||||
|
"Mean": float(var_data.mean().values),
|
||||||
|
"Std": float(var_data.std().values),
|
||||||
|
"Min": float(var_data.min().values),
|
||||||
|
"Max": float(var_data.max().values),
|
||||||
|
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(var_stats)
|
||||||
|
st.dataframe(stats_df, use_container_width=True, hide_index=True)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_arcticdem_plots(ds: xr.Dataset):
|
||||||
|
"""Render interactive plots for ArcticDEM terrain data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ArcticDEM data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Variable Distributions**")
|
||||||
|
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
|
||||||
|
# Select a variable to visualize
|
||||||
|
selected_var = st.selectbox("Select variable to visualize", options=variables, key="arcticdem_var_select")
|
||||||
|
|
||||||
|
if selected_var:
|
||||||
|
var_data = ds[selected_var]
|
||||||
|
|
||||||
|
# Create histogram
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
for agg in ds["aggregations"].to_numpy():
|
||||||
|
agg_data = var_data.sel(aggregations=agg).to_numpy().flatten()
|
||||||
|
agg_data = agg_data[~np.isnan(agg_data)]
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Histogram(
|
||||||
|
x=agg_data,
|
||||||
|
name=str(agg),
|
||||||
|
opacity=0.7,
|
||||||
|
nbinsx=50,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Distribution of {selected_var} by Aggregation",
|
||||||
|
xaxis_title=selected_var,
|
||||||
|
yaxis_title="Count",
|
||||||
|
barmode="overlay",
|
||||||
|
height=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
|
def render_era5_overview(ds: xr.Dataset, temporal_type: str):
|
||||||
|
"""Render overview statistics for ERA5 climate data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ERA5 data.
|
||||||
|
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader(f"🌡️ ERA5 Climate Statistics ({temporal_type.capitalize()})")
|
||||||
|
|
||||||
|
# Overall statistics
|
||||||
|
has_agg = "aggregations" in ds.dims
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.metric("Total Cells", f"{len(ds['cell_ids']):,}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.metric("Variables", f"{len(ds.data_vars)}")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
st.metric("Time Steps", f"{len(ds['time'])}")
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
if has_agg:
|
||||||
|
st.metric("Aggregations", f"{len(ds['aggregations'])}")
|
||||||
|
else:
|
||||||
|
st.metric("Aggregations", "1")
|
||||||
|
|
||||||
|
# Show available variables
|
||||||
|
st.markdown("**Available Variables:**")
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
st.write(", ".join(variables))
|
||||||
|
|
||||||
|
# Show temporal range
|
||||||
|
st.markdown("**Temporal Range:**")
|
||||||
|
time_values = pd.to_datetime(ds["time"].values)
|
||||||
|
st.write(f"{time_values.min().strftime('%Y-%m-%d')} to {time_values.max().strftime('%Y-%m-%d')}")
|
||||||
|
|
||||||
|
if has_agg:
|
||||||
|
st.markdown("**Available Aggregations:**")
|
||||||
|
aggs = ds["aggregations"].to_numpy()
|
||||||
|
st.write(", ".join(str(a) for a in aggs))
|
||||||
|
|
||||||
|
# Statistics by variable
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Variable Statistics (across all time steps and aggregations)**")
|
||||||
|
|
||||||
|
var_stats = []
|
||||||
|
for var_name in ds.data_vars:
|
||||||
|
var_data = ds[var_name]
|
||||||
|
var_stats.append(
|
||||||
|
{
|
||||||
|
"Variable": var_name,
|
||||||
|
"Mean": float(var_data.mean().values),
|
||||||
|
"Std": float(var_data.std().values),
|
||||||
|
"Min": float(var_data.min().values),
|
||||||
|
"Max": float(var_data.max().values),
|
||||||
|
"Missing %": float((var_data.isnull().sum() / var_data.size * 100).values),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(var_stats)
|
||||||
|
st.dataframe(stats_df, use_container_width=True, hide_index=True)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_era5_plots(ds: xr.Dataset, temporal_type: str):
|
||||||
|
"""Render interactive plots for ERA5 climate data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ERA5 data.
|
||||||
|
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Temporal Trends**")
|
||||||
|
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
has_agg = "aggregations" in ds.dims
|
||||||
|
|
||||||
|
selected_var = st.selectbox(
|
||||||
|
"Select variable to visualize", options=variables, key=f"era5_{temporal_type}_var_select"
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_var:
|
||||||
|
var_data = ds[selected_var]
|
||||||
|
|
||||||
|
# Calculate mean over space for each time step
|
||||||
|
if has_agg:
|
||||||
|
# Average over aggregations first, then over cells
|
||||||
|
time_series = var_data.mean(dim=["cell_ids", "aggregations"])
|
||||||
|
else:
|
||||||
|
time_series = var_data.mean(dim="cell_ids")
|
||||||
|
|
||||||
|
time_df = pd.DataFrame({"Time": pd.to_datetime(ds["time"].to_numpy()), "Value": time_series.to_numpy()})
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=time_df["Time"],
|
||||||
|
y=time_df["Value"],
|
||||||
|
mode="lines+markers",
|
||||||
|
name=selected_var,
|
||||||
|
line={"width": 2},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Temporal Trend of {selected_var} (Spatial Mean)",
|
||||||
|
xaxis_title="Time",
|
||||||
|
yaxis_title=selected_var,
|
||||||
|
height=400,
|
||||||
|
hovermode="x unified",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_alphaearth_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||||||
|
"""Render interactive pydeck map for AlphaEarth embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing AlphaEarth data.
|
||||||
|
targets: GeoDataFrame with geometry for each cell.
|
||||||
|
grid: Grid type ('hex' or 'healpix').
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("🗺️ AlphaEarth Spatial Distribution")
|
||||||
|
|
||||||
|
# Controls
|
||||||
|
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
selected_year = st.selectbox("Year", options=sorted(ds["year"].values), key="alphaearth_year")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
selected_agg = st.selectbox("Aggregation", options=ds["agg"].values, key="alphaearth_agg")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
selected_band = st.selectbox("Band", options=list(range(len(ds["band"]))), key="alphaearth_band")
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="alphaearth_opacity")
|
||||||
|
|
||||||
|
# Extract data for selected parameters
|
||||||
|
data_values = ds["embeddings"].sel(year=selected_year, agg=selected_agg).isel(band=selected_band)
|
||||||
|
|
||||||
|
# Create GeoDataFrame
|
||||||
|
gdf = targets.copy()
|
||||||
|
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||||
|
gdf = gdf.set_index("cell_id")
|
||||||
|
|
||||||
|
# Add values
|
||||||
|
values_df = data_values.to_dataframe(name="value")
|
||||||
|
gdf = gdf.join(values_df, how="inner")
|
||||||
|
|
||||||
|
# Convert to WGS84 first
|
||||||
|
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix geometries after CRS conversion
|
||||||
|
if grid == "hex":
|
||||||
|
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||||
|
|
||||||
|
# Normalize values for color mapping
|
||||||
|
values = gdf_wgs84["value"].to_numpy()
|
||||||
|
vmin, vmax = np.nanpercentile(values, [2, 98]) # Use percentiles to avoid outliers
|
||||||
|
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||||
|
|
||||||
|
# Apply colormap
|
||||||
|
cmap = get_cmap("embeddings")
|
||||||
|
colors = [cmap(val) for val in normalized]
|
||||||
|
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||||
|
|
||||||
|
# Create GeoJSON
|
||||||
|
geojson_data = []
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": {
|
||||||
|
"value": float(row["value"]),
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=opacity,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||||
|
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={"html": "<b>Value:</b> {value}", "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.pydeck_chart(deck)
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
st.caption(f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values):.4f} | Std: {np.nanstd(values):.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_arcticdem_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str):
|
||||||
|
"""Render interactive pydeck map for ArcticDEM terrain data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ArcticDEM data.
|
||||||
|
targets: GeoDataFrame with geometry for each cell.
|
||||||
|
grid: Grid type ('hex' or 'healpix').
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("🗺️ ArcticDEM Spatial Distribution")
|
||||||
|
|
||||||
|
# Controls
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
col1, col2, col3 = st.columns([3, 3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
selected_var = st.selectbox("Variable", options=variables, key="arcticdem_var")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
selected_agg = st.selectbox("Aggregation", options=ds["aggregations"].values, key="arcticdem_agg")
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key="arcticdem_opacity")
|
||||||
|
|
||||||
|
# Extract data for selected parameters
|
||||||
|
data_values = ds[selected_var].sel(aggregations=selected_agg)
|
||||||
|
|
||||||
|
# Create GeoDataFrame
|
||||||
|
gdf = targets.copy()
|
||||||
|
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||||
|
gdf = gdf.set_index("cell_id")
|
||||||
|
|
||||||
|
# Add values
|
||||||
|
values_df = data_values.to_dataframe(name="value")
|
||||||
|
gdf = gdf.join(values_df, how="inner")
|
||||||
|
|
||||||
|
# Add all aggregation values for tooltip
|
||||||
|
if len(ds["aggregations"]) > 1:
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
agg_data = ds[selected_var].sel(aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||||||
|
# Drop the aggregations column to avoid conflicts
|
||||||
|
if "aggregations" in agg_data.columns:
|
||||||
|
agg_data = agg_data.drop(columns=["aggregations"])
|
||||||
|
gdf = gdf.join(agg_data, how="left")
|
||||||
|
|
||||||
|
# Convert to WGS84 first
|
||||||
|
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix geometries after CRS conversion
|
||||||
|
if grid == "hex":
|
||||||
|
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||||
|
|
||||||
|
# Normalize values for color mapping
|
||||||
|
values = gdf_wgs84["value"].values
|
||||||
|
values_clean = values[~np.isnan(values)]
|
||||||
|
|
||||||
|
if len(values_clean) > 0:
|
||||||
|
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||||||
|
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||||
|
|
||||||
|
# Apply colormap
|
||||||
|
cmap = get_cmap(f"arcticdem_{selected_var}")
|
||||||
|
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||||||
|
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||||
|
|
||||||
|
# Create GeoJSON
|
||||||
|
geojson_data = []
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
properties = {
|
||||||
|
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
}
|
||||||
|
# Add all aggregation values if available
|
||||||
|
if len(ds["aggregations"]) > 1:
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
agg_col = f"agg_{agg}"
|
||||||
|
if agg_col in row.index:
|
||||||
|
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||||||
|
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=opacity,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||||
|
|
||||||
|
# Build tooltip HTML for ArcticDEM
|
||||||
|
if len(ds["aggregations"]) > 1:
|
||||||
|
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||||||
|
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||||||
|
tooltip_html = "".join(tooltip_lines)
|
||||||
|
else:
|
||||||
|
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||||||
|
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.pydeck_chart(deck)
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
st.caption(
|
||||||
|
f"Min: {vmin:.2f} | Max: {vmax:.2f} | Mean: {np.nanmean(values_clean):.2f} | "
|
||||||
|
f"Std: {np.nanstd(values_clean):.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.warning("No valid data available for selected parameters")
|
||||||
|
|
||||||
|
|
||||||
|
@st.fragment
|
||||||
|
def render_era5_map(ds: xr.Dataset, targets: gpd.GeoDataFrame, grid: str, temporal_type: str):
|
||||||
|
"""Render interactive pydeck map for ERA5 climate data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds: xarray Dataset containing ERA5 data.
|
||||||
|
targets: GeoDataFrame with geometry for each cell.
|
||||||
|
grid: Grid type ('hex' or 'healpix').
|
||||||
|
temporal_type: One of 'yearly', 'seasonal', 'shoulder'.
|
||||||
|
|
||||||
|
"""
|
||||||
|
st.subheader("🗺️ ERA5 Spatial Distribution")
|
||||||
|
|
||||||
|
# Controls
|
||||||
|
variables = list(ds.data_vars)
|
||||||
|
has_agg = "aggregations" in ds.dims
|
||||||
|
|
||||||
|
if has_agg:
|
||||||
|
col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
|
||||||
|
else:
|
||||||
|
col1, col2, col3 = st.columns([3, 3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
selected_var = st.selectbox("Variable", options=variables, key=f"era5_{temporal_type}_var")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# Convert time to readable format
|
||||||
|
time_values = pd.to_datetime(ds["time"].values)
|
||||||
|
time_options = {str(t): t for t in time_values}
|
||||||
|
selected_time = st.selectbox("Time", options=list(time_options.keys()), key=f"era5_{temporal_type}_time")
|
||||||
|
|
||||||
|
if has_agg:
|
||||||
|
with col3:
|
||||||
|
selected_agg = st.selectbox(
|
||||||
|
"Aggregation", options=ds["aggregations"].values, key=f"era5_{temporal_type}_agg"
|
||||||
|
)
|
||||||
|
with col4:
|
||||||
|
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||||||
|
else:
|
||||||
|
with col3:
|
||||||
|
opacity = st.slider("Opacity", 0.1, 1.0, 0.7, 0.1, key=f"era5_{temporal_type}_opacity")
|
||||||
|
|
||||||
|
# Extract data for selected parameters
|
||||||
|
time_val = time_options[selected_time]
|
||||||
|
if has_agg:
|
||||||
|
data_values = ds[selected_var].sel(time=time_val, aggregations=selected_agg)
|
||||||
|
else:
|
||||||
|
data_values = ds[selected_var].sel(time=time_val)
|
||||||
|
|
||||||
|
# Create GeoDataFrame
|
||||||
|
gdf = targets.copy()
|
||||||
|
gdf = gdf[gdf["cell_id"].isin(ds["cell_ids"].values)]
|
||||||
|
gdf = gdf.set_index("cell_id")
|
||||||
|
|
||||||
|
# Add values
|
||||||
|
values_df = data_values.to_dataframe(name="value")
|
||||||
|
gdf = gdf.join(values_df, how="inner")
|
||||||
|
|
||||||
|
# Add all aggregation values for tooltip if has_agg
|
||||||
|
if has_agg and len(ds["aggregations"]) > 1:
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
agg_data = ds[selected_var].sel(time=time_val, aggregations=agg).to_dataframe(name=f"agg_{agg}")
|
||||||
|
# Drop dimension columns to avoid conflicts
|
||||||
|
cols_to_drop = [col for col in ["aggregations", "time"] if col in agg_data.columns]
|
||||||
|
if cols_to_drop:
|
||||||
|
agg_data = agg_data.drop(columns=cols_to_drop)
|
||||||
|
gdf = gdf.join(agg_data, how="left")
|
||||||
|
|
||||||
|
# Convert to WGS84 first
|
||||||
|
gdf_wgs84 = gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
# Fix geometries after CRS conversion
|
||||||
|
if grid == "hex":
|
||||||
|
gdf_wgs84["geometry"] = gdf_wgs84["geometry"].apply(_fix_hex_geometry)
|
||||||
|
|
||||||
|
# Normalize values for color mapping
|
||||||
|
values = gdf_wgs84["value"].values
|
||||||
|
values_clean = values[~np.isnan(values)]
|
||||||
|
|
||||||
|
if len(values_clean) > 0:
|
||||||
|
vmin, vmax = np.nanpercentile(values_clean, [2, 98])
|
||||||
|
normalized = np.clip((values - vmin) / (vmax - vmin), 0, 1)
|
||||||
|
|
||||||
|
# Apply colormap - use variable-specific colors
|
||||||
|
cmap = get_cmap(f"era5_{selected_var}")
|
||||||
|
colors = [cmap(val) if not np.isnan(val) else (0.5, 0.5, 0.5, 0.5) for val in normalized]
|
||||||
|
gdf_wgs84["fill_color"] = [[int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)] for c in colors]
|
||||||
|
|
||||||
|
# Create GeoJSON
|
||||||
|
geojson_data = []
|
||||||
|
for _, row in gdf_wgs84.iterrows():
|
||||||
|
properties = {
|
||||||
|
"value": float(row["value"]) if not np.isnan(row["value"]) else None,
|
||||||
|
"fill_color": row["fill_color"],
|
||||||
|
}
|
||||||
|
# Add all aggregation values if available
|
||||||
|
if has_agg and len(ds["aggregations"]) > 1:
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
agg_col = f"agg_{agg}"
|
||||||
|
if agg_col in row.index:
|
||||||
|
properties[agg_col] = float(row[agg_col]) if not np.isnan(row[agg_col]) else None
|
||||||
|
|
||||||
|
feature = {
|
||||||
|
"type": "Feature",
|
||||||
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
geojson_data.append(feature)
|
||||||
|
|
||||||
|
# Create pydeck layer
|
||||||
|
layer = pdk.Layer(
|
||||||
|
"GeoJsonLayer",
|
||||||
|
geojson_data,
|
||||||
|
opacity=opacity,
|
||||||
|
stroked=True,
|
||||||
|
filled=True,
|
||||||
|
get_fill_color="properties.fill_color",
|
||||||
|
get_line_color=[80, 80, 80],
|
||||||
|
line_width_min_pixels=0.5,
|
||||||
|
pickable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
view_state = pdk.ViewState(latitude=70, longitude=0, zoom=2, pitch=0)
|
||||||
|
|
||||||
|
# Build tooltip HTML for ERA5
|
||||||
|
if has_agg and len(ds["aggregations"]) > 1:
|
||||||
|
tooltip_lines = [f"<b>{selected_var} (selected: {selected_agg}):</b> {{value}}<br/>"]
|
||||||
|
tooltip_lines.append("<b>All aggregations:</b><br/>")
|
||||||
|
for agg in ds["aggregations"].values:
|
||||||
|
tooltip_lines.append(f" {agg}: {{agg_{agg}}}<br/>")
|
||||||
|
tooltip_html = "".join(tooltip_lines)
|
||||||
|
else:
|
||||||
|
tooltip_html = f"<b>{selected_var}:</b> {{value}}"
|
||||||
|
|
||||||
|
deck = pdk.Deck(
|
||||||
|
layers=[layer],
|
||||||
|
initial_view_state=view_state,
|
||||||
|
tooltip={"html": tooltip_html, "style": {"backgroundColor": "steelblue", "color": "white"}},
|
||||||
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
st.pydeck_chart(deck)
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
st.caption(
|
||||||
|
f"Min: {vmin:.4f} | Max: {vmax:.4f} | Mean: {np.nanmean(values_clean):.4f} | "
|
||||||
|
f"Std: {np.nanstd(values_clean):.4f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.warning("No valid data available for selected parameters")
|
||||||
|
|
@ -2,8 +2,19 @@
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from entropice.dashboard.plots.source_data import (
|
||||||
|
render_alphaearth_map,
|
||||||
|
render_alphaearth_overview,
|
||||||
|
render_alphaearth_plots,
|
||||||
|
render_arcticdem_map,
|
||||||
|
render_arcticdem_overview,
|
||||||
|
render_arcticdem_plots,
|
||||||
|
render_era5_map,
|
||||||
|
render_era5_overview,
|
||||||
|
render_era5_plots,
|
||||||
|
)
|
||||||
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||||
from entropice.dashboard.utils.data import load_all_training_data
|
from entropice.dashboard.utils.data import load_all_training_data, load_source_data
|
||||||
from entropice.dataset import DatasetEnsemble
|
from entropice.dataset import DatasetEnsemble
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -107,30 +118,134 @@ def render_training_data_page():
|
||||||
# Display dataset ID in a styled container
|
# Display dataset ID in a styled container
|
||||||
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
st.info(f"**Dataset ID:** `{ensemble.id()}`")
|
||||||
|
|
||||||
# Load training data for all three tasks
|
# Create tabs for different data views
|
||||||
train_data_dict = load_all_training_data(ensemble)
|
tab_names = ["📊 Labels"]
|
||||||
|
|
||||||
# Calculate total samples (use binary as reference)
|
# Add tabs for each member
|
||||||
total_samples = len(train_data_dict["binary"])
|
for member in ensemble.members:
|
||||||
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
if member == "AlphaEarth":
|
||||||
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
tab_names.append("🌍 AlphaEarth")
|
||||||
|
elif member == "ArcticDEM":
|
||||||
|
tab_names.append("🏔️ ArcticDEM")
|
||||||
|
elif member.startswith("ERA5"):
|
||||||
|
# Group ERA5 temporal variants
|
||||||
|
if "🌡️ ERA5" not in tab_names:
|
||||||
|
tab_names.append("🌡️ ERA5")
|
||||||
|
|
||||||
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
|
tabs = st.tabs(tab_names)
|
||||||
|
|
||||||
# Render distribution histograms
|
# Labels tab
|
||||||
st.markdown("---")
|
with tabs[0]:
|
||||||
render_all_distribution_histograms(train_data_dict)
|
st.markdown("### Target Labels Distribution and Spatial Visualization")
|
||||||
|
|
||||||
st.markdown("---")
|
# Load training data for all three tasks
|
||||||
|
with st.spinner("Loading training data for all tasks..."):
|
||||||
|
train_data_dict = load_all_training_data(ensemble)
|
||||||
|
|
||||||
# Render spatial map (as a fragment for efficient re-rendering)
|
# Calculate total samples (use binary as reference)
|
||||||
# Extract geometries from the X.data dataframe (which has geometry as a column)
|
total_samples = len(train_data_dict["binary"])
|
||||||
# The index should be cell_id
|
train_samples = (train_data_dict["binary"].split == "train").sum().item()
|
||||||
binary_dataset = train_data_dict["binary"]
|
test_samples = (train_data_dict["binary"].split == "test").sum().item()
|
||||||
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
|
|
||||||
|
|
||||||
render_spatial_map(train_data_dict)
|
st.success(
|
||||||
|
f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Render distribution histograms
|
||||||
|
st.markdown("---")
|
||||||
|
render_all_distribution_histograms(train_data_dict)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# Render spatial map
|
||||||
|
binary_dataset = train_data_dict["binary"]
|
||||||
|
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
|
||||||
|
|
||||||
|
render_spatial_map(train_data_dict)
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
||||||
|
# AlphaEarth tab
|
||||||
|
tab_idx = 1
|
||||||
|
if "AlphaEarth" in ensemble.members:
|
||||||
|
with tabs[tab_idx]:
|
||||||
|
st.markdown("### AlphaEarth Embeddings Analysis")
|
||||||
|
|
||||||
|
with st.spinner("Loading AlphaEarth data..."):
|
||||||
|
alphaearth_ds, targets = load_source_data(ensemble, "AlphaEarth")
|
||||||
|
|
||||||
|
st.success(f"Loaded AlphaEarth data with {len(alphaearth_ds['cell_ids'])} cells")
|
||||||
|
|
||||||
|
render_alphaearth_overview(alphaearth_ds)
|
||||||
|
render_alphaearth_plots(alphaearth_ds)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
render_alphaearth_map(alphaearth_ds, targets, ensemble.grid)
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
||||||
|
tab_idx += 1
|
||||||
|
|
||||||
|
# ArcticDEM tab
|
||||||
|
if "ArcticDEM" in ensemble.members:
|
||||||
|
with tabs[tab_idx]:
|
||||||
|
st.markdown("### ArcticDEM Terrain Analysis")
|
||||||
|
|
||||||
|
with st.spinner("Loading ArcticDEM data..."):
|
||||||
|
arcticdem_ds, targets = load_source_data(ensemble, "ArcticDEM")
|
||||||
|
|
||||||
|
st.success(f"Loaded ArcticDEM data with {len(arcticdem_ds['cell_ids'])} cells")
|
||||||
|
|
||||||
|
render_arcticdem_overview(arcticdem_ds)
|
||||||
|
render_arcticdem_plots(arcticdem_ds)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
||||||
|
tab_idx += 1
|
||||||
|
|
||||||
|
# ERA5 tab (combining all temporal variants)
|
||||||
|
era5_members = [m for m in ensemble.members if m.startswith("ERA5")]
|
||||||
|
if era5_members:
|
||||||
|
with tabs[tab_idx]:
|
||||||
|
st.markdown("### ERA5 Climate Data Analysis")
|
||||||
|
|
||||||
|
# Let user select which ERA5 temporal aggregation to view
|
||||||
|
era5_options = {
|
||||||
|
"ERA5-yearly": "Yearly",
|
||||||
|
"ERA5-seasonal": "Seasonal (Winter/Summer)",
|
||||||
|
"ERA5-shoulder": "Shoulder Seasons (JFM/AMJ/JAS/OND)",
|
||||||
|
}
|
||||||
|
|
||||||
|
available_era5 = {k: v for k, v in era5_options.items() if k in era5_members}
|
||||||
|
|
||||||
|
selected_era5 = st.selectbox(
|
||||||
|
"Select ERA5 temporal aggregation",
|
||||||
|
options=list(available_era5.keys()),
|
||||||
|
format_func=lambda x: available_era5[x],
|
||||||
|
key="era5_temporal_select",
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_era5:
|
||||||
|
temporal_type = selected_era5.split("-")[1] # 'yearly', 'seasonal', or 'shoulder'
|
||||||
|
|
||||||
|
with st.spinner(f"Loading {selected_era5} data..."):
|
||||||
|
era5_ds, targets = load_source_data(ensemble, selected_era5)
|
||||||
|
|
||||||
|
st.success(f"Loaded {selected_era5} data with {len(era5_ds['cell_ids'])} cells")
|
||||||
|
|
||||||
|
render_era5_overview(era5_ds, temporal_type)
|
||||||
|
render_era5_plots(era5_ds, temporal_type)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
render_era5_map(era5_ds, targets, ensemble.grid, temporal_type)
|
||||||
|
|
||||||
|
st.balloons()
|
||||||
|
|
||||||
# Add more components and visualizations as needed for training data.
|
|
||||||
else:
|
else:
|
||||||
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
|
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")
|
||||||
|
|
|
||||||
|
|
@ -99,3 +99,23 @@ def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingD
|
||||||
"count": e.create_cat_training_dataset("count"),
|
"count": e.create_cat_training_dataset("count"),
|
||||||
"density": e.create_cat_training_dataset("density"),
|
"density": e.create_cat_training_dataset("density"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def load_source_data(e: DatasetEnsemble, source: str):
|
||||||
|
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: DatasetEnsemble object.
|
||||||
|
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xarray.Dataset with the raw data for the specified source.
|
||||||
|
|
||||||
|
"""
|
||||||
|
targets = e._read_target()
|
||||||
|
|
||||||
|
# Load the member data lazily to get metadata
|
||||||
|
ds = e._read_member(source, targets, lazy=False)
|
||||||
|
|
||||||
|
return ds, targets
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue