From 33c9667383653533cba80826254fdf2a4a600495 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 30 Nov 2025 01:12:00 +0100 Subject: [PATCH] Finally fix aggregations also for era5 --- Processing Documentation.md | 64 +++++++++++++++++++---- scripts/03era5.sh | 18 +++---- scripts/04arcticdem.sh | 15 ++++++ src/entropice/aggregators.py | 99 ++++++++++++++++++++++++------------ src/entropice/era5.py | 7 ++- 5 files changed, 151 insertions(+), 52 deletions(-) create mode 100644 scripts/04arcticdem.sh diff --git a/Processing Documentation.md b/Processing Documentation.md index f7e5190..f34cd5c 100644 --- a/Processing Documentation.md +++ b/Processing Documentation.md @@ -48,14 +48,14 @@ Each scale was choosen so that each grid cell had around 10000px do estimate the | grid | time | scale | | ----- | ------- | ----- | -| Hex3 | | 1600 | -| Hex4 | | 600 | -| Hex5 | | 240 | +| Hex3 | 46 min | 1600 | +| Hex4 | 5:04 h | 600 | +| Hex5 | 31:48 h | 240 | | Hex6 | | 90 | | Hpx6 | 58 min | 1600 | | Hpx7 | 3:16 h | 800 | | Hpx8 | 13:19 h | 400 | -| Hpx9 | | 200 | +| Hpx9 | 51:33 h | 200 | | Hpx10 | | 100 | ## Era5 @@ -78,13 +78,59 @@ For geometries crossing the antimeridian, geometries are corrected. | grid | method | | ----- | ----------- | | Hex3 | Common | -| Hex4 | Common | -| Hex5 | Mean | +| Hex4 | Mean | +| Hex5 | Interpolate | | Hex6 | Interpolate | | Hpx6 | Common | -| Hpx7 | Common | -| Hpx8 | Common | -| Hpx9 | Mean | +| Hpx7 | Mean | +| Hpx8 | Mean | +| Hpx9 | Interpolate | | Hpx10 | Interpolate | +- hex level 3 + min: 30.0 + max: 850.0 + mean: 251.25216674804688 + median: 235.5 +- hex level 4 + min: 8.0 + max: 166.0 + mean: 47.2462158203125 + median: 44.0 +- hex level 5 + min: 3.0 + max: 41.0 + mean: 11.164162635803223 + median: 10.0 +- hex level 6 + min: 2.0 + max: 14.0 + mean: 4.509947776794434 + median: 4.0 +- healpix level 6 + min: 25.0 + max: 769.0 + mean: 214.97296142578125 + median: 204.0 +healpix level 7 + min: 9.0 + max: 231.0 + mean: 65.91140747070312 + median: 62.0 +healpix level 8 + min: 4.0 + max: 75.0 + mean: 22.516725540161133 + median: 21.0 +healpix level 9 + min: 2.0 + max: 29.0 + mean: 8.952080726623535 + median: 9.0 +healpix level 10 + min: 2.0 + max: 15.0 + mean: 4.361577987670898 + median: 4.0 + ??? diff --git a/scripts/03era5.sh b/scripts/03era5.sh index 60f3e07..fed0203 100644 --- a/scripts/03era5.sh +++ b/scripts/03era5.sh @@ -3,13 +3,13 @@ # pixi run era5 download # pixi run era5 enrich -pixi run era5 spatial-agg --grid hex --level 3 -pixi run era5 spatial-agg --grid hex --level 4 -pixi run era5 spatial-agg --grid hex --level 5 -pixi run era5 spatial-agg --grid hex --level 6 +pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20 -pixi run era5 spatial-agg --grid healpix --level 6 -pixi run era5 spatial-agg --grid healpix --level 7 -pixi run era5 spatial-agg --grid healpix --level 8 -pixi run era5 spatial-agg --grid healpix --level 9 -pixi run era5 spatial-agg --grid healpix --level 10 +pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20 +pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20 diff --git a/scripts/04arcticdem.sh b/scripts/04arcticdem.sh new file mode 100644 index 0000000..5c8edd8 --- /dev/null +++ b/scripts/04arcticdem.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# pixi run arcticdem download +# pixi run arcticdem enrich + +pixi run arcticdem aggregate --grid hex --level 3 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid hex --level 4 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid hex --level 5 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid hex --level 6 --concurrent-partitions 20 + +pixi run arcticdem aggregate --grid healpix --level 6 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid healpix --level 7 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid healpix --level 8 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid healpix --level 9 --concurrent-partitions 20 +pixi run arcticdem aggregate --grid healpix --level 10 --concurrent-partitions 20 diff --git a/src/entropice/aggregators.py b/src/entropice/aggregators.py index 05047e3..e082461 100644 --- a/src/entropice/aggregators.py +++ b/src/entropice/aggregators.py @@ -147,6 +147,7 @@ class _Aggregations: dim="z", skipna=True, ).to_numpy() + return cell_data def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray: if isinstance(flattened, xr.DataArray): @@ -199,7 +200,6 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis # Split geometry in case it crossed antimeridian else: geoms = _split_antimeridian_cell(geom) - geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms] geoms = list(filter(lambda g: _check_geom(gbox, g), geoms)) return geoms @@ -373,16 +373,45 @@ def _align_partition( memprof.log_memory("Before reading partial raster", log=False) need_to_close_raster = False + if raster is None: + assert shared_raster is not None, "Shared raster is not initialized in worker" # print("Using shared raster in worker") raster = shared_raster - elif callable(raster) and not isinstance(raster, xr.Dataset): + elif callable(raster): # print("Loading raster in partition") raster = raster() need_to_close_raster = True # else: # print("Using provided raster in partition") + # Partition the raster if necessary + is_raster_in_memory = isinstance(raster, xr.Dataset) and all( + isinstance(raster[var].data, np.ndarray) for var in raster.data_vars + ) + if is_raster_in_memory: + partial_raster = raster + else: + partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) + partial_extent = partial_extent.buffered( + raster.odc.geobox.resolution.x * pxbuffer, + raster.odc.geobox.resolution.y * pxbuffer, + ) # buffer by pxbuffer pixels + with stopwatch("Cropping raster to partition extent", log=False): + try: + partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute() + except Exception as e: + print(f"Error cropping raster to partition extent: {e}") + raise e + + if partial_raster.nbytes / 1e9 > 20: + print( + f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:" + f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)." + f" This may lead to out-of-memory errors." + ) + memprof.log_memory("After reading partial raster", log=False) + if aggregations is None: cell_ids = grids.convert_cell_ids(grid_partition_gdf) if grid_partition_gdf.crs.to_epsg() == 4326: @@ -402,28 +431,9 @@ def _align_partition( ) # ?: Cubic does not work with NaNs in xarray interp with stopwatch("Interpolating data to grid centroids", log=False): - ongrid = raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan}) + ongrid = partial_raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan}) memprof.log_memory("After interpolating data", log=False) else: - partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) - partial_extent = partial_extent.buffered( - raster.odc.geobox.resolution.x * pxbuffer, - raster.odc.geobox.resolution.y * pxbuffer, - ) # buffer by pxbuffer pixels - with stopwatch("Cropping raster to partition extent", log=False): - try: - partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute() - except Exception as e: - print(f"Error cropping raster to partition extent: {e}") - raise e - - if partial_raster.nbytes / 1e9 > 20: - print( - f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:" - f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)." - f" This may lead to out-of-memory errors." - ) - memprof.log_memory("After reading partial raster", log=False) others_shape = tuple( [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] ) @@ -449,9 +459,8 @@ def _align_partition( ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables") - partial_raster.close() - del partial_raster - + partial_raster.close() + del partial_raster if need_to_close_raster: raster.close() del raster @@ -477,6 +486,27 @@ def _align_data( concurrent_partitions: int, pxbuffer: int, ): + # ? Logic for memory management of raster dataset in multiprocessing scenarios + # There are multiple possible scenarios for the raster dataset: + # Raster can be either: + # 1. A in-memory xarray.Dataset object + # 2. A callable that returns an xarray.Dataset object + # 3. A lazy xarray.Dataset object (e.g., backed by zarr or icechunk) + # Concurrent partitions can be either: + # A. 1 (no multiprocessing) + # B. >1 (multiprocessing with multiple workers in fork) + # C. >1 (multiprocessing with multiple workers in spawn/forkserver) + + # This results into potential working modes: + # 1. Work on the complete raster in memory (no partitioning of the raster is needed) + # A. Pass it to the function + # B. Make a raster handle in each worker + # C. Move it into a shared memory buffer (NOT Implemented) + # 2. & 3. Work on partitions of the raster + # A. Load a partition into memory in each task + # B. & C. Make a raster handle in each worker and load partitions into memory in each task + # ! Note, this is not YET implemented, but this is how it SHOULD be implemented + partial_ongrids = [] if isinstance(grid_gdf, list): @@ -501,18 +531,21 @@ def _align_data( ) partial_ongrids.append(part_ongrid) else: + is_raster_in_memory = isinstance(raster, xr.Dataset) and all( + isinstance(raster[var].data, np.ndarray) for var in raster.data_vars + ) + is_mpfork = mp.get_start_method(allow_none=True) == "fork" + taskarg = None if is_raster_in_memory else raster + # For spawn or forkserver, we need to copy the raster into each worker + workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,) # For mp start method fork, we can share the raster dataset between workers - if mp.get_start_method(allow_none=True) == "fork": - _init_worker(raster if isinstance(raster, xr.Dataset) else None) - initargs = (None,) - else: - # For spawn or forkserver, we need to copy the raster into each worker - initargs = (raster if isinstance(raster, xr.Dataset) else None,) + if mp.get_start_method(allow_none=True) == "fork" and is_raster_in_memory: + _init_worker(raster) with ProcessPoolExecutor( max_workers=concurrent_partitions, initializer=_init_worker, - initargs=initargs, + initargs=workerargs, ) as executor: futures = {} for i, grid_partition in enumerate(grid_partitions): @@ -520,7 +553,7 @@ def _align_data( executor.submit( _align_partition, grid_partition, - None if isinstance(raster, xr.Dataset) else raster, + taskarg, aggregations, pxbuffer, ) diff --git a/src/entropice/era5.py b/src/entropice/era5.py index afefb2a..76d12c5 100644 --- a/src/entropice/era5.py +++ b/src/entropice/era5.py @@ -674,6 +674,11 @@ def spatial_agg( # ? Mask out water, since we don't want to aggregate over oceans grid_gdf = watermask.clip_grid(grid_gdf) grid_gdf = grid_gdf.to_crs("epsg:4326") + # There is a small group of invalid geometry in healpix 10, filter it out manually + # They are invalid because the watermask clipping does not handle the antimeridian well + if grid == "healpix" and level == 10: + invalid_cell_id = [3059646, 3063547] + grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)] aggregations = { "hex": { @@ -685,7 +690,7 @@ def spatial_agg( "healpix": { 6: _Aggregations.common(), 7: _Aggregations.common(), - 8: _Aggregations.common(), + 8: _Aggregations(mean=True), 9: _Aggregations(mean=True), 10: "interpolate", },