From 98314fe8b3540490a3fc6374c376043480afaf39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Wed, 26 Nov 2025 18:11:15 +0100 Subject: [PATCH] Fix memory leak in aggregations --- Processing Documentation.md | 74 +++++++++ pixi.lock | 37 ++++- pyproject.toml | 2 +- src/entropice/aggregators.py | 282 ++++++++++++++++++----------------- src/entropice/arcticdem.py | 91 ++++++----- src/entropice/era5.py | 20 +-- src/entropice/grids.py | 21 +++ 7 files changed, 327 insertions(+), 200 deletions(-) create mode 100644 Processing Documentation.md diff --git a/Processing Documentation.md b/Processing Documentation.md new file mode 100644 index 0000000..01076b5 --- /dev/null +++ b/Processing Documentation.md @@ -0,0 +1,74 @@ +# Processing Documentation + +This document documents how long each processing step took and how much memory and compute (CPU & GPU) it needed. + +| Grid | ArcticDEM | Era5 | AlphaEarth | Darts | +| ----- | --------- | ---- | ---------- | ----- | +| Hex3 | [ ] | [/] | [ ] | [/] | +| Hex4 | [ ] | [/] | [ ] | [/] | +| Hex5 | [ ] | [/] | [ ] | [/] | +| Hex6 | [ ] | [ ] | [ ] | [ ] | +| Hpx6 | [x] | [/] | [ ] | [/] | +| Hpx7 | [ ] | [/] | [ ] | [/] | +| Hpx8 | [ ] | [/] | [ ] | [/] | +| Hpx9 | [ ] | [/] | [ ] | [/] | +| Hpx10 | [ ] | [ ] | [ ] | [ ] | + +## Grid creation + +The creation of grids did not take up any significant amount of memory or compute. +The time taken to create a grid was between few seconds for smaller levels up to a few minutes for the high levels. + +## DARTS + +Similar to grid creation, no significant amount of memory, compute or time needed. + +## ArcticDEM + +The download took around 8h with memory usage of about 10GB and no stronger limitations on compute. +The size of the resulted icechunk zarr datacube was approx. 160GB on disk which corresponse to approx. 270GB in memory if loaded in. + +The enrichment took around 2h on a single A100 GPU node (40GB) with a local dask cluster consisting of 7 processes, each using 2 threads and 30GB of memory, making up a total of 210GB of memory. +These settings can be changed easily to consume less memory by reducing the number of processes or threads. +More processes or thread could not be used to ensure that the GPU does not run out of memory. + +### Spatial aggregations into grids + +All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile +and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU. + +The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation. + +| grid | time | memory | processes | +| ----- | ------ | ------ | --------- | +| Hex3 | | | | +| Hex4 | | | | +| Hex5 | | | | +| Hex6 | | | | +| Hpx6 | 37 min | ~300GB | 40 | +| Hpx7 | | | | +| Hpx8 | | | | +| Hpx9 | 25m | ~300GB | 40 | +| Hpx10 | 34 min | ~300GB | 40 | + +## Alpha Earth + +The download was heavy limited through the scale of the input data, which is ~10m in the original dataset. +10m as a scale was not computationally feasible for the Google Earth Engine servers, thus each grid and level used another scale to aggregate and download the data. +Each scale was choosen so that each grid cell had around 10000px do estimate the aggregations from it. + +| grid | time | scale | +| ----- | ------- | ----- | +| Hex3 | | 1600 | +| Hex4 | | 600 | +| Hex5 | | 240 | +| Hex6 | | 90 | +| Hpx6 | 58 min | 1600 | +| Hpx7 | 3:16 h | 800 | +| Hpx8 | 13:19 h | 400 | +| Hpx9 | | 200 | +| Hpx10 | | 100 | + +## Era5 + +??? diff --git a/pixi.lock b/pixi.lock index aefb099..f83af07 100644 --- a/pixi.lock +++ b/pixi.lock @@ -475,6 +475,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/55/d2/507fd0ee4dd574d2bdbdeac5df83f39d2cae1ffe97d4622cca6f6bab39f1/botocore-1.40.70-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl @@ -609,6 +610,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: git+https://github.com/davbyr/xAnimate#750e03e480db309407e09f4ffe5f49522a4c4f9b - pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d + - pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/ad/8f9ff43ff49ef02c7b8202a42c32a1fe8de1276bba0e6f55609e19ff7585/xdggs-0.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/05/b9/b6a9cf72aef69c3e6db869dcc130e19452a658366dac9377f9cd32a76b80/xproj-0.2.1-py3-none-any.whl @@ -1410,6 +1412,13 @@ packages: - pkg:pypi/bokeh?source=compressed-mapping size: 5027028 timestamp: 1762557204752 +- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl + name: boost-histogram + version: 1.6.1 + sha256: c5700e53bf2d3d006f71610f71fc592d88dab3279dad300e8178dc084018b177 + requires_dist: + - numpy + requires_python: '>=3.9' - pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl name: boto3 version: 1.40.70 @@ -2810,7 +2819,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9 + sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -2865,6 +2874,7 @@ packages: - xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial - cupy-xarray>=0.1.4,<0.2 - memray>=1.19.1,<2 + - xarray-histogram>=0.2.2,<0.3 requires_python: '>=3.13,<3.14' editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 @@ -9373,6 +9383,31 @@ packages: - pkg:pypi/xarray?source=hash-mapping size: 989507 timestamp: 1763464377176 +- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl + name: xarray-histogram + version: 0.2.2 + sha256: 90682f9575131e5ea5d96feba5f8679e8c3fae039018d5e38f497696d55ccd5c + requires_dist: + - boost-histogram + - numpy + - xarray + - dask ; extra == 'full' + - scipy ; extra == 'full' + - dask ; extra == 'dev' + - scipy ; extra == 'dev' + - sphinx ; extra == 'dev' + - sphinx-book-theme ; extra == 'dev' + - ruff ; extra == 'dev' + - mypy>=1.5 ; extra == 'dev' + - pytest>=7.4 ; extra == 'dev' + - coverage ; extra == 'dev' + - pytest-cov ; extra == 'dev' + - dask ; extra == 'tests' + - scipy ; extra == 'tests' + - pytest>=7.4 ; extra == 'tests' + - coverage ; extra == 'tests' + - pytest-cov ; extra == 'tests' + requires_python: '>=3.11' - pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d name: xarray-spatial version: 0.1.dev519+g3a3120981 diff --git a/pyproject.toml b/pyproject.toml index d72c0d0..fbe3f25 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "xgboost>=3.1.1,<4", "s3fs>=2025.10.0,<2026", "xarray-spatial", - "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", + "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", ] [project.scripts] diff --git a/src/entropice/aggregators.py b/src/entropice/aggregators.py index e9cac8b..b3bff05 100644 --- a/src/entropice/aggregators.py +++ b/src/entropice/aggregators.py @@ -2,7 +2,6 @@ import gc import os -import sys from collections import defaultdict from collections.abc import Callable, Generator from concurrent.futures import ProcessPoolExecutor, as_completed @@ -33,6 +32,7 @@ from entropice import grids @dataclass(frozen=True) class _Aggregations: # ! The ordering is super important for this class! + _common: bool = False # If true, use _agg_cell_data_single_common and ignore orther flags mean: bool = True sum: bool = False std: bool = False @@ -46,8 +46,14 @@ class _Aggregations: if self.median and 0.5 in self.quantiles: raise ValueError("Median aggregation cannot be used together with quantile 0.5") + @classmethod + def common(cls): + return cls(_common=True) + @cache def __len__(self) -> int: + if self._common: + return 11 length = 0 if self.mean: length += 1 @@ -65,6 +71,8 @@ class _Aggregations: return length def aggnames(self) -> list[str]: + if self._common: + return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"] names = [] if self.mean: names.append("mean") @@ -83,7 +91,24 @@ class _Aggregations: names.append(f"p{q_int}") return names + def _agg_cell_data_single_common(self, flattened_var: xr.DataArray) -> np.ndarray: + others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"]) + cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32) + cell_data[0, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy() + cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy() + cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy() + cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy() + quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here! + cell_data[4:, ...] = flattened_var.quantile( + q=quantiles_to_compute, + dim="z", + skipna=True, + ).to_numpy() + return cell_data + def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray: + if self._common: + return self._agg_cell_data_single_common(flattened_var) others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"]) cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32) j = 0 @@ -106,6 +131,12 @@ class _Aggregations: quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here! if self.median: quantiles_to_compute.insert(0, 0.5) + + # Potential way to use cupy: + # cp.apply_along_axis( + # lambda x: cp.quantile(x[~cp.isnan(x)], q=quantiles), axis=0, # arr=data_gpu_raw + # ) + cell_data[j:, ...] = flattened_var.quantile( q=quantiles_to_compute, dim="z", @@ -173,7 +204,7 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations): spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"] flattened = cropped.stack(z=spatdims) # noqa: PD013 - # if flattened.z.size > 3000: + # if flattened.z.size > 1000000: # flattened = flattened.cupy.as_cupy() cell_data = aggregations.agg_cell_data(flattened) return cell_data @@ -187,6 +218,8 @@ def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggr else ["y", "x"] ) flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013 + # if flattened.z.size > 1000000: + # flattened = flattened.cupy.as_cupy() cell_data = aggregations.agg_cell_data(flattened) return cell_data @@ -218,7 +251,15 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]: - # ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413 + if grid_gdf.crs.to_epsg() == 4326: + crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian) + else: + crosses_antimeridian = pd.Series([False] * len(grid_gdf), index=grid_gdf.index) + + if crosses_antimeridian.any(): + grid_gdf_am = grid_gdf[crosses_antimeridian] + grid_gdf = grid_gdf[~crosses_antimeridian] + # Simple partitioning by splitting the GeoDataFrame into n_partitions parts centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y}) labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids) @@ -226,6 +267,9 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[ partition = grid_gdf[labels == i] yield partition + if crosses_antimeridian.any(): + yield grid_gdf_am + class _MemoryProfiler: def __init__(self): @@ -279,20 +323,20 @@ def _align_partition( # 1. unmeasurable # 2. ~0.1-0.3s # 3. ~0.05s + # => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes + # faster than a processpoolexecutor memprof.log_memory("Before reading partial raster", log=False) if callable(raster) and not isinstance(raster, xr.Dataset): raster = raster() - - if hasattr(raster, "__dask_graph__"): - graph_size = len(raster.__dask_graph__()) - graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2 - print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB") + need_to_close_raster = True + else: + need_to_close_raster = False others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]) - data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) - data = np.full(data_shape, np.nan, dtype=np.float32) + ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) + ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32) partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) partial_extent = partial_extent.buffered( @@ -304,57 +348,42 @@ def _align_partition( partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute() except Exception as e: print(f"Error cropping raster to partition extent: {e}") - return data + return ongrid - print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)") - memprof.log_memory("After reading partial raster", log=True) + if partial_raster.nbytes / 1e9 > 20: + print( + f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:" + f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)." + f" This may lead to out-of-memory errors." + ) + memprof.log_memory("After reading partial raster", log=False) - with ProcessPoolExecutor( - max_workers=60, - ) as executor: - futures = {} - for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): - # ? Splitting the geometries already here and passing only the cropped data to the worker to - # reduce pickling overhead - geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs)) - if len(geoms) == 0: - continue - elif len(geoms) == 1: - cropped = [_read_cell_data(partial_raster, geoms[0])] - else: - cropped = _read_split_cell_data(partial_raster, geoms) - futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i + for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): + try: + cell_data = _process_geom(row.geometry, partial_raster, aggregations) + except (SystemError, SystemExit, KeyboardInterrupt) as e: + raise e + except Exception as e: + print(f"Error processing cell {row['cell_id']}: {e}") + continue + ongrid[i, ...] = cell_data - # Clean up immediately after submitting - del geoms, cropped # Added this line + cell_ids = grids.convert_cell_ids(grid_partition_gdf) + dims = ["cell_ids", "variables", "aggregations"] + coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} + for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: + dims.append(dim) + coords[dim] = raster.coords[dim] - for future in as_completed(futures): - i = futures[future] - try: - cell_data = future.result() - except (SystemError, SystemExit, KeyboardInterrupt) as e: - raise e - except Exception as e: - print(f"Error processing cell {i}: {e}") - continue - data[i, ...] = cell_data - del cell_data, futures[future] - - # for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): - # try: - # cell_data = _process_geom(row.geometry, partial_raster, aggregations) - # except (SystemError, SystemExit, KeyboardInterrupt) as e: - # raise e - # except Exception as e: - # print(f"Error processing cell {row['cell_id']}: {e}") - # continue - # data[i, ...] = cell_data + ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables") partial_raster.close() - raster.close() - del partial_raster, raster, future, futures + del partial_raster + if need_to_close_raster: + raster.close() + del raster gc.collect() - memprof.log_memory("After cleaning", log=True) + memprof.log_memory("After cleaning", log=False) print("Finished processing partition") print("### Stopwatch summary ###\n") @@ -362,45 +391,8 @@ def _align_partition( print("### Memory summary ###\n") print(memprof.summary()) print("#########################") - # with ProcessPoolExecutor( - # max_workers=max_workers, - # initializer=_init_partial_raster_global, - # initargs=(raster, grid_partition_gdf, pxbuffer), - # ) as executor: - # futures = { - # executor.submit(_process_geom_global, row.geometry, aggregations): i - # for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()) - # } - # others_shape = tuple( - # [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] - # ) - # data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape) - # print(f"Allocating partial data array of shape {data_shape}") - # data = np.full(data_shape, np.nan, dtype=np.float32) - - # lap = time.time() - # for future in track( - # as_completed(futures), - # total=len(futures), - # description="Spatially aggregating ERA5 data...", - # ): - # print(f"time since last cell: {time.time() - lap:.2f}s") - # i = futures[future] - # try: - # cell_data = future.result() - # data[i, ...] = cell_data - # except Exception as e: - # print(f"Error processing cell {i}: {e}") - # continue - # finally: - # # Try to free memory - # del futures[future], future - # if "cell_data" in locals(): - # del cell_data - # gc.collect() - # lap = time.time() - return data + return ongrid @stopwatch("Aligning data with grid") @@ -409,47 +401,66 @@ def _align_data( raster: xr.Dataset | Callable[[], xr.Dataset], aggregations: _Aggregations, n_partitions: int, - max_workers: int, + concurrent_partitions: int, pxbuffer: int, ): - partial_data = {} + partial_ongrids = [] - with ProcessPoolExecutor( - max_workers=max_workers, - initializer=_init_worker, - # initializer=_init_raster_global, - # initargs=(raster,), - ) as executor: - futures = {} + _init_worker() + + if n_partitions < concurrent_partitions: + print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}") + concurrent_partitions = n_partitions + + if concurrent_partitions <= 1: for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)): - futures[ - executor.submit( - _align_partition, - grid_partition, - raster.copy() if isinstance(raster, xr.Dataset) else raster, - aggregations, - pxbuffer, - ) - ] = i + print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells") + part_ongrid = _align_partition( + grid_partition, + raster, # .copy() if isinstance(raster, xr.Dataset) else raster, + aggregations, + pxbuffer, + ) + partial_ongrids.append(part_ongrid) + else: + with ProcessPoolExecutor( + max_workers=concurrent_partitions, + initializer=_init_worker, + # initializer=_init_raster_global, + # initargs=(raster,), + ) as executor: + futures = {} + for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)): + futures[ + executor.submit( + _align_partition, + grid_partition, + raster.copy() if isinstance(raster, xr.Dataset) else raster, + aggregations, + pxbuffer, + ) + ] = i + if i == 6: + print("Breaking after 3 partitions for testing purposes") - print("Submitted all partitions, waiting for results...") + print("Submitted all partitions, waiting for results...") - for future in track( - as_completed(futures), - total=len(futures), - description="Processing grid partitions...", - ): - try: - i = futures[future] - print(f"Processed partition {i + 1}/{len(futures)}") - part_data = future.result() - partial_data[i] = part_data - except Exception as e: - print(f"Error processing partition {i}: {e}") - raise e + for future in track( + as_completed(futures), + total=len(futures), + description="Processing grid partitions...", + ): + try: + i = futures[future] + print(f"Processed partition {i + 1}/{len(futures)}") + part_ongrid = future.result() + partial_ongrids.append(part_ongrid) + except Exception as e: + print(f"Error processing partition {i}: {e}") + raise e - data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0) - return data + ongrid = xr.concat(partial_ongrids, dim="cell_ids") + return ongrid def aggregate_raster_into_grid( @@ -459,7 +470,7 @@ def aggregate_raster_into_grid( grid: Literal["hex", "healpix"], level: int, n_partitions: int = 20, - max_workers: int = 5, + concurrent_partitions: int = 5, pxbuffer: int = 15, ): """Aggregate raster data into grid cells. @@ -471,19 +482,20 @@ def aggregate_raster_into_grid( grid (Literal["hex", "healpix"]): The type of grid to use. level (int): The level of the grid. n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20. - max_workers (int, optional): Maximum number of worker processes. Defaults to 5. + concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions. + Defaults to 5. pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15. Returns: xr.Dataset: Aggregated data aligned with the grid. """ - aligned = _align_data( + ongrid = _align_data( grid_gdf, raster, aggregations, n_partitions=n_partitions, - max_workers=max_workers, + concurrent_partitions=concurrent_partitions, pxbuffer=pxbuffer, ) # Dims of aligned: (cell_ids, variables, aggregations, other dims...) @@ -491,16 +503,6 @@ def aggregate_raster_into_grid( if callable(raster) and not isinstance(raster, xr.Dataset): raster = raster() - cell_ids = grids.get_cell_ids(grid, level) - - dims = ["cell_ids", "variables", "aggregations"] - coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} - for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: - dims.append(dim) - coords[dim] = raster.coords[dim] - - ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables") - gridinfo = { "grid_name": "h3" if grid == "hex" else grid, "level": level, diff --git a/src/entropice/arcticdem.py b/src/entropice/arcticdem.py index 3d22aa5..5db1ae4 100644 --- a/src/entropice/arcticdem.py +++ b/src/entropice/arcticdem.py @@ -10,6 +10,7 @@ import cupyx.scipy.signal import cyclopts import dask.array import dask.distributed as dd +import geopandas as gpd import icechunk import icechunk.xarray import numpy as np @@ -48,7 +49,7 @@ def download(): @dataclass -class KernelFactory: +class _KernelFactory: res: int size_px: int @@ -77,14 +78,14 @@ class KernelFactory: return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel)) -def tpi_cupy(chunk, kernels: KernelFactory): +def tpi_cupy(chunk, kernels: _KernelFactory): kernel = kernels.ring() kernel = kernel / cp.nansum(kernel) tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same") return tpi -def tri_cupy(chunk, kernels: KernelFactory): +def tri_cupy(chunk, kernels: _KernelFactory): kernel = kernels.tri() c2 = chunk**2 focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same") @@ -93,7 +94,7 @@ def tri_cupy(chunk, kernels: KernelFactory): return tri -def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory): +def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory): slope_rad = slope * (cp.pi / 180) aspect_rad = aspect * (cp.pi / 180) aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad) @@ -150,9 +151,9 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array: res = 32 # 32m resolution - small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m) - medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m) - large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m) + small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m) + medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m) + large_kernels = _KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m) # Check if there is data in the chunk if np.all(np.isnan(chunk)): @@ -306,49 +307,61 @@ def enrich(): def _open_adem(): arcticdem_store = get_arcticdem_stores() accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) + # config = icechunk.RepositoryConfig.default() + # config.caching = icechunk.CachingConfig( + # num_snapshot_nodes=None, + # num_chunk_refs=None, + # num_transaction_changes=None, + # num_bytes_attributes=None, + # num_bytes_chunks=None, + # ) + # repo = icechunk.Repository.open(accessor.repo.storage, config=config) + # session = repo.readonly_session("main") + # adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref") adem = accessor.open_xarray() + # Don't use the datamask + adem = adem.drop_vars("datamask") assert {"x", "y"} == set(adem.dims) assert adem.odc.crs == "EPSG:3413" return adem @cli.command() -def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20): +def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20): mp.set_start_method("forkserver", force=True) + with ( + dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster, + ): + print(cluster) + with stopwatch(f"Loading {grid} grid at level {level}"): + grid_gdf = grids.open(grid, level) + # ? Mask out water, since we don't want to aggregate over oceans + grid_gdf = watermask.clip_grid(grid_gdf) + # ? Mask out cells that do not intersect with ArcticDEM extent + arcticdem_store = get_arcticdem_stores() + accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) + extent_info = gpd.read_parquet(accessor._aux_dir / "ArcticDEM_Mosaic_Index_v4_1_32m.parquet") + grid_gdf = grid_gdf[grid_gdf.intersects(extent_info.geometry.union_all())] - with stopwatch(f"Loading {grid} grid at level {level}"): - grid_gdf = grids.open(grid, level) - # ? Mask out water, since we don't want to aggregate over oceans - grid_gdf = watermask.clip_grid(grid_gdf) + aggregations = _Aggregations.common() + aggregated = aggregate_raster_into_grid( + _open_adem, + grid_gdf, + aggregations, + grid, + level, + concurrent_partitions=concurrent_partitions, + n_partitions=400, # 400 # Results in ~10GB large patches + pxbuffer=30, + ) + store = get_arcticdem_stores(grid, level) + with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"): + aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) - aggregations = _Aggregations( - mean=True, - sum=False, - std=True, - min=True, - max=True, - median=True, - quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99), - ) + print("Aggregation complete.") - aggregated = aggregate_raster_into_grid( - _open_adem, - grid_gdf, - aggregations, - grid, - level, - max_workers=max_workers, - n_partitions=200, # Results in 10GB large patches - pxbuffer=30, - ) - store = get_arcticdem_stores(grid, level) - with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"): - aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) - - print("Aggregation complete.") - - print("### Finished ArcticDEM processing ###") - stopwatch.summary() + print("### Finished ArcticDEM processing ###") + stopwatch.summary() if __name__ == "__main__": diff --git a/src/entropice/era5.py b/src/entropice/era5.py index b589275..5677e14 100644 --- a/src/entropice/era5.py +++ b/src/entropice/era5.py @@ -677,25 +677,7 @@ def spatial_agg( assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}" unaligned = _correct_longs(unaligned) - # with stopwatch("Precomputing cell geometries"): - # with ProcessPoolExecutor(max_workers=20) as executor: - # cell_geometries = list( - # executor.map( - # _get_corrected_geoms, - # [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()], - # ) - # ) - # data = _align_data(cell_geometries, unaligned) - # aggregated = _create_aligned(unaligned, data, grid, level) - - aggregations = _Aggregations( - mean=True, - sum=False, - std=True, - min=True, - max=True, - quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99], - ) + aggregations = _Aggregations.common() aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level) aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)}) diff --git a/src/entropice/grids.py b/src/entropice/grids.py index 531b2dc..c562d45 100644 --- a/src/entropice/grids.py +++ b/src/entropice/grids.py @@ -66,6 +66,27 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int): return cell_ids +def convert_cell_ids(grid_gdf: gpd.GeoDataFrame): + """Convert cell IDs in a GeoDataFrame to xdggs-compatible format. + + Args: + grid_gdf (gpd.GeoDataFrame): The input grid GeoDataFrame. + + Returns: + pd.Series: The the converted cell IDs. + + """ + if "cell_id" not in grid_gdf.columns: + raise ValueError("The GeoDataFrame must contain a 'cell_id' column.") + + # Check if cell IDs are hex strings (for H3) + if isinstance(grid_gdf["cell_id"].iloc[0], str): + grid_gdf = grid_gdf.copy() + grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda cid: int(cid, 16)) + + return grid_gdf["cell_id"] + + def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]: hex_batch = [] hex_id_batch = []