Start redoing the dashboard

This commit is contained in:
Tobias Hölzer 2025-12-18 22:49:25 +01:00
parent d22b857722
commit f5ea72e05e
22 changed files with 2610 additions and 1448 deletions

View file

@ -75,62 +75,14 @@ Since the resolution of the ERA5 dataset is spatially smaller than the resolutio
For geometries crossing the antimeridian, geometries are corrected. For geometries crossing the antimeridian, geometries are corrected.
| grid | method | | grid | method | ~#pixel |
| ----- | ----------- | | ----- | ----------- | ------------ |
| Hex3 | Common | | Hex3 | Common | 235 [30,850] |
| Hex4 | Mean | | Hex4 | Common | 44 [8,166] |
| Hex5 | Interpolate | | Hex5 | Mean-only | 11 [3,41] |
| Hex6 | Interpolate | | Hex6 | Interpolate | 4 [2,14] |
| Hpx6 | Common | | Hpx6 | Common | 204 [25,769] |
| Hpx7 | Mean | | Hpx7 | Common | 62 [9,231] |
| Hpx8 | Mean | | Hpx8 | Mean-only | 21 [4,75] |
| Hpx9 | Interpolate | | Hpx9 | Mean-only | 9 [2,29] |
| Hpx10 | Interpolate | | Hpx10 | Interpolate | 2 [2,15] |
- hex level 3
min: 30.0
max: 850.0
mean: 251.25216674804688
median: 235.5
- hex level 4
min: 8.0
max: 166.0
mean: 47.2462158203125
median: 44.0
- hex level 5
min: 3.0
max: 41.0
mean: 11.164162635803223
median: 10.0
- hex level 6
min: 2.0
max: 14.0
mean: 4.509947776794434
median: 4.0
- healpix level 6
min: 25.0
max: 769.0
mean: 214.97296142578125
median: 204.0
healpix level 7
min: 9.0
max: 231.0
mean: 65.91140747070312
median: 62.0
healpix level 8
min: 4.0
max: 75.0
mean: 22.516725540161133
median: 21.0
healpix level 9
min: 2.0
max: 29.0
mean: 8.952080726623535
median: 9.0
healpix level 10
min: 2.0
max: 15.0
mean: 4.361577987670898
median: 4.0
???

2712
pixi.lock generated

File diff suppressed because it is too large Load diff

View file

@ -62,6 +62,10 @@ dependencies = [
"xarray-histogram>=0.2.2,<0.3", "xarray-histogram>=0.2.2,<0.3",
"antimeridian>=0.4.5,<0.5", "antimeridian>=0.4.5,<0.5",
"duckdb>=1.4.2,<2", "duckdb>=1.4.2,<2",
"pydeck>=0.9.1,<0.10",
"pypalettes>=0.2.1,<0.3",
"ty>=0.0.2,<0.0.3",
"ruff>=0.14.9,<0.15", "pandas-stubs>=2.3.3.251201,<3",
] ]
[project.scripts] [project.scripts]
@ -70,8 +74,7 @@ darts = "entropice.darts:cli"
alpha-earth = "entropice.alphaearth:main" alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli" era5 = "entropice.era5:cli"
arcticdem = "entropice.arcticdem:cli" arcticdem = "entropice.arcticdem:cli"
train = "entropice.training:main" train = "entropice.training:cli"
dataset = "entropice.dataset:main"
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
@ -126,7 +129,7 @@ entropice = { path = ".", editable = true }
dashboard = { cmd = [ dashboard = { cmd = [
"streamlit", "streamlit",
"run", "run",
"src/entropice/training_analysis_dashboard.py", "src/entropice/dashboard/app.py",
"--server.port", "--server.port",
"8501", "8501",
"--server.address", "--server.address",

View file

@ -17,6 +17,7 @@ import geopandas as gpd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import odc.geo.geobox import odc.geo.geobox
import odc.geo.types
import pandas as pd import pandas as pd
import psutil import psutil
import shapely import shapely
@ -186,7 +187,7 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
x, y = enclosing.shape x, y = enclosing.shape
if x <= 1 or y <= 1: if x <= 1 or y <= 1:
return False return False
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing) roi: odc.geo.types.NormalizedROI = geobox.overlap_roi(enclosing)
roix, roiy = roi roix, roiy = roi
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1 return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
@ -216,7 +217,7 @@ def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggreg
@stopwatch("Extracting split cell data", log=False) @stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations): def _extract_split_cell_data(cropped_list: list[xr.Dataset] | list[xr.DataArray], aggregations: _Aggregations):
spatdims = ( spatdims = (
["latitude", "longitude"] ["latitude", "longitude"]
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
@ -370,6 +371,7 @@ def _align_partition(
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes # => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
# faster than a processpoolexecutor # faster than a processpoolexecutor
assert memprof is not None, "Memory profiler is not initialized in worker"
memprof.log_memory("Before reading partial raster", log=False) memprof.log_memory("Before reading partial raster", log=False)
need_to_close_raster = False need_to_close_raster = False
@ -513,6 +515,7 @@ def _align_data(
n_partitions = len(grid_gdf) n_partitions = len(grid_gdf)
grid_partitions = grid_gdf grid_partitions = grid_gdf
else: else:
assert n_partitions is not None, "n_partitions must be provided when grid_gdf is not a list"
grid_partitions = partition_grid(grid_gdf, n_partitions) grid_partitions = partition_grid(grid_gdf, n_partitions)
if n_partitions < concurrent_partitions: if n_partitions < concurrent_partitions:
@ -539,7 +542,8 @@ def _align_data(
# For spawn or forkserver, we need to copy the raster into each worker # For spawn or forkserver, we need to copy the raster into each worker
workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,) workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,)
# For mp start method fork, we can share the raster dataset between workers # For mp start method fork, we can share the raster dataset between workers
if mp.get_start_method(allow_none=True) == "fork" and is_raster_in_memory: if is_mpfork and is_raster_in_memory:
assert isinstance(raster, xr.Dataset) # satisfy type checker, but this is already checked above
_init_worker(raster) _init_worker(raster)
with ProcessPoolExecutor( with ProcessPoolExecutor(

View file

@ -72,7 +72,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
scale_factor = scale_factors[grid][level] scale_factor = scale_factors[grid][level]
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.") print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
for year in track(range(2018, 2025), total=7, description="Processing years..."): for year in track(range(2021, 2025), total=4, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"] aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]

View file

@ -18,6 +18,7 @@ import smart_geocubes
import xarray as xr import xarray as xr
import xdggs import xdggs
import xrspatial import xrspatial
import xrspatial.convolution
import zarr import zarr
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
from rich import pretty, print, traceback from rich import pretty, print, traceback
@ -116,7 +117,8 @@ def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory):
return vrm return vrm
def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tuple[cp.array, cp.array]: def _get_xy_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=None) -> tuple[cp.ndarray, cp.ndarray]:
assert isinstance(block_info, list) and len(block_info) >= 1
chunk_loc = block_info[0]["chunk-location"] chunk_loc = block_info[0]["chunk-location"]
d = 15 d = 15
cs = 3600 cs = 3600
@ -149,7 +151,7 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
return xx, yy return xx, yy
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array: def _enrich_chunk(chunk: np.ndarray, x: np.ndarray, y: np.ndarray, block_info=None) -> np.ndarray:
res = 32 # 32m resolution res = 32 # 32m resolution
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m) small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m) medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)

View file

@ -63,10 +63,15 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
) )
# Apply corrections to NaNs # Apply corrections to NaNs
covered = ~grid_gdf[f"darts_{year}_coverage"].isnull() covered = ~grid_gdf[f"darts_{year}_coverage"].isna()
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna( grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
0.0 0.0
) )
grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[
covered, f"darts_{year}_rts_density"
].fillna(0.0)
grid_gdf[f"darts_{year}_has_coverage"] = covered
grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1) grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1) grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
@ -128,9 +133,10 @@ def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
# Apply corrections to NaNs # Apply corrections to NaNs
covered = ~grid_gdf["dartsml_coverage"].isna() covered = ~grid_gdf["dartsml_coverage"].isna()
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0) grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0)
grid_gdf["dartsml_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna() grid_gdf["dartsml_has_coverage"] = covered
grid_gdf["dartsml_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna() grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0
output_path = get_darts_rts_file(grid, level, labels=True) output_path = get_darts_rts_file(grid, level, labels=True)
grid_gdf.to_parquet(output_path) grid_gdf.to_parquet(output_path)

View file

View file

@ -0,0 +1,44 @@
"""Streamlit app for Entropice dashboard.
Pages:
- Overview: List of available result directories with some summary statistics.
- Training Data: Visualization of training data distributions.
- Training Results Analysis: Analysis of training results and model performance.
- Model State: Visualization of model state and features.
- Inference: Visualization of inference results.
"""
import streamlit as st
from entropice.dashboard.inference_page import render_inference_page
from entropice.dashboard.model_state_page import render_model_state_page
from entropice.dashboard.overview_page import render_overview_page
from entropice.dashboard.training_analysis_page import render_training_analysis_page
from entropice.dashboard.training_data_page import render_training_data_page
def main():
"""Run the dashboard."""
st.set_page_config(page_title="Entropice Dashboard", layout="wide")
# Setup Navigation
overview_page = st.Page(render_overview_page, title="Overview", icon="🏡", default=True)
training_data_page = st.Page(render_training_data_page, title="Training Data", icon="🎞️")
training_analysis_page = st.Page(render_training_analysis_page, title="Training Results Analysis", icon="🦾")
model_state_page = st.Page(render_model_state_page, title="Model State", icon="🧮")
inference_page = st.Page(render_inference_page, title="Inference", icon="🗺️")
pg = st.navigation(
{
"Overview": [overview_page],
"Training": [training_data_page, training_analysis_page],
"Model State": [model_state_page],
"Inference": [inference_page],
}
)
pg.run()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,8 @@
import streamlit as st
def render_inference_page():
"""Render the Inference page of the dashboard."""
st.title("Inference Results")
st.write("This page will display inference results and visualizations.")
# Add more components and visualizations as needed for inference results.

View file

@ -0,0 +1,8 @@
import streamlit as st
def render_model_state_page():
"""Render the Model State page of the dashboard."""
st.title("Model State")
st.write("This page will display model state and feature visualizations.")
# Add more components and visualizations as needed for model state.

View file

@ -0,0 +1,155 @@
"""Overview page: List of available result directories with some summary statistics."""
from datetime import datetime
import pandas as pd
import streamlit as st
from entropice.dashboard.utils.data import load_all_training_results
def render_overview_page():
"""Render the Overview page of the dashboard."""
st.title("🏡 Training Results Overview")
training_results = load_all_training_results()
if not training_results:
st.warning("No training results found. Please run some training experiments first.")
return
st.write(f"Found **{len(training_results)}** training result(s)")
# Summary statistics at the top
st.subheader("Summary Statistics")
col1, col2, col3, col4 = st.columns(4)
with col1:
tasks = {tr.settings.get("task", "Unknown") for tr in training_results}
st.metric("Tasks", len(tasks))
with col2:
grids = {tr.settings.get("grid", "Unknown") for tr in training_results}
st.metric("Grid Types", len(grids))
with col3:
models = {tr.settings.get("model", "Unknown") for tr in training_results}
st.metric("Model Types", len(models))
with col4:
latest = training_results[0] # Already sorted by creation time
latest_date = datetime.fromtimestamp(latest.created_at).strftime("%Y-%m-%d")
st.metric("Latest Run", latest_date)
st.divider()
# Detailed results table
st.subheader("Training Results")
# Build a summary dataframe
summary_data = []
for tr in training_results:
# Extract best scores from the results dataframe
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
best_scores = {}
for col in score_cols:
metric_name = col.replace("mean_test_", "")
best_score = tr.results[col].max()
best_scores[metric_name] = best_score
# Get primary metric (usually the first one or accuracy)
primary_metric = (
"accuracy"
if "mean_test_accuracy" in tr.results.columns
else score_cols[0].replace("mean_test_", "")
if score_cols
else "N/A"
)
primary_score = best_scores.get(primary_metric, 0.0)
summary_data.append(
{
"Date": datetime.fromtimestamp(tr.created_at).strftime("%Y-%m-%d %H:%M"),
"Task": tr.settings.get("task", "Unknown"),
"Grid": tr.settings.get("grid", "Unknown"),
"Level": tr.settings.get("level", "Unknown"),
"Model": tr.settings.get("model", "Unknown"),
f"Best {primary_metric.title()}": f"{primary_score:.4f}",
"Trials": len(tr.results),
"Path": str(tr.path.name),
}
)
summary_df = pd.DataFrame(summary_data)
# Display with color coding for best scores
st.dataframe(
summary_df,
width="stretch",
hide_index=True,
)
st.divider()
# Expandable details for each result
st.subheader("Detailed Results")
for tr in training_results:
with st.expander(tr.name):
col1, col2 = st.columns([1, 2])
with col1:
st.write("**Configuration:**")
st.write(f"- **Task:** {tr.settings.get('task', 'Unknown')}")
st.write(f"- **Grid:** {tr.settings.get('grid', 'Unknown')}")
st.write(f"- **Level:** {tr.settings.get('level', 'Unknown')}")
st.write(f"- **Model:** {tr.settings.get('model', 'Unknown')}")
st.write(f"- **CV Splits:** {tr.settings.get('cv_splits', 'Unknown')}")
st.write(f"- **Classes:** {tr.settings.get('classes', 'Unknown')}")
st.write("\n**Files:**")
st.write("- 📊 search_results.parquet")
st.write("- 🧮 best_estimator_state.nc")
st.write("- 🎯 predicted_probabilities.parquet")
st.write("- ⚙️ search_settings.toml")
with col2:
st.write("**Best Scores:**")
# Extract all test scores
score_cols = [col for col in tr.results.columns if col.startswith("mean_test_")]
if score_cols:
metric_data = []
for col in score_cols:
metric_name = col.replace("mean_test_", "").title()
best_score = tr.results[col].max()
mean_score = tr.results[col].mean()
std_score = tr.results[col].std()
metric_data.append(
{
"Metric": metric_name,
"Best": f"{best_score:.4f}",
"Mean": f"{mean_score:.4f}",
"Std": f"{std_score:.4f}",
}
)
metric_df = pd.DataFrame(metric_data)
st.dataframe(metric_df, width="stretch", hide_index=True)
else:
st.write("No test scores found in results.")
# Show parameter space explored
if "initial_K" in tr.results.columns: # Common parameter
st.write("\n**Parameter Ranges Explored:**")
for param in ["initial_K", "eps_cl", "eps_e"]:
if param in tr.results.columns:
min_val = tr.results[param].min()
max_val = tr.results[param].max()
unique_vals = tr.results[param].nunique()
st.write(f"- **{param}:** {unique_vals} values ({min_val:.2e} to {max_val:.2e})")
st.write(f"\n**Path:** `{tr.path}`")

View file

@ -0,0 +1,92 @@
"""Color related utilities for dashboard plots.
Color palettes from https://python-graph-gallery.com/color-palette-finder/
Material palettes:
- amber_material
- blue_grey_material
- blue_material
- brown_material
- cyan_material
- deep_orange_material
- deep_purple_material
- green_material
- grey_material
- indigo_material
- light_blue_material
- light_green_material
- lime_material
- orange_material
- pink_material
- purple_material
- red_material
- teal_material
- yellow_material
"""
import matplotlib.colors as mcolors
from pypalettes import load_cmap
def get_cmap(variable: str) -> mcolors.Colormap:
"""Get a color palette by a "data" variable.
Each variable (meaning of the data) should be associated with another color palette when plotting.
This function should help to standardize the color palettes used for each variable type.
The variable can be any string, descriptive names are recommended.
Args:
variable: The variable to load a palette for.
Returns:
A list of hex color strings.
"""
material_palettes = [
"amber_material",
"blue_grey_material",
"blue_material",
"brown_material",
"cyan_material",
"deep_orange_material",
"deep_purple_material",
"green_material",
"grey_material",
"indigo_material",
"light_blue_material",
"light_green_material",
"lime_material",
"orange_material",
"pink_material",
"purple_material",
"red_material",
"teal_material",
"yellow_material",
]
# Fuzzy map from variable type to palette name
material_idx = sum(ord(c) for c in variable) % len(material_palettes)
palette_name = material_palettes[material_idx]
cmap = load_cmap(name=palette_name)
return cmap
def get_palette(variable: str, n_colors: int) -> list[str]:
"""Get a color palette by a "data" variable.
Each variable (meaning of the data) should be associated with another color palette when plotting.
This function should help to standardize the color palettes used for each variable type.
The variable can be any string, descriptive names are recommended.
Args:
variable: The variable to load a palette for.
n_colors: The number of colors to return.
Returns:
A list of hex color strings.
"""
cmap = get_cmap(variable).resampled(n_colors)
colors = [mcolors.to_hex(cmap(i)) for i in range(cmap.N)]
return colors

View file

@ -0,0 +1,356 @@
"""Plotting functions for training data visualizations."""
import geopandas as gpd
import pandas as pd
import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_palette
from entropice.dataset import CategoricalTrainingDataset
def render_all_distribution_histograms(train_data_dict: dict[str, CategoricalTrainingDataset]):
"""Render histograms for all three tasks side by side.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("📊 Target Distribution by Task")
# Create a 3-column layout for the three tasks
cols = st.columns(3)
tasks = ["binary", "count", "density"]
task_titles = {
"binary": "Binary Classification",
"count": "Count Classification",
"density": "Density Classification",
}
for idx, task in enumerate(tasks):
dataset = train_data_dict[task]
categories = dataset.y.binned.cat.categories.tolist()
colors = get_palette(task, len(categories))
with cols[idx]:
st.markdown(f"**{task_titles[task]}**")
# Create histogram data
counts_df = pd.DataFrame(
{
"Category": categories,
"Train": [((dataset.y.binned == cat) & (dataset.split == "train")).sum() for cat in categories],
"Test": [((dataset.y.binned == cat) & (dataset.split == "test")).sum() for cat in categories],
}
)
# Create stacked bar chart
fig = go.Figure()
fig.add_trace(
go.Bar(
name="Train",
x=counts_df["Category"],
y=counts_df["Train"],
marker_color=colors,
opacity=0.9,
text=counts_df["Train"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.add_trace(
go.Bar(
name="Test",
x=counts_df["Category"],
y=counts_df["Test"],
marker_color=colors,
opacity=0.6,
text=counts_df["Test"],
textposition="inside",
textfont={"size": 10, "color": "white"},
)
)
fig.update_layout(
barmode="group",
height=400,
margin={"l": 20, "r": 20, "t": 20, "b": 20},
showlegend=True,
legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
xaxis_title=None,
yaxis_title="Count",
xaxis={"tickangle": -45},
)
st.plotly_chart(fig, width="stretch")
# Show summary statistics
total = len(dataset)
train_pct = (dataset.split == "train").sum() / total * 100
test_pct = (dataset.split == "test").sum() / total * 100
st.caption(f"Total: {total:,} | Train: {train_pct:.1f}% | Test: {test_pct:.1f}%")
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
import antimeridian
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
def _assign_colors_by_mode(gdf, color_mode, dataset, selected_task):
"""Assign colors to geodataframe based on the selected color mode.
Args:
gdf: GeoDataFrame to add colors to
color_mode: One of 'target_class' or 'split'
dataset: CategoricalTrainingDataset
selected_task: Task name for color palette selection
Returns:
GeoDataFrame with 'fill_color' column added
"""
if color_mode == "target_class":
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
# Create color mapping
color_map = {cat: colors_palette[i] for i, cat in enumerate(categories)}
gdf["color"] = gdf["target_class"].map(color_map)
# Convert hex colors to RGB
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip("#")
return [int(hex_color[i : i + 2], 16) for i in (0, 2, 4)]
gdf["fill_color"] = gdf["color"].apply(hex_to_rgb)
elif color_mode == "split":
split_colors = {"train": [66, 135, 245], "test": [245, 135, 66]} # Blue # Orange
gdf["fill_color"] = gdf["split"].map(split_colors)
return gdf
@st.fragment
def render_spatial_map(train_data_dict: dict[str, CategoricalTrainingDataset]):
"""Render a pydeck spatial map showing training data distribution with interactive controls.
This is a Streamlit fragment that reruns independently when users interact with the
visualization controls (color mode and opacity), without re-running the entire page.
Args:
train_data_dict: Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
st.subheader("🗺️ Spatial Distribution Map")
# Create controls in columns
col1, col2 = st.columns([3, 1])
with col1:
vis_mode = st.selectbox(
"Visualization mode",
options=["binary", "count", "density", "split"],
format_func=lambda x: x.capitalize() if x != "split" else "Train/Test Split",
key="spatial_map_mode",
)
with col2:
opacity = st.slider("Opacity", min_value=0.1, max_value=1.0, value=0.7, step=0.1, key="spatial_map_opacity")
# Determine which task dataset to use and color mode
if vis_mode == "split":
# Use binary dataset for split visualization
dataset = train_data_dict["binary"]
color_mode = "split"
selected_task = "binary"
else:
# Use the selected task
dataset = train_data_dict[vis_mode]
color_mode = "target_class"
selected_task = vis_mode
# Prepare data for visualization - dataset.dataset should already be a GeoDataFrame
gdf: gpd.GeoDataFrame = dataset.dataset.copy() # type: ignore[assignment]
# Fix antimeridian issues
gdf["geometry"] = gdf["geometry"].apply(_fix_hex_geometry)
# Add binned labels and split information from current dataset
gdf["target_class"] = dataset.y.binned.to_numpy()
gdf["split"] = dataset.split.to_numpy()
gdf["raw_value"] = dataset.z.to_numpy()
# Add information from all three tasks for tooltip
gdf["binary_label"] = train_data_dict["binary"].y.binned.to_numpy()
gdf["count_category"] = train_data_dict["count"].y.binned.to_numpy()
gdf["count_raw"] = train_data_dict["count"].z.to_numpy()
gdf["density_category"] = train_data_dict["density"].y.binned.to_numpy()
gdf["density_raw"] = train_data_dict["density"].z.to_numpy()
# Convert to WGS84 for pydeck
gdf_wgs84: gpd.GeoDataFrame = gdf.to_crs("EPSG:4326") # type: ignore[assignment]
# Assign colors based on the selected mode
gdf_wgs84 = _assign_colors_by_mode(gdf_wgs84, color_mode, dataset, selected_task)
# Convert to GeoJSON format and add elevation for 3D visualization
geojson_data = []
# Normalize raw values for elevation (only for count and density)
use_elevation = vis_mode in ["count", "density"]
if use_elevation:
raw_values = gdf_wgs84["raw_value"]
min_val, max_val = raw_values.min(), raw_values.max()
# Normalize to 0-1 range for better 3D visualization
if max_val > min_val:
gdf_wgs84["elevation"] = ((raw_values - min_val) / (max_val - min_val)).fillna(0)
else:
gdf_wgs84["elevation"] = 0
for _, row in gdf_wgs84.iterrows():
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"target_class": str(row["target_class"]),
"split": str(row["split"]),
"raw_value": float(row["raw_value"]),
"fill_color": row["fill_color"],
"elevation": float(row["elevation"]) if use_elevation else 0,
"binary_label": str(row["binary_label"]),
"count_category": str(row["count_category"]),
"count_raw": int(row["count_raw"]),
"density_category": str(row["density_category"]),
"density_raw": f"{float(row['density_raw']):.4f}",
},
}
geojson_data.append(feature)
# Create pydeck layer
layer = pdk.Layer(
"GeoJsonLayer",
geojson_data,
opacity=opacity,
stroked=True,
filled=True,
extruded=use_elevation,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_elevation="properties.elevation" if use_elevation else 0,
elevation_scale=500000, # Scale normalized values (0-1) to 500km height
pickable=True,
)
# Set initial view state (centered on the Arctic)
# Adjust pitch and zoom based on whether we're using elevation
view_state = pdk.ViewState(
latitude=70, longitude=0, zoom=2 if not use_elevation else 1.5, pitch=0 if not use_elevation else 45
)
# Create deck
deck = pdk.Deck(
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>Binary:</b> {binary_label}<br/>"
"<b>Count Category:</b> {count_category}<br/>"
"<b>Count Raw:</b> {count_raw}<br/>"
"<b>Density Category:</b> {density_category}<br/>"
"<b>Density Raw:</b> {density_raw}<br/>"
"<b>Split:</b> {split}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
)
# Render the map
st.pydeck_chart(deck)
# Show info about 3D visualization
if use_elevation:
st.info("💡 3D elevation represents raw values. Rotate the map by holding Ctrl/Cmd and dragging.")
# Add legend
with st.expander("Legend", expanded=True):
if color_mode == "target_class":
st.markdown("**Target Classes:**")
categories = dataset.y.binned.cat.categories.tolist()
colors_palette = get_palette(selected_task, len(categories))
intervals = dataset.y.intervals
# For count and density tasks, show intervals
if selected_task in ["count", "density"]:
for i, cat in enumerate(categories):
color = colors_palette[i]
interval_min, interval_max = intervals[i]
# Format interval display
if interval_min is None or interval_max is None:
interval_str = ""
elif selected_task == "count":
# Integer values for count
if interval_min == interval_max:
interval_str = f" ({int(interval_min)})"
else:
interval_str = f" ({int(interval_min)}-{int(interval_max)})"
else: # density
# Percentage values for density
if interval_min == interval_max:
interval_str = f" ({interval_min * 100:.4f}%)"
else:
interval_str = f" ({interval_min * 100:.4f}%-{interval_max * 100:.4f}%)"
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{cat}{interval_str}</span></div>",
unsafe_allow_html=True,
)
else:
# Binary task: use original column layout
legend_cols = st.columns(len(categories))
for i, cat in enumerate(categories):
with legend_cols[i]:
color = colors_palette[i]
st.markdown(
f'<div style="display: flex; align-items: center;">'
f'<div style="width: 20px; height: 20px; background-color: {color}; '
f'margin-right: 8px; border: 1px solid #ccc;"></div>'
f"<span>{cat}</span></div>",
unsafe_allow_html=True,
)
if use_elevation:
st.markdown("---")
st.markdown("**Elevation (3D):**")
min_val = gdf_wgs84["raw_value"].min()
max_val = gdf_wgs84["raw_value"].max()
st.markdown(f"Height represents raw value: {min_val:.2f} (low) → {max_val:.2f} (high)")
elif color_mode == "split":
st.markdown("**Data Split:**")
legend_html = (
'<div style="display: flex; gap: 20px;">'
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(66, 135, 245); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Train</span></div>"
'<div style="display: flex; align-items: center;">'
'<div style="width: 20px; height: 20px; background-color: rgb(245, 135, 66); '
'margin-right: 8px; border: 1px solid #ccc;"></div>'
"<span>Test</span></div></div>"
)
st.markdown(legend_html, unsafe_allow_html=True)

View file

@ -0,0 +1,10 @@
"""Training Results Analysis page: Analysis of training results and model performance."""
import streamlit as st
def render_training_analysis_page():
"""Render the Training Results Analysis page of the dashboard."""
st.title("Training Results Analysis")
st.write("This page will display analysis of training results and model performance.")
# Add more components and visualizations as needed for training results analysis.

View file

@ -0,0 +1,136 @@
"""Training Data page: Visualization of training data distributions."""
import streamlit as st
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
from entropice.dashboard.utils.data import load_all_training_data
from entropice.dataset import DatasetEnsemble
def render_training_data_page():
"""Render the Training Data page of the dashboard."""
st.title("Training Data")
# Sidebar widgets for dataset configuration in a form
with st.sidebar.form("dataset_config_form"):
st.header("Dataset Configuration")
# Combined grid and level selection
grid_options = [
"hex-3",
"hex-4",
"hex-5",
"hex-6",
"healpix-6",
"healpix-7",
"healpix-8",
"healpix-9",
"healpix-10",
]
grid_level_combined = st.selectbox(
"Grid Configuration", options=grid_options, index=0, help="Select the grid system and resolution level"
)
# Parse grid type and level
grid, level_str = grid_level_combined.split("-")
level = int(level_str)
# Target feature selection
target = st.selectbox(
"Target Feature",
options=["darts_rts", "darts_mllabels"],
index=0,
help="Select the target variable for training",
)
# Members selection
st.subheader("Dataset Members")
# Check if AlphaEarth should be disabled
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
selected_members = []
for member in all_members:
if member == "AlphaEarth" and disable_alphaearth:
# Show disabled checkbox with explanation
st.checkbox(
member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}"
)
else:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member)
# Form submit button
load_button = st.form_submit_button(
"Load Dataset", type="primary", use_container_width=True, disabled=len(selected_members) == 0
)
# Create DatasetEnsemble only when form is submitted
if load_button:
ensemble = DatasetEnsemble(grid=grid, level=level, target=target, members=selected_members)
# Store ensemble in session state
st.session_state["dataset_ensemble"] = ensemble
st.session_state["dataset_loaded"] = True
# Display dataset information if loaded
if st.session_state.get("dataset_loaded", False) and "dataset_ensemble" in st.session_state:
ensemble = st.session_state["dataset_ensemble"]
# Display current configuration
st.subheader("📊 Current Configuration")
# Create a visually appealing layout with columns
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric(label="Grid Type", value=ensemble.grid.upper())
with col2:
st.metric(label="Grid Level", value=ensemble.level)
with col3:
st.metric(label="Target Feature", value=ensemble.target.replace("darts_", ""))
with col4:
st.metric(label="Members", value=len(ensemble.members))
# Display members in an expandable section
with st.expander("🗂️ Dataset Members", expanded=False):
members_cols = st.columns(len(ensemble.members))
for idx, member in enumerate(ensemble.members):
with members_cols[idx]:
st.markdown(f"✓ **{member}**")
# Display dataset ID in a styled container
st.info(f"**Dataset ID:** `{ensemble.id()}`")
# Load training data for all three tasks
train_data_dict = load_all_training_data(ensemble)
# Calculate total samples (use binary as reference)
total_samples = len(train_data_dict["binary"])
train_samples = (train_data_dict["binary"].split == "train").sum().item()
test_samples = (train_data_dict["binary"].split == "test").sum().item()
st.success(f"Loaded {total_samples} samples ({train_samples} train, {test_samples} test) for all three tasks")
# Render distribution histograms
st.markdown("---")
render_all_distribution_histograms(train_data_dict)
st.markdown("---")
# Render spatial map (as a fragment for efficient re-rendering)
# Extract geometries from the X.data dataframe (which has geometry as a column)
# The index should be cell_id
binary_dataset = train_data_dict["binary"]
assert "geometry" in binary_dataset.dataset.columns, "Geometry column missing in dataset"
render_spatial_map(train_data_dict)
# Add more components and visualizations as needed for training data.
else:
st.info("Configure the dataset settings in the sidebar and click 'Load Dataset' to begin.")

View file

@ -0,0 +1,101 @@
"""Data utilities for Entropice dashboard."""
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import antimeridian
import pandas as pd
import streamlit as st
import toml
from shapely.geometry import shape
import entropice.paths
from entropice.dataset import CategoricalTrainingDataset, DatasetEnsemble
@dataclass
class TrainingResult:
"""Simple wrapper of training result data."""
name: str
path: Path
settings: dict
results: pd.DataFrame
created_at: float
@classmethod
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
state_file = result_path / "best_estimator_state.nc"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), state_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
created_at = result_path.stat().st_ctime
settings = toml.load(settings_file)["settings"]
results = pd.read_parquet(result_file)
# Name should be "task grid-level (created_at)"
name = (
f"**{settings.get('task', 'Unknown').capitalize()}** -"
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
)
return cls(
name=name,
path=result_path,
settings=settings,
results=results,
created_at=created_at,
)
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
@st.cache_data
def load_all_training_results() -> list[TrainingResult]:
"""Load all training results from the results directory."""
results_dir = entropice.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
try:
training_result = TrainingResult.from_path(result_path)
training_results.append(training_result)
except FileNotFoundError:
continue
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
@st.cache_data
def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]:
"""Load training data for all three tasks.
Args:
e: DatasetEnsemble object.
Returns:
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
return {
"binary": e.create_cat_training_dataset("binary"),
"count": e.create_cat_training_dataset("count"),
"density": e.create_cat_training_dataset("density"),
}

View file

@ -1,3 +1,4 @@
# ruff: noqa: N806
"""Training dataset preparation and model training. """Training dataset preparation and model training.
Naming conventions: Naming conventions:
@ -14,15 +15,18 @@ Naming conventions:
import hashlib import hashlib
import json import json
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from functools import cached_property, lru_cache
from typing import Literal from typing import Literal
import cyclopts import cyclopts
import geopandas as gpd import geopandas as gpd
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
import torch
import xarray as xr import xarray as xr
from rich import pretty, traceback from rich import pretty, traceback
from sklearn import set_config from sklearn import set_config
from sklearn.model_selection import train_test_split
import entropice.paths import entropice.paths
@ -35,29 +39,111 @@ sns.set_theme("talk", "whitegrid")
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
if temporal == "yearly": if temporal == "yearly":
return df.index.get_level_values("time").year return time_index.year
elif temporal == "seasonal": elif temporal == "seasonal":
seasons = {10: "winter", 4: "summer"} seasons = {10: "winter", 4: "summer"}
return ( return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder": elif temporal == "shoulder":
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
return ( return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
type Task = Literal["binary", "count", "density"]
def bin_values(
values: pd.Series,
task: Literal["count", "density"],
none_val: float = 0,
) -> pd.Series:
"""Bin values into predefined intervals for different tasks.
First, a 'none' bin is created for values equal to `none_val` usually 0.
Then, the remaining values are binned automatically into 5 bins, all containing roughly the same number of samples.
Args:
values (pd.Series): Pandas Series of numerical values to bin.
task (Literal["count", "density"]): Task type - 'count' or 'density'.
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
Returns:
pd.Series: Pandas Series of ordered categorical binned values.
Raises:
ValueError: If an value is NaN.
"""
labels_dict = {
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
"density": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
}
labels = labels_dict[task]
if values.isna().any():
raise ValueError("Values contain NaN")
# Separate none values from others
none_mask = values == none_val
non_none_values = values[~none_mask]
assert len(non_none_values) > 5, "Not enough non-none values to create bins."
binned_non_none = pd.qcut(non_none_values, q=5, labels=labels[1:]).cat.set_categories(labels, ordered=True)
binned = pd.Series(index=values.index, dtype="category")
binned = binned.cat.set_categories(labels, ordered=True)
binned.update(binned_non_none)
binned.loc[none_mask] = labels[0]
return binned
@dataclass(frozen=True, eq=False)
class DatasetLabels:
binned: pd.Series
train: torch.Tensor
test: torch.Tensor
raw_values: pd.Series
@cached_property
def intervals(self) -> list[tuple[float, float] | tuple[int, int]]:
# For each category get the min and max values from raw_values
intervals = []
for category in self.binned.cat.categories:
category_mask = self.binned == category
if category_mask.sum() == 0:
intervals.append((None, None))
else:
category_raw_values = self.raw_values[category_mask]
intervals.append((category_raw_values.min(), category_raw_values.max()))
return intervals
@cached_property
def labels(self) -> list[str]:
return list(self.binned.cat.categories)
@dataclass(frozen=True, eq=False)
class DatasetInputs:
data: pd.DataFrame
train: torch.Tensor
test: torch.Tensor
@dataclass(frozen=True)
class CategoricalTrainingDataset:
dataset: pd.DataFrame
X: DatasetInputs
y: DatasetLabels
z: pd.Series
split: pd.Series
def __len__(self):
return len(self.z)
@cyclopts.Parameter("*") @cyclopts.Parameter("*")
@dataclass @dataclass(frozen=True)
class DatasetEnsemble: class DatasetEnsemble:
grid: Literal["hex", "healpix"] grid: Literal["hex", "healpix"]
level: int level: int
@ -70,17 +156,35 @@ class DatasetEnsemble:
filter_target: str | Literal[False] = False filter_target: str | Literal[False] = False
add_lonlat: bool = True add_lonlat: bool = True
def __hash__(self):
return int(self.id(), 16)
def id(self): def id(self):
return hashlib.blake2b( return hashlib.blake2b(
json.dumps(asdict(self), sort_keys=True).encode("utf-8"), json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
digest_size=16, digest_size=16,
).hexdigest() ).hexdigest()
@property
def covcol(self) -> str:
return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
def taskcol(self, task: Task) -> str:
if task == "binary":
return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts"
elif task == "count":
return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count"
elif task == "density":
return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density"
else:
raise ValueError(f"Invalid task: {task}")
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth": if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
store = entropice.paths.get_era5_stores(member.split("-")[1], grid=self.grid, level=self.level) era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
store = entropice.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
elif member == "ArcticDEM": elif member == "ArcticDEM":
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level) store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
else: else:
@ -145,7 +249,7 @@ class DatasetEnsemble:
def _prep_era5( def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame: ) -> pd.DataFrame:
era5 = self._read_member(f"ERA5-{temporal}", targets) era5 = self._read_member("ERA5-" + temporal, targets)
era5_df = era5.to_dataframe() era5_df = era5.to_dataframe()
era5_df["t"] = _get_era5_tempus(era5_df, temporal) era5_df["t"] = _get_era5_tempus(era5_df, temporal)
if "aggregations" not in era5.dims: if "aggregations" not in era5.dims:
@ -190,9 +294,10 @@ class DatasetEnsemble:
n_cols += n_cols_member n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}") print(f"=== Total number of features in dataset: {n_cols}")
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: @lru_cache(maxsize=1)
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists # n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(self.id()) cache_file = entropice.paths.get_dataset_cache(self.id(), subset=filter_target_col)
if cache_mode == "r" and cache_file.exists(): if cache_mode == "r" and cache_file.exists():
dataset = gpd.read_parquet(cache_file) dataset = gpd.read_parquet(cache_file)
print( print(
@ -201,11 +306,14 @@ class DatasetEnsemble:
) )
return dataset return dataset
targets = self._read_target() targets = self._read_target()
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
member_dfs = [] member_dfs = []
for member in self.members: for member in self.members:
if member.startswith("ERA5"): if member.startswith("ERA5"):
member_dfs.append(self._prep_era5(targets, member.split("-")[1])) era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets, era5_agg))
elif member == "AlphaEarth": elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets)) member_dfs.append(self._prep_embeddings(targets))
elif member == "ArcticDEM": elif member == "ArcticDEM":
@ -220,3 +328,65 @@ class DatasetEnsemble:
dataset.to_parquet(cache_file) dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.") print(f"Saved dataset to cache at {cache_file}.")
return dataset return dataset
def create_cat_training_dataset(self, task: Task) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
"""
covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
dataset = self.create(filter_target_col=covcol)
taskcol = self.taskcol(task)
valid_labels = dataset[taskcol].notna()
cols_to_drop = {"geometry", taskcol, covcol}
cols_to_drop |= {
col
for col in dataset.columns
if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_")
}
model_inputs = dataset.drop(columns=cols_to_drop)
# Assert that no column in all-nan
assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN"
# Get valid inputs (rows)
valid_inputs = model_inputs.notna().all("columns")
dataset = dataset.loc[valid_labels & valid_inputs]
model_inputs = model_inputs.loc[valid_labels & valid_inputs]
model_labels = dataset[taskcol]
if task == "binary":
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
elif task == "count":
binned = bin_values(model_labels.astype(int), task=task)
elif task == "density":
binned = bin_values(model_labels, task=task)
else:
raise ValueError("Invalid task.")
# Create train / test split
train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
split = pd.Series(index=dataset.index, dtype=object)
split.loc[train_idx] = "train"
split.loc[test_idx] = "test"
split = split.astype("category")
X_train = torch.asarray(model_inputs.loc[train_idx].to_numpy(dtype="float64"), device=0)
X_test = torch.asarray(model_inputs.loc[test_idx].to_numpy(dtype="float64"), device=0)
y_train = torch.asarray(binned.loc[train_idx].cat.codes.to_numpy(dtype="int64"), device=0)
y_test = torch.asarray(binned.loc[test_idx].cat.codes.to_numpy(dtype="int64"), device=0)
return CategoricalTrainingDataset(
dataset=dataset.to_crs("EPSG:4326"),
X=DatasetInputs(data=model_inputs, train=X_train, test=X_test),
y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels),
z=model_labels,
split=split,
)

View file

@ -177,11 +177,11 @@ def download_daily_aggregated():
tchunksize = era5.chunksizes["time"][0] tchunksize = era5.chunksizes["time"][0]
era5_chunk_starts = pd.date_range(era5.time.min().item(), era5.time.max().item(), freq=f"{tchunksize}h") era5_chunk_starts = pd.date_range(era5.time.min().item(), era5.time.max().item(), freq=f"{tchunksize}h")
closest_chunk_start = era5_chunk_starts[ closest_chunk_start = era5_chunk_starts[
era5_chunk_starts.get_indexer([pd.to_datetime(min_time)], method="ffill")[0] era5_chunk_starts.get_indexer([pd.to_datetime(min_time)], method="ffill")[0] # ty:ignore[invalid-argument-type]
] ]
subset["time"] = slice(str(closest_chunk_start), max_time) subset["time"] = slice(str(closest_chunk_start), max_time)
era5 = era5.sel(**subset) era5 = era5.sel(subset)
daily_raw = xr.merge( daily_raw = xr.merge(
[ [
@ -680,7 +680,7 @@ def spatial_agg(
invalid_cell_id = [3059646, 3063547] invalid_cell_id = [3059646, 3063547]
grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)] grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)]
aggregations = { aggregations_by_gridlevel: dict[str, dict[int, _Aggregations | Literal["interpolate"]]] = {
"hex": { "hex": {
3: _Aggregations.common(), 3: _Aggregations.common(),
4: _Aggregations.common(), 4: _Aggregations.common(),
@ -695,9 +695,9 @@ def spatial_agg(
10: "interpolate", 10: "interpolate",
}, },
} }
aggregations = aggregations[grid][level] aggregations = aggregations_by_gridlevel[grid][level]
for agg in ["yearly", "seasonal", "shoulder"]: for agg in ("yearly", "seasonal", "shoulder"):
unaligned_store = get_era5_stores(agg) unaligned_store = get_era5_stores(agg)
with stopwatch(f"Loading {agg} ERA5 data"): with stopwatch(f"Loading {agg} ERA5 data"):
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load() unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load()

View file

@ -4,9 +4,11 @@
import geopandas as gpd import geopandas as gpd
import pandas as pd import pandas as pd
import torch import torch
from cuml.ensemble import RandomForestClassifier
from entropy import ESPAClassifier
from rich import pretty, traceback from rich import pretty, traceback
from sklearn import set_config from sklearn import set_config
from sklearn.base import BaseEstimator from xgboost.sklearn import XGBClassifier
from entropice.dataset import DatasetEnsemble from entropice.dataset import DatasetEnsemble
@ -16,7 +18,9 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def predict_proba(e: DatasetEnsemble, clf: BaseEstimator, classes: list) -> gpd.GeoDataFrame: def predict_proba(
e: DatasetEnsemble, clf: RandomForestClassifier | ESPAClassifier | XGBClassifier, classes: list
) -> gpd.GeoDataFrame:
"""Get predicted probabilities for each cell. """Get predicted probabilities for each cell.
Args: Args:

View file

@ -6,12 +6,15 @@ import os
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
DATA_DIR = Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None)).resolve() / "entropice" DATA_DIR = (
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
)
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
GRIDS_DIR = DATA_DIR / "grids" GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures") FIGURES_DIR = Path("figures")
DARTS_DIR = DATA_DIR / "darts" RTS_DIR = DATA_DIR / "darts-rts"
RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels"
ERA5_DIR = DATA_DIR / "era5" ERA5_DIR = DATA_DIR / "era5"
ARCTICDEM_DIR = DATA_DIR / "arcticdem" ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR = DATA_DIR / "embeddings"
@ -22,7 +25,7 @@ RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True) GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True)
DARTS_DIR.mkdir(parents=True, exist_ok=True) RTS_DIR.mkdir(parents=True, exist_ok=True)
ERA5_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True)
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True) ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
@ -34,9 +37,9 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet" dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet" dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = DARTS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps" darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str: def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
@ -58,9 +61,9 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path: def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
if labels: if labels:
rtsfile = DARTS_DIR / f"{gridname}_darts-mllabels.parquet" rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
else: else:
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet" rtsfile = RTS_DIR / f"{gridname}_darts.parquet"
return rtsfile return rtsfile
@ -107,8 +110,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file return dataset_file
def get_dataset_cache(eid: str) -> Path: def get_dataset_cache(eid: str, subset: str | None = None) -> Path:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet" if subset is None:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
else:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_{subset}_dataset.parquet"
return cache_file return cache_file
@ -116,7 +122,7 @@ def get_cv_results_dir(
name: str, name: str,
grid: Literal["hex", "healpix"], grid: Literal["hex", "healpix"],
level: int, level: int,
task: Literal["binary", "multi"], task: Literal["binary", "count", "density"],
) -> Path: ) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

View file

@ -1,4 +1,3 @@
# ruff: noqa: N806
"""Training of classification models training.""" """Training of classification models training."""
import pickle import pickle
@ -8,7 +7,6 @@ from typing import Literal
import cyclopts import cyclopts
import pandas as pd import pandas as pd
import toml import toml
import torch
import xarray as xr import xarray as xr
from cuml.ensemble import RandomForestClassifier from cuml.ensemble import RandomForestClassifier
from cuml.neighbors import KNeighborsClassifier from cuml.neighbors import KNeighborsClassifier
@ -17,9 +15,9 @@ from rich import pretty, traceback
from scipy.stats import loguniform, randint from scipy.stats import loguniform, randint
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
from sklearn import set_config from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from sklearn.model_selection import KFold, RandomizedSearchCV
from stopuhr import stopwatch from stopuhr import stopwatch
from xgboost import XGBClassifier from xgboost.sklearn import XGBClassifier
from entropice.dataset import DatasetEnsemble from entropice.dataset import DatasetEnsemble
from entropice.inference import predict_proba from entropice.inference import predict_proba
@ -30,7 +28,7 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) cli = cyclopts.App("entropice-training", config=cyclopts.config.Toml("training-config.toml")) # ty:ignore[invalid-argument-type]
_metrics = { _metrics = {
"binary": ["accuracy", "recall", "precision", "f1", "jaccard"], "binary": ["accuracy", "recall", "precision", "f1", "jaccard"],
@ -57,51 +55,6 @@ class CVSettings:
model: Literal["espa", "xgboost", "rf", "knn"] = "espa" model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
def _create_xy_data(e: DatasetEnsemble, task: Literal["binary", "count", "density"] = "binary"):
data = e.create()
covcol = "dartsml_has_coverage" if e.target == "darts_mllabels" else "darts_has_coverage"
bincol = "dartsml_has_rts" if e.target == "darts_mllabels" else "darts_has_rts"
countcol = "dartsml_rts_count" if e.target == "darts_mllabels" else "darts_rts_count"
densitycol = "dartsml_rts_density" if e.target == "darts_mllabels" else "darts_rts_density"
data = data[data[covcol]].reset_index(drop=True)
cols_to_drop = ["geometry"]
if e.target == "darts_mllabels":
cols_to_drop += [col for col in data.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
if task == "binary":
labels = ["No RTS", "RTS"]
y_data = data.loc[X_data.index, bincol]
elif task == "count":
# Put into n categories (log scaled)
y_data = data.loc[X_data.index, countcol]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
# Change the first interval to start at 1 and add a category for 0
bins = pd.IntervalIndex.from_tuples(
[(-1, 0)] + [(int(interval.left), int(interval.right)) for interval in bins]
)
print(f"{bins=}")
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
elif task == "density":
y_data = data.loc[X_data.index, densitycol]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
print(f"{bins=}")
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
else:
raise ValueError(f"Unknown task: {task}")
return data, X_data, y_data, labels
def _create_clf( def _create_clf(
settings: CVSettings, settings: CVSettings,
): ):
@ -196,15 +149,7 @@ def random_cv(
""" """
print("Creating training data...") print("Creating training data...")
_, X_data, y_data, labels = _create_xy_data(dataset_ensemble, task=settings.task) training_data = dataset_ensemble.create_cat_training_dataset(task=settings.task)
print(f"Using {settings.task}-class classification with {len(labels)} classes: {labels}")
print(f"{y_data.describe()=}")
X = X_data.to_numpy(dtype="float64")
y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
clf, param_grid, fit_params = _create_clf(settings) clf, param_grid, fit_params = _create_clf(settings)
print(f"Using model: {settings.model} with parameters: {param_grid}") print(f"Using model: {settings.model} with parameters: {param_grid}")
@ -224,14 +169,14 @@ def random_cv(
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"): with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
search.fit(X_train, y_train, **fit_params) search.fit(training_data.X.train, training_data.y.train, **fit_params)
print("Best parameters combination found:") print("Best parameters combination found:")
best_parameters = search.best_estimator_.get_params() best_parameters = search.best_estimator_.get_params()
for param_name in sorted(param_grid.keys()): for param_name in sorted(param_grid.keys()):
print(f"{param_name}: {best_parameters[param_name]}") print(f"{param_name}: {best_parameters[param_name]}")
test_accuracy = search.score(X_test, y_test) test_accuracy = search.score(training_data.X.test, training_data.y.test)
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}")
@ -251,7 +196,7 @@ def random_cv(
"param_grid": param_grid_serializable, "param_grid": param_grid_serializable,
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": labels, "classes": training_data.y.labels,
} }
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
@ -267,7 +212,7 @@ def random_cv(
# Store the search results # Store the search results
results = pd.DataFrame(search.cv_results_) results = pd.DataFrame(search.cv_results_)
# Parse the params into individual columns # Parse the params into individual columns
params = pd.json_normalize(results["params"]) params = pd.json_normalize(results["params"]) # ty:ignore[invalid-argument-type]
# Concatenate the params columns with the original DataFrame # Concatenate the params columns with the original DataFrame
results = pd.concat([results.drop(columns=["params"]), params], axis=1) results = pd.concat([results.drop(columns=["params"]), params], axis=1)
results_file = results_dir / "search_results.parquet" results_file = results_dir / "search_results.parquet"
@ -278,7 +223,7 @@ def random_cv(
if settings.model == "espa": if settings.model == "espa":
best_estimator = search.best_estimator_ best_estimator = search.best_estimator_
# Annotate the state with xarray metadata # Annotate the state with xarray metadata
features = X_data.columns.tolist() features = training_data.X.data.columns.tolist()
boxes = list(range(best_estimator.K_)) boxes = list(range(best_estimator.K_))
box_centers = xr.DataArray( box_centers = xr.DataArray(
best_estimator.S_.cpu().numpy(), best_estimator.S_.cpu().numpy(),
@ -290,7 +235,7 @@ def random_cv(
box_assignments = xr.DataArray( box_assignments = xr.DataArray(
best_estimator.Lambda_.cpu().numpy(), best_estimator.Lambda_.cpu().numpy(),
dims=["class", "box"], dims=["class", "box"],
coords={"class": labels, "box": boxes}, coords={"class": training_data.y.labels, "box": boxes},
name="box_assignments", name="box_assignments",
attrs={"description": "Assignments of samples to boxes."}, attrs={"description": "Assignments of samples to boxes."},
) )
@ -317,7 +262,7 @@ def random_cv(
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")
preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=labels) preds = predict_proba(dataset_ensemble, clf=best_estimator, classes=training_data.y.labels)
preds_file = results_dir / "predicted_probabilities.parquet" preds_file = results_dir / "predicted_probabilities.parquet"
print(f"Storing predicted probabilities to {preds_file}") print(f"Storing predicted probabilities to {preds_file}")
preds.to_parquet(preds_file) preds.to_parquet(preds_file)