Add first training

This commit is contained in:
Tobias Hölzer 2025-11-07 15:56:54 +01:00
parent ad3d7aae73
commit 3e0e6e0d2d
11 changed files with 5368 additions and 83 deletions

2
.gitattributes vendored Normal file
View file

@ -0,0 +1,2 @@
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true

4
.gitignore vendored
View file

@ -21,3 +21,7 @@ pg.ipynb
playground.ipynb playground.ipynb
*fix*.ipynb *fix*.ipynb
*debug*.ipynb *debug*.ipynb
# pixi environments
.pixi
*.egg-info

4795
pixi.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }] authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
requires-python = ">=3.12" # requires-python = ">=3.10,<3.13"
dependencies = [ dependencies = [
"aiohttp>=3.12.11", "aiohttp>=3.12.11",
"bokeh>=3.7.3", "bokeh>=3.7.3",
@ -16,7 +16,7 @@ dependencies = [
"distributed>=2025.5.1", "distributed>=2025.5.1",
"earthengine-api>=1.6.9", "earthengine-api>=1.6.9",
"eemont>=2025.7.1", "eemont>=2025.7.1",
"entropyc", "entropy",
"flox>=0.10.4", "flox>=0.10.4",
"folium>=0.19.7", "folium>=0.19.7",
"geemap>=0.36.3", "geemap>=0.36.3",
@ -48,7 +48,7 @@ dependencies = [
"xarray>=2025.9.0", "xarray>=2025.9.0",
"xdggs>=0.2.1", "xdggs>=0.2.1",
"xvec>=0.5.1", "xvec>=0.5.1",
"zarr[remote]>=3.1.3", "zarr[remote]>=3.1.3", "geocube>=0.7.1,<0.8",
] ]
[project.scripts] [project.scripts]
@ -56,6 +56,7 @@ create-grid = "entropice.grids:main"
darts = "entropice.darts:main" darts = "entropice.darts:main"
alpha-earth = "entropice.alphaearth:main" alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli" era5 = "entropice.era5:cli"
train = "entropice.training:cli"
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
@ -66,4 +67,40 @@ package = true
[tool.uv.sources] [tool.uv.sources]
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" } entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
xanimate = { git = "https://github.com/davbyr/xAnimate" } xanimate = { git = "https://github.com/davbyr/xAnimate" }
[tool.ruff.lint.pyflakes]
# Ignore libraries when checking for unused imports
allowed-unused-imports = [
"hvplot.pandas",
"hvplot.xarray",
"rioxarray",
"odc.geo.xr",
"cupy_xarray",
"xdggs",
"xvec",
]
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]
[tool.pixi.activation.env]
SCIPY_ARRAY_API = "1"
[tool.pixi.system-requirements]
cuda = "12"
[tool.pixi.pypi-dependencies]
entropice = { path = ".", editable = true }
[tool.pixi.tasks]
[tool.pixi.dependencies]
pytorch-gpu = ">=2.5.1,<3"
cupy = ">=13.6.0,<14"
nccl = ">=2.27.7.1,<3"
cudnn = ">=9.13.1.26,<10"
cusparselt = ">=0.8.1.1,<0.9"
cuda-version = "12.1.*"

View file

@ -58,6 +58,15 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"] grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"]
) )
# Apply corrections to NaNs
covered = ~grid_gdf[f"darts_{year}_coverage"].isnull()
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
0.0
)
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
output_path = get_darts_rts_file(grid, level) output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path) grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}") print(f"Saved RTS labels to {output_path}")

View file

@ -52,6 +52,10 @@ Monthly, Winter, Summer & Yearly Aggregations (Names don't change):
- *_rest -> median - *_rest -> median
- accum variables: sum - accum variables: sum
Additionally:
- snow_cover_min [instant]: min of snowc_mean over month/winter/summer/year
- snow_cover_max [instant]: max of snowc_mean over month/winter/summer/year
Derived & (from monthly) Aggregated Winter Variables: Derived & (from monthly) Aggregated Winter Variables:
- effective_snow_depth [instant]: (sde_mean * M + 1 - m).sum(M) / (m).sum(M),see also https://tc.copernicus.org/articles/11/989/2017/tc-11-989-2017.pdf - effective_snow_depth [instant]: (sde_mean * M + 1 - m).sum(M) / (m).sum(M),see also https://tc.copernicus.org/articles/11/989/2017/tc-11-989-2017.pdf
@ -75,6 +79,7 @@ Date: June to October 2025
import cProfile import cProfile
import time import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Literal from typing import Literal
import cyclopts import cyclopts
@ -88,6 +93,7 @@ import shapely.ops
import ultraplot as uplt import ultraplot as uplt
import xarray as xr import xarray as xr
import xdggs import xdggs
import xvec
from rich import pretty, print, traceback from rich import pretty, print, traceback
from rich.progress import track from rich.progress import track
from shapely.geometry import LineString, Polygon from shapely.geometry import LineString, Polygon
@ -95,6 +101,7 @@ from stopuhr import stopwatch
from entropice import codecs, grids, watermask from entropice import codecs, grids, watermask
from entropice.paths import FIGURES_DIR, get_era5_stores from entropice.paths import FIGURES_DIR, get_era5_stores
from entropice.xvec import to_xvec
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile]) traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
pretty.install() pretty.install()
@ -572,21 +579,38 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily
@cli.command @cli.command
def viz(agg: Literal["daily", "monthly", "yearly", "summer", "winter"]): def viz(
grid: Literal["hex", "healpix"] | None = None,
level: int | None = None,
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
high_qual: bool = False,
):
"""Visualize a small overview of ERA5 variables for a given aggregation. """Visualize a small overview of ERA5 variables for a given aggregation.
Args: Args:
agg (Literal["daily", "monthly", "yearly", "summer", "winter"]): grid (Literal["hex", "healpix"], optional): Grid type for spatial representation.
If provided along with level, the ERA5 data will be decoded onto the specified grid.
level (int, optional): Level of the grid for spatial representation.
If provided along with grid, the ERA5 data will be decoded onto the specified grid.
agg (Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"], optional):
Aggregation identifier used to locate the appropriate ERA5 Zarr store via Aggregation identifier used to locate the appropriate ERA5 Zarr store via
get_era5_stores. Determines which dataset is opened and visualized. get_era5_stores. Determines which dataset is opened and visualized.
high_qual (bool, optional): Whether to use high quality plotting settings.
If True, the plot will be generated with higher resolution and quality settings.
Defaults to False.
Example: Example:
>>> # produce and save an overview for monthly ERA5 data >>> # produce and save an overview for monthly ERA5 data
>>> viz("monthly") >>> viz("monthly")
""" """
store = get_era5_stores(agg) is_grid = grid is not None and level is not None
store = get_era5_stores(agg, grid, level)
ds = xr.open_zarr(store, consolidated=False).set_coords("spatial_ref") ds = xr.open_zarr(store, consolidated=False).set_coords("spatial_ref")
if is_grid:
ds = xdggs.decode(ds)
# Get cell boundaries for plotting
ds = to_xvec(ds)
tis = [0, 1, -2, -1] tis = [0, 1, -2, -1]
ts = [str(ds.time.isel(time=t).values)[:10] for t in tis] ts = [str(ds.time.isel(time=t).values)[:10] for t in tis]
@ -594,16 +618,28 @@ def viz(agg: Literal["daily", "monthly", "yearly", "summer", "winter"]):
vunits = [ds[var].attrs.get("units", "") for var in ds.data_vars] vunits = [ds[var].attrs.get("units", "") for var in ds.data_vars]
vlabels = [f"{name} [{unit}]" if unit else name for name, unit in zip(vnames, vunits)] vlabels = [f"{name} [{unit}]" if unit else name for name, unit in zip(vnames, vunits)]
fig, axs = uplt.subplots(ncols=len(tis), nrows=len(vlabels), proj="npaeqd", share=0) proj = uplt.Proj("npaeqd")
fig, axs = uplt.subplots(ncols=len(tis), nrows=len(vlabels), proj=proj, share=0)
axs.format(boundinglat=50, coast=True, toplabels=ts, leftlabels=vlabels) axs.format(boundinglat=50, coast=True, toplabels=ts, leftlabels=vlabels)
for t in tis: for t in tis:
for i, var in enumerate(ds.data_vars): for i, var in enumerate(ds.data_vars):
subset = ds[var].isel(time=t).load() subset = ds[var].isel(time=t)
m = axs[i, t].pcolormesh(subset, cmap="viridis") if is_grid:
axs[i, t].colorbar(m, loc="ll", label=var) subset.xvec.to_geodataframe().to_crs(proj.proj4_init).plot(ax=axs[i, t], column=var)
else:
if not high_qual:
subset = subset.isel(latitude=slice(None, None, 4), longitude=slice(None, None, 4))
m = axs[i, t].pcolormesh(subset, cmap="viridis")
axs[i, t].colorbar(m, loc="ll", label=var)
fig.format(suptitle="ERA5 Data") fig.format(suptitle="ERA5 Data")
(FIGURES_DIR / "era5").mkdir(parents=True, exist_ok=True) figdir = (FIGURES_DIR / "era5").resolve()
fig.savefig(FIGURES_DIR / "era5" / f"{agg}_overview_unaligned.png") figdir.mkdir(parents=True, exist_ok=True)
if high_qual:
print(f"Saving ERA5 {agg} overview figure to {figdir / f'{agg}_overview_unaligned.pdf'}.")
fig.savefig(figdir / f"{agg}_overview_unaligned.pdf")
else:
print(f"Saving ERA5 {agg} overview figure to {figdir / f'{agg}_overview_unaligned.png'}.")
fig.savefig(figdir / f"{agg}_overview_unaligned.png", dpi=100)
# =========================== # ===========================
@ -640,7 +676,7 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
@stopwatch("Getting corrected geometries", log=False) @stopwatch("Getting corrected geometries", log=False)
def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc.geo.Geometry]: def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]:
"""Get corrected geometries for antimeridian-crossing polygons. """Get corrected geometries for antimeridian-crossing polygons.
Args: Args:
@ -651,6 +687,7 @@ def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc
list[odc.geo.Geometry]: List of corrected, georeferenced geometries. list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
""" """
geom, gbox = inp
# cell.geometry is a shapely Polygon # cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom): if not _crosses_antimeridian(geom):
geoms = [geom] geoms = [geom]
@ -659,43 +696,71 @@ def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc
geoms = _split_antimeridian_cell(geom) geoms = _split_antimeridian_cell(geom)
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms] geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
geoms = filter(lambda g: _check_geom(gbox, g), geoms) geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
return geoms return geoms
@stopwatch("Extracting cell data", log=False)
def extract_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry]) -> xr.Dataset | bool:
"""Extract ERA5 data for a specific grid cell geometry.
Extracts and spatially averages ERA5 data within the bounds of a grid cell.
Handles antimeridian-crossing cells by splitting them appropriately.
Args:
ds (xr.Dataset): An ERA5 dataset.
geoms (list[odc.geo.Geometry]): List of (valid) geometries of the grid cell.
"""
if len(geoms) == 0:
return False
elif len(geoms) == 1:
return ds.odc.crop(geoms[0]).drop_vars("spatial_ref").mean(["latitude", "longitude"], skipna=True).compute()
else:
parts = [
ds.odc.crop(geom).drop_vars("spatial_ref").mean(["latitude", "longitude"], skipna=True) for geom in geoms
]
parts = [part for part in parts if part.latitude.size > 0 and part.longitude.size > 0]
if len(parts) == 0:
raise ValueError("No valid parts found for geometry. This should not happen!")
elif len(parts) == 1:
return parts[0].compute()
else:
return xr.concat(parts, dim="part").mean("part", skipna=True).compute()
def _correct_longs(ds: xr.Dataset) -> xr.Dataset: def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude") return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
@stopwatch("Extracting cell data", log=False)
def _extract_cell_data(ds, geom):
cropped = ds.odc.crop(geom).drop_vars("spatial_ref")
with np.errstate(divide="ignore", invalid="ignore"):
cell_data = cropped.mean(["latitude", "longitude"])
return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(dss: xr.Dataset, geoms):
parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)]
partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts]
with np.errstate(divide="ignore", invalid="ignore"):
partial_means = [part.sum(["latitude", "longitude"]) for part in parts]
n = xr.concat(partial_counts, dim="part").sum("part")
cell_data = xr.concat(partial_means, dim="part").sum("part") / n
return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Aligning data")
def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]:
# Persist the dataset, as all the operations here MUST NOT be lazy
unaligned = unaligned.load()
data = {
var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32)
for var in unaligned.data_vars
}
with ProcessPoolExecutor(max_workers=10) as executor:
futures = {}
for i, geoms in track(
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
):
if len(geoms) == 0:
continue
elif len(geoms) == 1:
geom = geoms[0]
# Reduce the amount of data needed to be sent to the worker
# Since we dont mask the data, only isel operations are done here
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i
else:
# Same as above but for multiple parts
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i
for future in track(
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
):
i = futures[future]
cell_data = future.result()
for var in unaligned.data_vars:
data[var][i, :] = cell_data[var]
return data
@stopwatch("Creating aligned dataset", log=False) @stopwatch("Creating aligned dataset", log=False)
def _create_aligned( def _create_aligned(
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
@ -737,48 +802,38 @@ def spatial_agg(
level (int): Grid resolution level. level (int): Grid resolution level.
""" """
grid_gdf = grids.open(grid, level) with stopwatch(f"Loading {grid} grid at level {level}"):
# ? Mask out water, since we don't want to aggregate over oceans grid_gdf = grids.open(grid, level)
grid_gdf = watermask.clip_grid(grid_gdf) # ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = grid_gdf.to_crs("epsg:4326") grid_gdf = watermask.clip_grid(grid_gdf)
grid_gdf = grid_gdf.to_crs("epsg:4326")
# Precompute the geometries to clip later for agg in ["yearly", "seasonal", "shoulder"]:
daily_store = get_era5_stores("daily")
daily_unaligned = xr.open_zarr(daily_store, consolidated=False).set_coords("spatial_ref")
assert {"latitude", "longitude", "time"} == set(daily_unaligned.dims)
assert daily_unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {daily_unaligned.odc.crs}"
daily_unaligned = _correct_longs(daily_unaligned)
cell_geometries = [_get_corrected_geoms(row.geometry, daily_unaligned.odc.geobox) for _, row in grid_gdf.iterrows()]
for agg in ["summer", "winter", "yearly"]:
unaligned_store = get_era5_stores(agg) unaligned_store = get_era5_stores(agg)
with stopwatch(f"Loading {agg} ERA5 data"): with stopwatch(f"Loading {agg} ERA5 data"):
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref") unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load()
assert {"latitude", "longitude", "time"} == set(unaligned.dims) assert {"latitude", "longitude", "time"} == set(unaligned.dims)
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}" assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
unaligned = _correct_longs(unaligned) unaligned = _correct_longs(unaligned)
data = { with stopwatch("Precomputing cell geometries"):
var: np.full((len(grid_gdf), len(unaligned.time)), np.nan, dtype=np.float32) for var in unaligned.data_vars with ProcessPoolExecutor(max_workers=20) as executor:
} cell_geometries = list(
for i, geoms in track( executor.map(
enumerate(cell_geometries), _get_corrected_geoms,
total=len(grid_gdf), [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
description=f"Spatially aggregating {agg} ERA5 data...", )
): )
if len(geoms) == 0:
print(f"Warning: No valid geometry for cell {grid_gdf.iloc[i].cell_id}.")
continue
cell_data = extract_cell_data(unaligned, geoms) data = _align_data(cell_geometries, unaligned)
for var in unaligned.data_vars:
data[var][i, :] = cell_data[var].values
aggregated = _create_aligned(unaligned, data, grid, level) aggregated = _create_aligned(unaligned, data, grid, level)
store = get_era5_stores(agg, grid, level) store = get_era5_stores(agg, grid, level)
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
print(f"Finished spatial matching for {agg} data.") print(f"Finished spatial matching for {agg} data.")
print("### Stopwatch Summary ###")
stopwatch.summary() stopwatch.summary()

View file

@ -1,11 +1,13 @@
# ruff: noqa: D103 # ruff: noqa: D103
"""Paths for entropice data storage.""" """Paths for entropice data storage."""
import datetime
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")).resolve() / "entropice" DATA_DIR = Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None)).resolve() / "entropice"
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
GRIDS_DIR = DATA_DIR / "grids" GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures") FIGURES_DIR = Path("figures")
@ -13,6 +15,8 @@ DARTS_DIR = DATA_DIR / "darts"
ERA5_DIR = DATA_DIR / "era5" ERA5_DIR = DATA_DIR / "era5"
EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR = DATA_DIR / "embeddings"
WATERMASK_DIR = DATA_DIR / "watermask" WATERMASK_DIR = DATA_DIR / "watermask"
TRAINING_DIR = DATA_DIR / "training"
RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True) GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True)
@ -20,9 +24,10 @@ DARTS_DIR.mkdir(parents=True, exist_ok=True)
ERA5_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True)
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
WATERMASK_DIR.mkdir(parents=True, exist_ok=True) WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
watermask_file = DATA_DIR / "simplified_water_polygons.shp"
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet" dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet" dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
@ -46,7 +51,7 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path: def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
rtsfile = DARTS_DIR / f"{gridname}_rts.parquet" rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
return rtsfile return rtsfile
@ -74,3 +79,18 @@ def get_era5_stores(
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr" aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr"
return aligned_path return aligned_path
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
gridname = _get_gridname(grid, level)
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
return dataset_file
def get_cv_results_file(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}"
results_dir.mkdir(parents=True, exist_ok=True)
results_file = results_dir / "search_results.parquet"
return results_file

295
src/entropice/training.py Normal file
View file

@ -0,0 +1,295 @@
# ruff: noqa: N806
"""Training dataset preparation and model training."""
from pathlib import Path
from typing import Literal
import cyclopts
import geopandas as gpd
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import xarray as xr
from entropy import ESPAClassifier
from rich import pretty, traceback
from scipy.stats import loguniform, randint
from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from stopuhr import stopwatch
from entropice.paths import (
get_cv_results_file,
get_darts_rts_file,
get_embeddings_store,
get_era5_stores,
get_train_dataset_file,
)
traceback.install()
pretty.install()
set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
cli = cyclopts.App()
@cli.command()
def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
"""
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
# Filter to coverage
rts = rts[rts["darts_has_coverage"]]
# Convert hex cell_id to int
if grid == "hex":
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
# Get era5 data
era5_store = get_era5_stores("yearly", grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
era5_df = []
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
df["year"] = df.index.get_level_values("time").year
df = (
df.pivot_table(index="cell_ids", columns="year", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
# TODO: season and shoulder data
# Get embeddings data
embs_store = get_embeddings_store(grid=grid, level=level)
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
embeddings = embeddings.sel(cell=rts["cell_id"].values)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# Combine datasets by cell id / cell
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
dataset_file = get_train_dataset_file(grid=grid, level=level)
dataset.reset_index().to_parquet(dataset_file)
@cli.command()
def random_cv(grid: Literal["hex", "healpix"], level: int):
"""Perform random cross-validation on the training dataset.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
"""
data = get_train_dataset_file(grid=grid, level=level)
data = gpd.read_parquet(data)
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
X_data = data.drop(columns=cols_to_drop).dropna()
y_data = data.loc[X_data.index, "darts_has_rts"]
X = X_data.to_numpy(dtype="float32")
y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
param_grid = {
"eps_cl": loguniform(1e-3, 1e7),
"eps_e": loguniform(1e-3, 1e7),
"initial_K": randint(20, 400),
}
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42)
cv = KFold(n_splits=5, shuffle=True, random_state=42)
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
search = RandomizedSearchCV(
clf,
param_grid,
n_iter=20,
n_jobs=24,
cv=cv,
random_state=42,
verbose=10,
scoring=metrics,
refit="f1",
)
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
search.fit(X_train, y_train, max_iter=100)
print("Best parameters combination found:")
best_parameters = search.best_estimator_.get_params()
for param_name in sorted(param_grid.keys()):
print(f"{param_name}: {best_parameters[param_name]}")
test_accuracy = search.score(X_test, y_test)
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
# Store the search results
results = pd.DataFrame(search.cv_results_)
# Parse the params into individual columns
params = pd.json_normalize(results["params"])
# Concatenate the params columns with the original DataFrame
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
results["grid"] = grid
results["level"] = level
results_file = get_cv_results_file("random_search", grid=grid, level=level)
print(f"Storing CV results to {results_file}")
results.to_parquet(results_file)
stopwatch.summary()
print("Done.")
plot_random_cv_results(results_file)
def _plot_k_binned(
results: pd.DataFrame, target: str, *, vmin_percentile: float | None = None, vmax_percentile: float | None = None
):
assert vmin_percentile is None or vmax_percentile is None, (
"Only one of vmin_percentile or vmax_percentile can be set."
)
assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results."
assert target in results.columns, f"{target} column not found in results."
assert "eps_e" in results.columns, "eps_e column not found in results."
assert "eps_cl" in results.columns, "eps_cl column not found in results."
# add a colorbar instead of the sampled legend
cmap = sns.color_palette("ch:", as_cmap=True)
# sufisticated normalization
if vmin_percentile is not None:
vmin = np.percentile(results[target], vmin_percentile)
norm = mcolors.Normalize(vmin=vmin)
elif vmax_percentile is not None:
vmax = np.percentile(results[target], vmax_percentile)
norm = mcolors.Normalize(vmax=vmax)
else:
norm = mcolors.Normalize()
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# nice col-wrap based on columns
n_cols = results["initial_K_binned"].unique().size
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
scatter = sns.relplot(
data=results,
x="eps_e",
y="eps_cl",
hue=target,
hue_norm=sm.norm,
palette=cmap,
legend=False,
col="initial_K_binned",
col_wrap=col_wrap,
)
# Apply log scale to all axes
for ax in scatter.axes.flat:
ax.set_xscale("log")
ax.set_yscale("log")
# Tight layout
scatter.figure.tight_layout()
# Add a shared colorbar at the bottom
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
cbar.set_label(target)
return scatter
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
assert "initial_K" in results.columns, "initial_K column not found in results."
assert metric in results.columns, f"{metric} not found in results."
if target == "eps_cl":
hue = "eps_cl"
col = "eps_e_binned"
elif target == "eps_e":
hue = "eps_e"
col = "eps_cl_binned"
assert hue in results.columns, f"{hue} column not found in results."
assert col in results.columns, f"{col} column not found in results."
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
@cli.command()
def plot_random_cv_results(file: Path):
"""Plot analysis of the results from the RandomCVSearch.
Args:
file (Path): The file of the results.
"""
print(f"Plotting random CV results from {file}...")
results = pd.read_parquet(file)
# Bin the initial_K into 40er bins
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
# Bin the eps_cl and eps_e into logarithmic bins
eps_cl_bins = np.logspace(-3, 7, num=10)
eps_e_bins = np.logspace(-3, 7, num=10)
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
figdir = file.parent
# K-Plots
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
for metric in metrics:
_plot_k_binned(
results,
f"mean_test_{metric}",
vmin_percentile=50,
).figure.savefig(figdir / f"params3d-mean_{metric}.pdf")
_plot_k_binned(
results,
f"std_test_{metric}",
vmax_percentile=50,
).figure.savefig(figdir / f"params3d-std_{metric}.pdf")
_plot_k_binned(results, f"mean_test_{metric}").figure.savefig(figdir / f"params3d-mean_{metric}-noperc.pdf")
_plot_k_binned(results, f"std_test_{metric}").figure.savefig(figdir / f"params3d-std_{metric}-noperc.pdf")
# eps-Plots
_plot_eps_binned(
results,
"eps_cl",
f"mean_test_{metric}",
).figure.savefig(figdir / f"k-eps_cl-mean_{metric}.pdf")
_plot_eps_binned(
results,
"eps_e",
f"mean_test_{metric}",
).figure.savefig(figdir / f"k-eps_e-mean_{metric}.pdf")
# Close all figures
plt.close("all")
print("Done.")
if __name__ == "__main__":
cli()

View file

@ -27,6 +27,6 @@ def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
""" """
watermask = open() watermask = open()
watermask = watermask.to_crs("EPSG:3413") watermask = watermask.to_crs(gdf.crs)
gdf = gdf.overlay(watermask, how="difference") gdf = gdf.overlay(watermask, how="difference")
return gdf return gdf

13
src/entropice/xvec.py Normal file
View file

@ -0,0 +1,13 @@
import xarray as xr
import xdggs
import xvec
def to_xvec(ds: xr.Dataset) -> xr.Dataset:
ds["geometry"] = xdggs.decode(ds.dggs.cell_boundaries())
cell_ids = ds.cell_ids.values
ds = ds.set_index(cell_ids="geometry").rename_dims({"cell_ids": "geometry"}).rename_vars({"cell_ids": "geometry"})
ds = ds.xvec.set_geom_indexes("geometry", crs="epsg:4326")
ds["cell_ids"] = ("geometry", cell_ids)
ds = ds.set_coords("cell_ids")
return ds

65
uv.lock generated
View file

@ -136,6 +136,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl", hash = "sha256:944b82ef1dd17b8ff0fb6d1f199f613caf9111338e6e2857da478f6e73770cb8", size = 220671 }, { url = "https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl", hash = "sha256:944b82ef1dd17b8ff0fb6d1f199f613caf9111338e6e2857da478f6e73770cb8", size = 220671 },
] ]
[[package]]
name = "appdirs"
version = "1.4.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566 },
]
[[package]] [[package]]
name = "appnope" name = "appnope"
version = "0.1.4" version = "0.1.4"
@ -145,6 +154,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 },
] ]
[[package]]
name = "array-api-compat"
version = "1.12.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8d/bd/9fa5c7c5621698d5632cc852a79fbbdc28024462c9396698e5fdcb395f37/array_api_compat-1.12.0.tar.gz", hash = "sha256:585bc615f650de53ac24b7c012baecfcdd810f50df3573be47e6dd9fa20df974", size = 99883 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl", hash = "sha256:a0b4795b6944a9507fde54679f9350e2ad2b1e2acf4a2408a098cdc27f890a8b", size = 58156 },
]
[[package]]
name = "array-api-extra"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "array-api-compat" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cb/be/387d596e0ed6d191988d7b61bb0e252bb8023965485218a33bd1a8ccc72a/array_api_extra-0.9.0.tar.gz", hash = "sha256:2d49e38394f5a96caef17b80964d228636e6fdf1dada33697ab8ec3f5364eb77", size = 81579 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8c/17/8a7503c9dc3f8cc2868c632291e5d224822b05ae62f1279c529c459368d2/array_api_extra-0.9.0-py3-none-any.whl", hash = "sha256:36b34e29380b678007f151511be950cb2ea199606fe4a7ad466efc5044ea9e44", size = 48234 },
]
[[package]] [[package]]
name = "arro3-compute" name = "arro3-compute"
version = "0.6.5" version = "0.6.5"
@ -1217,10 +1247,11 @@ dependencies = [
{ name = "distributed" }, { name = "distributed" },
{ name = "earthengine-api" }, { name = "earthengine-api" },
{ name = "eemont" }, { name = "eemont" },
{ name = "entropyc" }, { name = "entropy" },
{ name = "flox" }, { name = "flox" },
{ name = "folium" }, { name = "folium" },
{ name = "geemap" }, { name = "geemap" },
{ name = "geocube" },
{ name = "geopandas" }, { name = "geopandas" },
{ name = "h3" }, { name = "h3" },
{ name = "h5netcdf" }, { name = "h5netcdf" },
@ -1264,10 +1295,11 @@ requires-dist = [
{ name = "distributed", specifier = ">=2025.5.1" }, { name = "distributed", specifier = ">=2025.5.1" },
{ name = "earthengine-api", specifier = ">=1.6.9" }, { name = "earthengine-api", specifier = ">=1.6.9" },
{ name = "eemont", specifier = ">=2025.7.1" }, { name = "eemont", specifier = ">=2025.7.1" },
{ name = "entropyc", git = "ssh://git@github.com/AlbertEMC2Stein/entropyc?branch=refactor%2Ftobi" }, { name = "entropy", git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" },
{ name = "flox", specifier = ">=0.10.4" }, { name = "flox", specifier = ">=0.10.4" },
{ name = "folium", specifier = ">=0.19.7" }, { name = "folium", specifier = ">=0.19.7" },
{ name = "geemap", specifier = ">=0.36.3" }, { name = "geemap", specifier = ">=0.36.3" },
{ name = "geocube", specifier = ">=0.7.1,<0.8" },
{ name = "geopandas", specifier = ">=1.1.0" }, { name = "geopandas", specifier = ">=1.1.0" },
{ name = "h3", specifier = ">=4.2.2" }, { name = "h3", specifier = ">=4.2.2" },
{ name = "h5netcdf", specifier = ">=1.6.4" }, { name = "h5netcdf", specifier = ">=1.6.4" },
@ -1300,11 +1332,13 @@ requires-dist = [
] ]
[[package]] [[package]]
name = "entropyc" name = "entropy"
version = "0.1.0" version = "0.1.0"
source = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc?branch=refactor%2Ftobi#22a191d194a76b6c182481acb2af1bde3f60b49e" } source = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7" }
dependencies = [ dependencies = [
{ name = "numpy" }, { name = "array-api-compat" },
{ name = "array-api-extra" },
{ name = "scikit-learn" },
{ name = "scipy" }, { name = "scipy" },
] ]
@ -1569,6 +1603,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4f/6b/13166c909ad2f2d76b929a4227c952630ebaf0d729f6317eb09cbceccbab/geocoder-1.38.1-py2.py3-none-any.whl", hash = "sha256:a733e1dfbce3f4e1a526cac03aadcedb8ed1239cf55bd7f3a23c60075121a834", size = 98590 }, { url = "https://files.pythonhosted.org/packages/4f/6b/13166c909ad2f2d76b929a4227c952630ebaf0d729f6317eb09cbceccbab/geocoder-1.38.1-py2.py3-none-any.whl", hash = "sha256:a733e1dfbce3f4e1a526cac03aadcedb8ed1239cf55bd7f3a23c60075121a834", size = 98590 },
] ]
[[package]]
name = "geocube"
version = "0.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "appdirs" },
{ name = "click" },
{ name = "geopandas" },
{ name = "numpy" },
{ name = "odc-geo" },
{ name = "pyproj" },
{ name = "rasterio" },
{ name = "rioxarray" },
{ name = "scipy" },
{ name = "xarray" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6c/03/d39b7a372f2054ae374247c8e0130b8f23aee89b0624c8f04fa49b2c1199/geocube-0.7.1.tar.gz", hash = "sha256:5f0f4a2143b379434d81172ae8c9fb49c2c5ff2f9723864ed79d3947a68ea37f", size = 20528 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/31/c6/a9341239e2e2953537b9e90a46ebc59f2e122247a3fe22373cc37520fc44/geocube-0.7.1-py3-none-any.whl", hash = "sha256:661a12c0b2106f27477290b5f18e76eb5855c9c50cac7fd19028fde4babca628", size = 23107 },
]
[[package]] [[package]]
name = "geographiclib" name = "geographiclib"
version = "2.1" version = "2.1"