From e5382670ec76392830b7208e26cd67fe0cb85a6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 22 Nov 2025 22:18:45 +0100 Subject: [PATCH] Refactor spatial aggregations --- src/entropice/aggregators.py | 254 +++++++++++++++++++++++++++++++++++ src/entropice/arcticdem.py | 48 +++++-- src/entropice/era5.py | 173 +++--------------------- src/entropice/paths.py | 14 +- src/entropice/training.py | 208 +++++++++++++++++----------- 5 files changed, 457 insertions(+), 240 deletions(-) create mode 100644 src/entropice/aggregators.py diff --git a/src/entropice/aggregators.py b/src/entropice/aggregators.py new file mode 100644 index 0000000..9534880 --- /dev/null +++ b/src/entropice/aggregators.py @@ -0,0 +1,254 @@ +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from typing import Literal + +import cupy_xarray +import geopandas as gpd +import numpy as np +import odc.geo.geobox +import shapely +import shapely.ops +import xarray as xr +import xdggs +import xvec +from rich.progress import track +from shapely.geometry import LineString, Polygon +from stopuhr import stopwatch +from xdggs.healpix import HealpixInfo + +from entropice import grids + + +@dataclass +class _Aggregations: + mean: bool = True + sum: bool = False + std: bool = False + min: bool = False + max: bool = False + quantiles: list[float] = field(default_factory=lambda: []) + + def varnames(self, vars: list[str] | str) -> list[str]: + if isinstance(vars, str): + vars = [vars] + agg_vars = [] + for var in vars: + if self.mean: + agg_vars.append(f"{var}_mean") + if self.sum: + agg_vars.append(f"{var}_sum") + if self.std: + agg_vars.append(f"{var}_std") + if self.min: + agg_vars.append(f"{var}_min") + if self.max: + agg_vars.append(f"{var}_max") + for q in self.quantiles: + q_int = int(q * 100) + agg_vars.append(f"{var}_p{q_int}") + return agg_vars + + +def _crosses_antimeridian(geom: Polygon) -> bool: + coords = shapely.get_coordinates(geom) + crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any() + return crosses_any_meridian and abs(coords[:, 0]).max() > 90 + + +def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]: + # Assumes that it is a antimeridian hex + coords = shapely.get_coordinates(geom) + for i in range(coords.shape[0]): + if coords[i, 0] < 0: + coords[i, 0] += 360 + geom = Polygon(coords) + antimeridian = LineString([[180, -90], [180, 90]]) + polys = shapely.ops.split(geom, antimeridian) + return list(polys.geoms) + + +def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool: + enclosing = geobox.enclosing(geom) + x, y = enclosing.shape + if x <= 1 or y <= 1: + return False + roi: tuple[slice, slice] = geobox.overlap_roi(enclosing) + roix, roiy = roi + return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1 + + +@stopwatch("Getting corrected geometries", log=False) +def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]: + geom, gbox, crs = inp + # cell.geometry is a shapely Polygon + if not _crosses_antimeridian(geom): + geoms = [geom] + # Split geometry in case it crossed antimeridian + else: + geoms = _split_antimeridian_cell(geom) + + geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms] + geoms = list(filter(lambda g: _check_geom(gbox, g), geoms)) + return geoms + + +@stopwatch("Correcting geometries") +def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox): + """Get corrected geometries for antimeridian-crossing polygons. + + Args: + grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame. + gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference. + + Returns: + list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries. + + """ + with ProcessPoolExecutor(max_workers=20) as executor: + cell_geometries = list( + executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()]) + ) + return cell_geometries + + +@stopwatch("Aggregating cell data", log=False) +def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations): + cell_data = {} + for var in flattened.data_vars: + if aggregations.mean: + cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True) + if aggregations.sum: + cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True) + if aggregations.std: + cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True) + if aggregations.min: + cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True) + if aggregations.max: + cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True) + if len(aggregations.quantiles) > 0: + quantile_values = flattened[var].quantile( + q=aggregations.quantiles, + dim="z", + skipna=True, + ) + for q, qv in zip(aggregations.quantiles, quantile_values): + q_int = int(q * 100) + cell_data[f"{var}_p{q_int}"] = qv + # Transform to numpy arrays + for key in cell_data: + cell_data[key] = cell_data[key].cupy.as_numpy().values + return cell_data + + +@stopwatch("Extracting cell data", log=False) +def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations): + spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] + cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref") + flattened = cropped.stack(z=spatdims) + if flattened.z.size > 3000: + flattened = flattened.cupy.as_cupy() + cell_data = _agg_cell_data(flattened, aggregations) + return cell_data + + # with np.errstate(divide="ignore", invalid="ignore"): + # cell_data = cropped.mean(spatdims) + # return {var: cell_data[var].values for var in cell_data.data_vars} + + +@stopwatch("Extracting split cell data", log=False) +def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations): + spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] + cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)] + flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z") + cell_data = _agg_cell_data(flattened, aggregations) + return cell_data + + # partial_counts = [part.notnull().sum(dim=spatdims) for part in parts] + # with np.errstate(divide="ignore", invalid="ignore"): + # partial_means = [part.sum(dim=spatdims) for part in parts] + # n = xr.concat(partial_counts, dim="part").sum("part") + # cell_data = xr.concat(partial_means, dim="part").sum("part") / n + # return {var: cell_data[var].values for var in cell_data.data_vars} + + +@stopwatch("Aligning data with grid") +def _align_data( + grid_gdf: gpd.GeoDataFrame, + unaligned: xr.Dataset, + aggregations: _Aggregations, +) -> dict[str, np.ndarray]: + # Persist the dataset, as all the operations here MUST NOT be lazy + unaligned = unaligned.load() + + cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox) + other_dims_shape = tuple( + [unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]] + ) + data_shape = (len(cell_geometries), *other_dims_shape) + data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)} + with ProcessPoolExecutor(max_workers=10) as executor: + futures = {} + for i, geoms in track( + enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..." + ): + if len(geoms) == 0: + continue + elif len(geoms) == 1: + geom = geoms[0] + # Reduce the amount of data needed to be sent to the worker + # Since we dont mask the data, only isel operations are done here + # Thus, to properly extract the subset, another masked crop needs to be done in the worker + unaligned_subset = unaligned.odc.crop(geom, apply_mask=False) + fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations) + else: + # Same as above but for multiple parts + unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms] + fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations) + futures[fut] = i + + for future in track( + as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..." + ): + i = futures[future] + cell_data = future.result() + for var in unaligned.data_vars: + data[var][i, :] = cell_data[var] + return data + + +def aggregate_raster_into_grid( + raster: xr.Dataset, + grid_gdf: gpd.GeoDataFrame, + aggregations: _Aggregations, + grid: Literal["hex", "healpix"], + level: int, +): + aligned = _align_data(grid_gdf, raster, aggregations) + + cell_ids = grids.get_cell_ids(grid, level) + + dims = ["cell_ids"] + coords = {"cell_ids": cell_ids} + for dim in raster.dims: + if dim not in ["y", "x", "latitude", "longitude"]: + dims.append(dim) + coords[dim] = raster.coords[dim] + + data_vars = {var: (dims, values) for var, values in aligned.items()} + ongrid = xr.Dataset( + data_vars, + coords=coords, + ) + gridinfo = { + "grid_name": "h3" if grid == "hex" else grid, + "level": level, + } + if grid == "healpix": + gridinfo["indexing_scheme"] = "nested" + ongrid.cell_ids.attrs = gridinfo + for var in raster.data_vars: + for v in aggregations.varnames(var): + ongrid[v].attrs = raster[var].attrs + + ongrid = xdggs.decode(ongrid) + return ongrid diff --git a/src/entropice/arcticdem.py b/src/entropice/arcticdem.py index e419f89..824581b 100644 --- a/src/entropice/arcticdem.py +++ b/src/entropice/arcticdem.py @@ -1,6 +1,7 @@ import datetime from dataclasses import dataclass from math import ceil +from typing import Literal import cupy as cp import cupy_xarray @@ -12,6 +13,7 @@ import icechunk.xarray import numpy as np import smart_geocubes import xarray as xr +import xdggs import xrspatial import zarr from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt @@ -22,7 +24,9 @@ from xrspatial.curvature import _run_cupy as curvature_cupy from xrspatial.slope import _run_cupy as slope_cupy from zarr.codecs import BloscCodec -from entropice.paths import arcticdem_store +from entropice import codecs, grids, watermask +from entropice.aggregators import _Aggregations, aggregate_raster_into_grid +from entropice.paths import get_arcticdem_stores traceback.install(show_locals=True, suppress=[cyclopts]) pretty.install() @@ -34,6 +38,7 @@ cli = cyclopts.App(name="arcticdem") @cli.command() def download(): + arcticdem_store = get_arcticdem_stores() adem = smart_geocubes.ArcticDEM32m(arcticdem_store) adem.create(exists_ok=True) with stopwatch("Download ArcticDEM data"): @@ -172,7 +177,7 @@ def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tpi_large = tpi_cupy(chunk, large_kernels) # Slope slope = slope_cupy(chunk, res, res) - # Aspect + # Aspect & correction aspect = aspect_cupy(chunk) xx, yy = _get_xy_chunk(chunk, x, y, block_info) aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90 @@ -235,12 +240,6 @@ def _enrich(terrain: xr.DataArray): enriched_da = dask.array.map_overlap( _enrich_chunk, terrain.data, - # dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat( - # terrain.x.size, axis=1 - # ), - # dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat( - # terrain.y.size, axis=0 - # ), x=terrain.x.to_numpy(), y=terrain.y.to_numpy(), depth=15, # large_kernels.size_px @@ -273,14 +272,13 @@ def enrich(): print(client) print(client.dashboard_link) + arcticdem_store = get_arcticdem_stores() accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) # Garbage collect from previous runs accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC)) adem = accessor.open_xarray() - # session = adem.repo.readonly_session("main") - # adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref") del adem.y.attrs["_FillValue"] del adem.x.attrs["_FillValue"] enriched, new_features = _enrich(adem.dem) @@ -289,10 +287,9 @@ def enrich(): adem[feature] = enriched.sel(feature=feature) print(adem[new_features]) # subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000)) - encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features} - # subset.to_zarr("test2.zarr", mode="w", encoding=encodings) session = accessor.repo.writable_session("main") + encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features} icechunk.xarray.to_icechunk( adem[new_features], session, @@ -304,5 +301,32 @@ def enrich(): print("Enrichment complete.") +@cli.command() +def aggregate(grid: Literal["hex", "healpix"], level: int): + with stopwatch(f"Loading {grid} grid at level {level}"): + grid_gdf = grids.open(grid, level) + # ? Mask out water, since we don't want to aggregate over oceans + grid_gdf = watermask.clip_grid(grid_gdf) + + arcticdem_store = get_arcticdem_stores() + accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) + adem = accessor.open_xarray() + assert {"x", "y"} == set(adem.dims) + assert adem.odc.crs == "EPSG:3413" + + aggregations = _Aggregations( + mean=True, + sum=False, + std=True, + min=True, + max=True, + quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99], + ) + aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level) + store = get_arcticdem_stores(grid, level) + with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"): + aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) + + if __name__ == "__main__": cli() diff --git a/src/entropice/era5.py b/src/entropice/era5.py index 8a35c01..b589275 100644 --- a/src/entropice/era5.py +++ b/src/entropice/era5.py @@ -1,4 +1,4 @@ -# ruff: noqa: PD003, PD011 +# ruff: noqa: PD003 """Download and preprocess ERA5 data. Variables of Interest: @@ -79,7 +79,6 @@ Date: June to October 2025 import cProfile import time -from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Literal import cyclopts @@ -88,18 +87,15 @@ import numpy as np import odc.geo import odc.geo.xr import pandas as pd -import shapely -import shapely.ops import ultraplot as uplt import xarray as xr import xdggs import xvec from rich import pretty, print, traceback -from rich.progress import track -from shapely.geometry import LineString, Polygon from stopuhr import stopwatch from entropice import codecs, grids, watermask +from entropice.aggregators import _Aggregations, aggregate_raster_into_grid from entropice.paths import FIGURES_DIR, get_era5_stores from entropice.xvec import to_xvec @@ -647,145 +643,10 @@ def viz( # =========================== -def _crosses_antimeridian(geom: Polygon) -> bool: - coords = shapely.get_coordinates(geom) - crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any() - return crosses_any_meridian and abs(coords[:, 0]).max() > 90 - - -def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]: - # Assumes that it is a antimeridian hex - coords = shapely.get_coordinates(geom) - for i in range(coords.shape[0]): - if coords[i, 0] < 0: - coords[i, 0] += 360 - geom = Polygon(coords) - antimeridian = LineString([[180, -90], [180, 90]]) - polys = shapely.ops.split(geom, antimeridian) - return list(polys.geoms) - - -def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool: - enclosing = geobox.enclosing(geom) - x, y = enclosing.shape - if x <= 1 or y <= 1: - return False - roi: tuple[slice, slice] = geobox.overlap_roi(enclosing) - roix, roiy = roi - return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1 - - -@stopwatch("Getting corrected geometries", log=False) -def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]: - """Get corrected geometries for antimeridian-crossing polygons. - - Args: - geom (Polygon): Input polygon geometry. - gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference. - - Returns: - list[odc.geo.Geometry]: List of corrected, georeferenced geometries. - - """ - geom, gbox = inp - # cell.geometry is a shapely Polygon - if not _crosses_antimeridian(geom): - geoms = [geom] - # Split geometry in case it crossed antimeridian - else: - geoms = _split_antimeridian_cell(geom) - - geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms] - geoms = list(filter(lambda g: _check_geom(gbox, g), geoms)) - return geoms - - def _correct_longs(ds: xr.Dataset) -> xr.Dataset: return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude") -@stopwatch("Extracting cell data", log=False) -def _extract_cell_data(ds, geom): - cropped = ds.odc.crop(geom).drop_vars("spatial_ref") - with np.errstate(divide="ignore", invalid="ignore"): - cell_data = cropped.mean(["latitude", "longitude"]) - return {var: cell_data[var].values for var in cell_data.data_vars} - - -@stopwatch("Extracting split cell data", log=False) -def _extract_split_cell_data(dss: xr.Dataset, geoms): - parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)] - partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts] - with np.errstate(divide="ignore", invalid="ignore"): - partial_means = [part.sum(["latitude", "longitude"]) for part in parts] - n = xr.concat(partial_counts, dim="part").sum("part") - cell_data = xr.concat(partial_means, dim="part").sum("part") / n - return {var: cell_data[var].values for var in cell_data.data_vars} - - -@stopwatch("Aligning data") -def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]: - # Persist the dataset, as all the operations here MUST NOT be lazy - unaligned = unaligned.load() - - data = { - var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32) - for var in unaligned.data_vars - } - with ProcessPoolExecutor(max_workers=10) as executor: - futures = {} - for i, geoms in track( - enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..." - ): - if len(geoms) == 0: - continue - elif len(geoms) == 1: - geom = geoms[0] - # Reduce the amount of data needed to be sent to the worker - # Since we dont mask the data, only isel operations are done here - # Thus, to properly extract the subset, another masked crop needs to be done in the worker - unaligned_subset = unaligned.odc.crop(geom, apply_mask=False) - futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i - else: - # Same as above but for multiple parts - unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms] - futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i - - for future in track( - as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..." - ): - i = futures[future] - cell_data = future.result() - for var in unaligned.data_vars: - data[var][i, :] = cell_data[var] - return data - - -@stopwatch("Creating aligned dataset", log=False) -def _create_aligned( - ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int -) -> xr.Dataset: - cell_ids = grids.get_cell_ids(grid, level) - data_vars = {var: (["cell_ids", "time"], values) for var, values in data.items()} - aligned = xr.Dataset( - data_vars, - coords={"cell_ids": cell_ids, "time": ds.time}, - ) - gridinfo = { - "grid_name": "h3" if grid == "hex" else grid, - "level": level, - } - if grid == "healpix": - gridinfo["indexing_scheme"] = "nested" - aligned.cell_ids.attrs = gridinfo - for var in ds.data_vars: - aligned[var].attrs = ds[var].attrs - - aligned = aligned.chunk({"cell_ids": min(len(aligned.cell_ids), 10000), "time": len(aligned.time)}) - aligned = xdggs.decode(aligned) - return aligned - - @cli.command def spatial_agg( grid: Literal["hex", "healpix"], @@ -816,18 +677,28 @@ def spatial_agg( assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}" unaligned = _correct_longs(unaligned) - with stopwatch("Precomputing cell geometries"): - with ProcessPoolExecutor(max_workers=20) as executor: - cell_geometries = list( - executor.map( - _get_corrected_geoms, - [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()], - ) - ) + # with stopwatch("Precomputing cell geometries"): + # with ProcessPoolExecutor(max_workers=20) as executor: + # cell_geometries = list( + # executor.map( + # _get_corrected_geoms, + # [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()], + # ) + # ) + # data = _align_data(cell_geometries, unaligned) + # aggregated = _create_aligned(unaligned, data, grid, level) - data = _align_data(cell_geometries, unaligned) + aggregations = _Aggregations( + mean=True, + sum=False, + std=True, + min=True, + max=True, + quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99], + ) + aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level) - aggregated = _create_aligned(unaligned, data, grid, level) + aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)}) store = get_era5_stores(agg, grid, level) with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"): aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) diff --git a/src/entropice/paths.py b/src/entropice/paths.py index 779632e..f2f1602 100644 --- a/src/entropice/paths.py +++ b/src/entropice/paths.py @@ -13,6 +13,7 @@ GRIDS_DIR = DATA_DIR / "grids" FIGURES_DIR = Path("figures") DARTS_DIR = DATA_DIR / "darts" ERA5_DIR = DATA_DIR / "era5" +ARCTICDEM_DIR = DATA_DIR / "arcticdem" EMBEDDINGS_DIR = DATA_DIR / "embeddings" WATERMASK_DIR = DATA_DIR / "watermask" TRAINING_DIR = DATA_DIR / "training" @@ -22,12 +23,12 @@ GRIDS_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True) DARTS_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True) +ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) WATERMASK_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DIR.mkdir(parents=True, exist_ok=True) RESULTS_DIR.mkdir(parents=True, exist_ok=True) -arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" @@ -83,6 +84,17 @@ def get_era5_stores( return aligned_path +def get_arcticdem_stores( + grid: Literal["hex", "healpix"] | None = None, + level: int | None = None, +): + if grid is None or level is None: + return DATA_DIR / "arcticdem32m.icechunk.zarr" + gridname = _get_gridname(grid, level) + aligned_path = ARCTICDEM_DIR / f"{gridname}_dem.zarr" + return aligned_path + + def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: gridname = _get_gridname(grid, level) dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet" diff --git a/src/entropice/training.py b/src/entropice/training.py index 483d647..182783b 100644 --- a/src/entropice/training.py +++ b/src/entropice/training.py @@ -1,5 +1,5 @@ # ruff: noqa: N806 -"""Training dataset preparation and model training.""" +"""Training of classification models training.""" import pickle from typing import Literal @@ -10,12 +10,15 @@ import pandas as pd import toml import torch import xarray as xr +from cuml.ensemble import RandomForestClassifier +from cuml.neighbors import KNeighborsClassifier from entropy import ESPAClassifier from rich import pretty, traceback from scipy.stats import loguniform, randint from sklearn import set_config from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from stopuhr import stopwatch +from xgboost import XGBClassifier from entropice.inference import predict_proba from entropice.paths import ( @@ -29,13 +32,13 @@ pretty.install() set_config(array_api_dispatch=True) -def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"): +def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"): """Create X and y data from the training dataset. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. - task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary". + task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". Returns: Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names. @@ -51,7 +54,7 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b if task == "binary": labels = ["No RTS", "RTS"] y_data = data.loc[X_data.index, "darts_has_rts"] - else: + elif task == "count": # Put into n categories (log scaled) y_data = data.loc[X_data.index, "darts_rts_count"] n_categories = 5 @@ -63,6 +66,17 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b y_data = pd.cut(y_data, bins=bins) labels = [str(v) for v in y_data.sort_values().unique()] y_data = y_data.cat.codes + elif task == "density": + y_data = data.loc[X_data.index, "darts_rts_density"] + n_categories = 5 + bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories + # Change the first interval to start at 0 + bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins]) + y_data = pd.cut(y_data, bins=bins) + labels = [str(v) for v in y_data.sort_values().unique()] + y_data = y_data.cat.codes + else: + raise ValueError(f"Unknown task: {task}") return data, X_data, y_data, labels @@ -71,7 +85,8 @@ def random_cv( level: int, n_iter: int = 2000, robust: bool = False, - task: Literal["binary", "multi"] = "binary", + task: Literal["binary", "count", "density"] = "binary", + model: Literal["espa", "xgboost", "rf", "knn"] = "espa", ): """Perform random cross-validation on the training dataset. @@ -80,38 +95,74 @@ def random_cv( level (int): The grid level to use. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. robust (bool, optional): Whether to use robust training. Defaults to False. - task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary". + task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary". """ _, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task) print(f"Using {task}-class classification with {len(labels)} classes: {labels}") print(f"{y_data.describe()=}") - X = X_data.to_numpy(dtype="float32") + X = X_data.to_numpy(dtype="float64") y = y_data.to_numpy(dtype="int8") X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) print(f"{X.shape=}, {y.shape=}") X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") - param_grid = { - "eps_cl": loguniform(1e-3, 1e7), - "eps_e": loguniform(1e-3, 1e7), - "initial_K": randint(20, 400), - } - - clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) + if model == "espa": + clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) + if task == "binary": + param_grid = { + "eps_cl": loguniform(1e-4, 1e1), + "eps_e": loguniform(1e1, 1e7), + "initial_K": randint(20, 400), + } + else: + param_grid = { + "eps_cl": loguniform(1e-11, 1e-6), + "eps_e": loguniform(1e4, 1e8), + "initial_K": randint(400, 800), + } + elif model == "xgboost": + param_grid = { + "learning_rate": loguniform(1e-4, 1e-1), + "max_depth": randint(3, 15), + "n_estimators": randint(100, 1000), + "subsample": loguniform(0.5, 1.0), + "colsample_bytree": loguniform(0.5, 1.0), + } + clf = XGBClassifier( + objective="multi:softprob" if task != "binary" else "binary:logistic", + eval_metric="mlogloss" if task != "binary" else "logloss", + random_state=42, + tree_method="gpu_hist", + device="cuda", + ) + elif model == "rf": + param_grid = { + "max_depth": randint(5, 50), + "n_estimators": randint(50, 500), + } + clf = RandomForestClassifier(random_state=42) + elif model == "knn": + param_grid = { + "n_neighbors": randint(3, 15), + "weights": ["uniform", "distance"], + "algorithm": ["brute", "kd_tree", "ball_tree"], + } + clf = KNeighborsClassifier(random_state=42) + else: + raise ValueError(f"Unknown model: {model}") cv = KFold(n_splits=5, shuffle=True, random_state=42) if task == "binary": metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU else: metrics = [ - "accuracy", # equals "f1_micro", "precision_micro", "recall_micro", + "accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted" "f1_macro", "f1_weighted", "precision_macro", "precision_weighted", "recall_macro", - "recall_weighted", "jaccard_micro", "jaccard_macro", "jaccard_weighted", @@ -120,17 +171,17 @@ def random_cv( clf, param_grid, n_iter=n_iter, - n_jobs=16, + n_jobs=8, cv=cv, random_state=42, - verbose=3, + verbose=10, scoring=metrics, refit="f1" if task == "binary" else "f1_weighted", ) print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"): - search.fit(X_train, y_train, max_iter=100) + search.fit(X_train, y_train, max_iter=300) print("Best parameters combination found:") best_parameters = search.best_estimator_.get_params() @@ -144,29 +195,33 @@ def random_cv( results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task) # Store the search settings + # First convert the param_grid distributions to a serializable format + param_grid_serializable = {} + for key, dist in param_grid.items(): + if isinstance(dist, loguniform): + param_grid_serializable[key] = { + "distribution": "loguniform", + "low": dist.a, + "high": dist.b, + } + elif isinstance(dist, randint): + param_grid_serializable[key] = { + "distribution": "randint", + "low": dist.a, + "high": dist.b, + } + elif isinstance(dist, list): + param_grid_serializable[key] = dist + else: + raise ValueError(f"Unknown distribution type for {key}: {type(dist)}") settings = { "task": task, + "model": model, "grid": grid, "level": level, "random_state": 42, "n_iter": n_iter, - "param_grid": { - "eps_cl": { - "distribution": "loguniform", - "low": param_grid["eps_cl"].a, - "high": param_grid["eps_cl"].b, - }, - "eps_e": { - "distribution": "loguniform", - "low": param_grid["eps_e"].a, - "high": param_grid["eps_e"].b, - }, - "initial_K": { - "distribution": "randint", - "low": param_grid["initial_K"].a, - "high": param_grid["initial_K"].b, - }, - }, + "param_grid": param_grid_serializable, "cv_splits": cv.get_n_splits(), "metrics": metrics, "classes": labels, @@ -195,46 +250,47 @@ def random_cv( results.to_parquet(results_file) # Get the inner state of the best estimator - best_estimator = search.best_estimator_ - # Annotate the state with xarray metadata - features = X_data.columns.tolist() - boxes = list(range(best_estimator.K_)) - box_centers = xr.DataArray( - best_estimator.S_.cpu().numpy(), - dims=["feature", "box"], - coords={"feature": features, "box": boxes}, - name="box_centers", - attrs={"description": "Centers of the boxes in feature space."}, - ) - box_assignments = xr.DataArray( - best_estimator.Lambda_.cpu().numpy(), - dims=["class", "box"], - coords={"class": labels, "box": boxes}, - name="box_assignments", - attrs={"description": "Assignments of samples to boxes."}, - ) - feature_weights = xr.DataArray( - best_estimator.W_.cpu().numpy(), - dims=["feature"], - coords={"feature": features}, - name="feature_weights", - attrs={"description": "Feature weights for each box."}, - ) - state = xr.Dataset( - { - "box_centers": box_centers, - "box_assignments": box_assignments, - "feature_weights": feature_weights, - }, - attrs={ - "description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.", - "grid": grid, - "level": level, - }, - ) - state_file = results_dir / "best_estimator_state.nc" - print(f"Storing best estimator state to {state_file}") - state.to_netcdf(state_file, engine="h5netcdf") + if model == "espa": + best_estimator = search.best_estimator_ + # Annotate the state with xarray metadata + features = X_data.columns.tolist() + boxes = list(range(best_estimator.K_)) + box_centers = xr.DataArray( + best_estimator.S_.cpu().numpy(), + dims=["feature", "box"], + coords={"feature": features, "box": boxes}, + name="box_centers", + attrs={"description": "Centers of the boxes in feature space."}, + ) + box_assignments = xr.DataArray( + best_estimator.Lambda_.cpu().numpy(), + dims=["class", "box"], + coords={"class": labels, "box": boxes}, + name="box_assignments", + attrs={"description": "Assignments of samples to boxes."}, + ) + feature_weights = xr.DataArray( + best_estimator.W_.cpu().numpy(), + dims=["feature"], + coords={"feature": features}, + name="feature_weights", + attrs={"description": "Feature weights for each box."}, + ) + state = xr.Dataset( + { + "box_centers": box_centers, + "box_assignments": box_assignments, + "feature_weights": feature_weights, + }, + attrs={ + "description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.", + "grid": grid, + "level": level, + }, + ) + state_file = results_dir / "best_estimator_state.nc" + print(f"Storing best estimator state to {state_file}") + state.to_netcdf(state_file, engine="h5netcdf") # Predict probabilities for all cells print("Predicting probabilities for all cells...")