diff --git a/Processing Documentation.md b/Processing Documentation.md index 01076b5..f7e5190 100644 --- a/Processing Documentation.md +++ b/Processing Documentation.md @@ -38,18 +38,7 @@ All spatial aggregations relied heavily on CPU compute, since Cupy lacking suppo and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU. The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation. - -| grid | time | memory | processes | -| ----- | ------ | ------ | --------- | -| Hex3 | | | | -| Hex4 | | | | -| Hex5 | | | | -| Hex6 | | | | -| Hpx6 | 37 min | ~300GB | 40 | -| Hpx7 | | | | -| Hpx8 | | | | -| Hpx9 | 25m | ~300GB | 40 | -| Hpx10 | 34 min | ~300GB | 40 | +All spatial aggregations into the different grids done took around 30 min each, with a total memory peak of ~300 GB partitioned over 40 processes. ## Alpha Earth @@ -71,4 +60,31 @@ Each scale was choosen so that each grid cell had around 10000px do estimate the ## Era5 +### Spatial aggregations into grids + +All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile +and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU. + +The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation. + +Since the resolution of the ERA5 dataset is spatially smaller than the resolution of the higher-resolution, different aggregations methods where used for different grid-levels: + +- Common aggregations: mean, min, max, std, median, p01, p05, p25, p75, p95, p99 for low resolution grids +- Only mean aggregations for medium resolution grids +- Linar interpolation for high resolution grids + +For geometries crossing the antimeridian, geometries are corrected. + +| grid | method | +| ----- | ----------- | +| Hex3 | Common | +| Hex4 | Common | +| Hex5 | Mean | +| Hex6 | Interpolate | +| Hpx6 | Common | +| Hpx7 | Common | +| Hpx8 | Common | +| Hpx9 | Mean | +| Hpx10 | Interpolate | + ??? diff --git a/pixi.lock b/pixi.lock index f83af07..591f7f6 100644 --- a/pixi.lock +++ b/pixi.lock @@ -462,6 +462,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl @@ -497,6 +498,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/11/a8/c6a4b901d17399c77cd81fb001ce8961e9f5e04d3daf27e8925cb012e163/docutils-0.22.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a3/cf/7feb3222d770566ca9eaf0bf6922745fadd1ed7ab11832520063a515c240/ecmwf_datastores_client-0.4.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/54/5e3b0e41799e17e5eff1547fda4aab53878c0adb4243de6b95f8ddef899e/ee_extra-2025.7.2-py3-none-any.whl @@ -788,6 +790,15 @@ packages: - jupyter-book ; extra == 'doc' - vl-convert-python ; extra == 'doc' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl + name: antimeridian + version: 0.4.5 + sha256: 8b1f82c077d2c48eae0a6606759cfec9133a0701250371cb707a56959451d9dd + requires_dist: + - numpy>=1.22.4 + - shapely>=2.0 + - click>=8.1.6 ; extra == 'cli' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/anyio-4.11.0-pyhcf101f3_0.conda sha256: 7378b5b9d81662d73a906fabfc2fb81daddffe8dc0680ed9cda7a9562af894b0 md5: 814472b61da9792fae28156cb9ee54f5 @@ -2763,6 +2774,18 @@ packages: - pytest ; extra == 'test' - cloudpickle ; extra == 'test' requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl + name: duckdb + version: 1.4.2 + sha256: a456adbc3459c9dcd99052fad20bd5f8ef642be5b04d09590376b2eb3eb84f5c + requires_dist: + - ipython ; extra == 'all' + - fsspec ; extra == 'all' + - numpy ; extra == 'all' + - pandas ; extra == 'all' + - pyarrow ; extra == 'all' + - adbc-driver-manager ; extra == 'all' + requires_python: '>=3.9.0' - pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl name: earthengine-api version: 1.7.1 @@ -2819,7 +2842,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30 + sha256: c335ffb8f5ffc53929fcd9d656087692b6e9918938384df60d136124ca5365bc requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -2875,6 +2898,8 @@ packages: - cupy-xarray>=0.1.4,<0.2 - memray>=1.19.1,<2 - xarray-histogram>=0.2.2,<0.3 + - antimeridian>=0.4.5,<0.5 + - duckdb>=1.4.2,<2 requires_python: '>=3.13,<3.14' editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 diff --git a/pyproject.toml b/pyproject.toml index fbe3f25..4f0379c 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "xgboost>=3.1.1,<4", "s3fs>=2025.10.0,<2026", "xarray-spatial", - "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", + "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2", ] [project.scripts] diff --git a/src/entropice/aggregators.py b/src/entropice/aggregators.py index b3bff05..05047e3 100644 --- a/src/entropice/aggregators.py +++ b/src/entropice/aggregators.py @@ -1,6 +1,7 @@ """Aggregation helpers.""" import gc +import multiprocessing as mp import os from collections import defaultdict from collections.abc import Callable, Generator @@ -9,7 +10,11 @@ from dataclasses import dataclass, field from functools import cache from typing import Literal +import antimeridian +import cudf +import cuml.cluster import geopandas as gpd +import matplotlib.pyplot as plt import numpy as np import odc.geo.geobox import pandas as pd @@ -250,7 +255,18 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati return cell_data -def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]: +def partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int, plot: bool = False) -> Generator[gpd.GeoDataFrame]: + """Partition the input GeoDataFrame into n_partitions parts. + + Args: + grid_gdf (gpd.GeoDataFrame): The input GeoDataFrame to partition. + n_partitions (int): The number of partitions. + plot (bool, optional): Whether to plot the partitions. Defaults to True. + + Yields: + Generator[gpd.GeoDataFrame]: Partitions of the input GeoDataFrame. + + """ if grid_gdf.crs.to_epsg() == 4326: crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian) else: @@ -262,7 +278,24 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[ # Simple partitioning by splitting the GeoDataFrame into n_partitions parts centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y}) - labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids) + + # use cuml and cudf if len of centroids is larger than 100000 + if len(centroids) > 100000: + print(f"Using cuML KMeans for partitioning {len(centroids)} centroids") + centroids_cudf = cudf.DataFrame.from_pandas(centroids) + kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42) + labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy() + else: + labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids) + + if plot: + grid_gdf = grid_gdf.copy() + grid_gdf["partition"] = labels + ax = grid_gdf.plot(column="partition", categorical=True, legend=True, figsize=(10, 10)) + if crosses_antimeridian.any(): + grid_gdf_am.plot(ax=ax, color="red", edgecolor="black", alpha=0.5) + ax.set_title("Grid partitions") + plt.show() for i in range(n_partitions): partition = grid_gdf[labels == i] yield partition @@ -292,19 +325,30 @@ class _MemoryProfiler: memprof = None +shared_raster = None -def _init_worker(): +def _init_worker(r: xr.Dataset | None): global memprof + global shared_raster memprof = _MemoryProfiler() + if r is not None: + # print("Initializing shared raster in worker") + shared_raster = r def _align_partition( grid_partition_gdf: gpd.GeoDataFrame, - raster: xr.Dataset | Callable[[], xr.Dataset], - aggregations: _Aggregations, + raster: xr.Dataset | Callable[[], xr.Dataset] | None, + aggregations: _Aggregations | None, # None -> Interpolation pxbuffer: int, ): + # ? This function is expected to run inside a worker process + # It heavily utilizes different techniques to reduce memory usage such as + # Lazy operations, reading only necessary data, and cleaning up memory after use. + # Shared in-memory raster datasets are used when possible to avoid duplicating large datasets in memory. + # Shared raster datasets only work when using the "fork" start method for multiprocessing. + # Strategy for each cell: # 1. Correct the geometry to account for the antimeridian # 2. Cop the dataset and load the data into memory @@ -328,122 +372,159 @@ def _align_partition( memprof.log_memory("Before reading partial raster", log=False) - if callable(raster) and not isinstance(raster, xr.Dataset): + need_to_close_raster = False + if raster is None: + # print("Using shared raster in worker") + raster = shared_raster + elif callable(raster) and not isinstance(raster, xr.Dataset): + # print("Loading raster in partition") raster = raster() need_to_close_raster = True - else: - need_to_close_raster = False + # else: + # print("Using provided raster in partition") - others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]) - ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) - ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32) - - partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) - partial_extent = partial_extent.buffered( - raster.odc.geobox.resolution.x * pxbuffer, - raster.odc.geobox.resolution.y * pxbuffer, - ) # buffer by pxbuffer pixels - with stopwatch("Cropping raster to partition extent", log=False): - try: - partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute() - except Exception as e: - print(f"Error cropping raster to partition extent: {e}") - return ongrid - - if partial_raster.nbytes / 1e9 > 20: - print( - f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:" - f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)." - f" This may lead to out-of-memory errors." + if aggregations is None: + cell_ids = grids.convert_cell_ids(grid_partition_gdf) + if grid_partition_gdf.crs.to_epsg() == 4326: + centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid) + cx = centroids.apply(lambda p: p.x) + cy = centroids.apply(lambda p: p.y) + else: + centroids = grid_partition_gdf.geometry.centroid + cx = centroids.x + cy = centroids.y + interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids}) + interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids}) + interp_coords = ( + {"latitude": interp_y, "longitude": interp_x} + if "latitude" in raster.dims and "longitude" in raster.dims + else {"y": interp_y, "x": interp_x} ) - memprof.log_memory("After reading partial raster", log=False) + # ?: Cubic does not work with NaNs in xarray interp + with stopwatch("Interpolating data to grid centroids", log=False): + ongrid = raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan}) + memprof.log_memory("After interpolating data", log=False) + else: + partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) + partial_extent = partial_extent.buffered( + raster.odc.geobox.resolution.x * pxbuffer, + raster.odc.geobox.resolution.y * pxbuffer, + ) # buffer by pxbuffer pixels + with stopwatch("Cropping raster to partition extent", log=False): + try: + partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute() + except Exception as e: + print(f"Error cropping raster to partition extent: {e}") + raise e - for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): - try: - cell_data = _process_geom(row.geometry, partial_raster, aggregations) - except (SystemError, SystemExit, KeyboardInterrupt) as e: - raise e - except Exception as e: - print(f"Error processing cell {row['cell_id']}: {e}") - continue - ongrid[i, ...] = cell_data + if partial_raster.nbytes / 1e9 > 20: + print( + f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:" + f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)." + f" This may lead to out-of-memory errors." + ) + memprof.log_memory("After reading partial raster", log=False) + others_shape = tuple( + [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] + ) + ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) + ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32) - cell_ids = grids.convert_cell_ids(grid_partition_gdf) - dims = ["cell_ids", "variables", "aggregations"] - coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} - for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: - dims.append(dim) - coords[dim] = raster.coords[dim] + for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): + try: + cell_data = _process_geom(row.geometry, partial_raster, aggregations) + except (SystemError, SystemExit, KeyboardInterrupt) as e: + raise e + except Exception as e: + print(f"Error processing cell {row['cell_id']}: {e}") + continue + ongrid[i, ...] = cell_data - ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables") + cell_ids = grids.convert_cell_ids(grid_partition_gdf) + dims = ["cell_ids", "variables", "aggregations"] + coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} + for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: + dims.append(dim) + coords[dim] = raster.coords[dim] + + ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables") + + partial_raster.close() + del partial_raster - partial_raster.close() - del partial_raster if need_to_close_raster: raster.close() del raster gc.collect() memprof.log_memory("After cleaning", log=False) - print("Finished processing partition") - print("### Stopwatch summary ###\n") - print(stopwatch.summary()) - print("### Memory summary ###\n") - print(memprof.summary()) - print("#########################") + # print("Finished processing partition") + # print("### Stopwatch summary ###\n") + # print(stopwatch.summary()) + # print("### Memory summary ###\n") + # print(memprof.summary()) + # print("#########################") return ongrid @stopwatch("Aligning data with grid") def _align_data( - grid_gdf: gpd.GeoDataFrame, + grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame], raster: xr.Dataset | Callable[[], xr.Dataset], - aggregations: _Aggregations, - n_partitions: int, + aggregations: _Aggregations | None, + n_partitions: int | None, concurrent_partitions: int, pxbuffer: int, ): partial_ongrids = [] - _init_worker() + if isinstance(grid_gdf, list): + n_partitions = len(grid_gdf) + grid_partitions = grid_gdf + else: + grid_partitions = partition_grid(grid_gdf, n_partitions) if n_partitions < concurrent_partitions: print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}") concurrent_partitions = n_partitions if concurrent_partitions <= 1: - for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)): - print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells") + _init_worker(None) # No need to use a shared raster, since the processing is done in the main process + for i, grid_partition in enumerate(grid_partitions): + # print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells") part_ongrid = _align_partition( grid_partition, - raster, # .copy() if isinstance(raster, xr.Dataset) else raster, + raster, aggregations, pxbuffer, ) partial_ongrids.append(part_ongrid) else: + # For mp start method fork, we can share the raster dataset between workers + if mp.get_start_method(allow_none=True) == "fork": + _init_worker(raster if isinstance(raster, xr.Dataset) else None) + initargs = (None,) + else: + # For spawn or forkserver, we need to copy the raster into each worker + initargs = (raster if isinstance(raster, xr.Dataset) else None,) + with ProcessPoolExecutor( max_workers=concurrent_partitions, initializer=_init_worker, - # initializer=_init_raster_global, - # initargs=(raster,), + initargs=initargs, ) as executor: futures = {} - for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)): + for i, grid_partition in enumerate(grid_partitions): futures[ executor.submit( _align_partition, grid_partition, - raster.copy() if isinstance(raster, xr.Dataset) else raster, + None if isinstance(raster, xr.Dataset) else raster, aggregations, pxbuffer, ) ] = i - if i == 6: - print("Breaking after 3 partitions for testing purposes") - - print("Submitted all partitions, waiting for results...") for future in track( as_completed(futures), @@ -452,7 +533,7 @@ def _align_data( ): try: i = futures[future] - print(f"Processed partition {i + 1}/{len(futures)}") + # print(f"Processed partition {i + 1}/{len(futures)}") part_ongrid = future.result() partial_ongrids.append(part_ongrid) except Exception as e: @@ -465,11 +546,11 @@ def _align_data( def aggregate_raster_into_grid( raster: xr.Dataset | Callable[[], xr.Dataset], - grid_gdf: gpd.GeoDataFrame, - aggregations: _Aggregations, + grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame], + aggregations: _Aggregations | Literal["interpolate"], grid: Literal["hex", "healpix"], level: int, - n_partitions: int = 20, + n_partitions: int | None = 20, concurrent_partitions: int = 5, pxbuffer: int = 15, ): @@ -477,11 +558,13 @@ def aggregate_raster_into_grid( Args: raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it. - grid_gdf (gpd.GeoDataFrame): The grid to aggregate into. - aggregations (_Aggregations): The aggregations to perform. + grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into. + If a list of GeoDataFrames is provided, each will be processed as a separate partition. + No further partitioning will be done and the n_partitions argument will be ignored. + aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform. grid (Literal["hex", "healpix"]): The type of grid to use. level (int): The level of the grid. - n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20. + n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20. concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions. Defaults to 5. pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15. @@ -493,7 +576,7 @@ def aggregate_raster_into_grid( ongrid = _align_data( grid_gdf, raster, - aggregations, + aggregations if aggregations != "interpolate" else None, n_partitions=n_partitions, concurrent_partitions=concurrent_partitions, pxbuffer=pxbuffer, diff --git a/src/entropice/era5.py b/src/entropice/era5.py index 5677e14..afefb2a 100644 --- a/src/entropice/era5.py +++ b/src/entropice/era5.py @@ -87,10 +87,12 @@ import numpy as np import odc.geo import odc.geo.xr import pandas as pd +import shapely.geometry import ultraplot as uplt import xarray as xr import xdggs import xvec +from rasterio.features import shapes from rich import pretty, print, traceback from stopuhr import stopwatch @@ -643,6 +645,7 @@ def viz( # =========================== +@stopwatch("Correcting longitudes to -180 to 180") def _correct_longs(ds: xr.Dataset) -> xr.Dataset: return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude") @@ -651,6 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset: def spatial_agg( grid: Literal["hex", "healpix"], level: int, + concurrent_partitions: int = 20, ): """Perform spatial aggregation of ERA5 data to grid cells. @@ -661,6 +665,8 @@ def spatial_agg( Args: grid ("hex" | "healpix"): Grid type. level (int): Grid resolution level. + concurrent_partitions (int, optional): Number of concurrent partitions to process. + Defaults to 20. """ with stopwatch(f"Loading {grid} grid at level {level}"): @@ -669,6 +675,23 @@ def spatial_agg( grid_gdf = watermask.clip_grid(grid_gdf) grid_gdf = grid_gdf.to_crs("epsg:4326") + aggregations = { + "hex": { + 3: _Aggregations.common(), + 4: _Aggregations.common(), + 5: _Aggregations(mean=True), + 6: "interpolate", + }, + "healpix": { + 6: _Aggregations.common(), + 7: _Aggregations.common(), + 8: _Aggregations.common(), + 9: _Aggregations(mean=True), + 10: "interpolate", + }, + } + aggregations = aggregations[grid][level] + for agg in ["yearly", "seasonal", "shoulder"]: unaligned_store = get_era5_stores(agg) with stopwatch(f"Loading {agg} ERA5 data"): @@ -677,8 +700,26 @@ def spatial_agg( assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}" unaligned = _correct_longs(unaligned) - aggregations = _Aggregations.common() - aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level) + # Filter out Grid Cells that are completely outside the ERA5 valid area + valid_geoms = [] + for g, v in shapes( + unaligned.t2m_mean.isel(time=0).isnull().astype("uint8").values, + transform=unaligned.odc.transform, + ): + if v == 0: + valid_geoms.append(shapely.geometry.shape(g)) + grid_gdf_filtered = grid_gdf[grid_gdf.geometry.intersects(shapely.geometry.MultiPolygon(valid_geoms))] + + aggregated = aggregate_raster_into_grid( + unaligned, + grid_gdf_filtered, + aggregations, + grid, + level, + n_partitions=40, + concurrent_partitions=concurrent_partitions, + pxbuffer=10, + ) aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)}) store = get_era5_stores(agg, grid, level) diff --git a/src/entropice/watermask.py b/src/entropice/watermask.py index 36c6c2e..d744056 100644 --- a/src/entropice/watermask.py +++ b/src/entropice/watermask.py @@ -1,5 +1,6 @@ """Helpers for the watermask.""" +import duckdb import geopandas as gpd from entropice.paths import watermask_file @@ -16,7 +17,7 @@ def open(): return watermask -def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: +def clip_grid(grid_gdf: gpd.GeoDataFrame, allow_duckdb: bool = False) -> gpd.GeoDataFrame: """Clip the input GeoDataFrame with the watermask. Args: @@ -27,6 +28,71 @@ def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: """ watermask = open() - watermask = watermask.to_crs(gdf.crs) - gdf = gdf.overlay(watermask, how="difference") - return gdf + watermask = watermask.to_crs(grid_gdf.crs) + # ! Currently disabled - kernel crashes + allow_duckdb = False + if len(grid_gdf) >= 10000 and allow_duckdb: + # Use duckdb for large datasets + crs = grid_gdf.crs + + # Convert geometry columns to WKB format for DuckDB + # No need to copy the watermask + watermask["geometry"] = watermask.geometry.to_wkb() + grid_gdf = grid_gdf.copy() + grid_gdf["geometry"] = grid_gdf.geometry.to_wkb() + + # Connect to DuckDB + con = duckdb.connect(":memory:") + + # Install and load spatial extension + con.execute("INSTALL spatial;") + con.execute("LOAD spatial;") + + # Register the DataFrames as tables in DuckDB + con.register("watermask", watermask) + con.register("grid", grid_gdf) + + query = """ + SELECT g.* + FROM grid g + WHERE NOT EXISTS ( + SELECT 1 + FROM watermask w + WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry)) + ) + """ + + query = """ + WITH clipped AS ( + SELECT + g.* EXCLUDE (geometry), + CASE + WHEN EXISTS ( + SELECT 1 + FROM watermask w + WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry)) + ) + THEN ST_Difference( + ST_GeomFromWKB(g.geometry), + ( + SELECT ST_Union_Agg(ST_GeomFromWKB(w.geometry)) + FROM watermask w + WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry)) + ) + ) + ELSE ST_GeomFromWKB(g.geometry) + END AS geometry + FROM grid g + ) + SELECT * FROM clipped + WHERE NOT ST_IsEmpty(geometry) + """ + + result = con.execute(query).df() + # Convert back to GeoDataFrame + result["geometry"] = gpd.GeoSeries.from_wkb(result["geometry"].apply(bytes)) + grid_gdf = gpd.GeoDataFrame(result, geometry="geometry", crs=crs) + con.close() + else: + grid_gdf = grid_gdf.overlay(watermask, how="difference") + return grid_gdf