Make aggregations work

This commit is contained in:
Tobias Hölzer 2025-11-28 00:03:51 +01:00
parent 98314fe8b3
commit 7b09dda6a3
6 changed files with 328 additions and 97 deletions

View file

@ -38,18 +38,7 @@ All spatial aggregations relied heavily on CPU compute, since Cupy lacking suppo
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
| grid | time | memory | processes |
| ----- | ------ | ------ | --------- |
| Hex3 | | | |
| Hex4 | | | |
| Hex5 | | | |
| Hex6 | | | |
| Hpx6 | 37 min | ~300GB | 40 |
| Hpx7 | | | |
| Hpx8 | | | |
| Hpx9 | 25m | ~300GB | 40 |
| Hpx10 | 34 min | ~300GB | 40 |
All spatial aggregations into the different grids done took around 30 min each, with a total memory peak of ~300 GB partitioned over 40 processes.
## Alpha Earth
@ -71,4 +60,31 @@ Each scale was choosen so that each grid cell had around 10000px do estimate the
## Era5
### Spatial aggregations into grids
All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
Since the resolution of the ERA5 dataset is spatially smaller than the resolution of the higher-resolution, different aggregations methods where used for different grid-levels:
- Common aggregations: mean, min, max, std, median, p01, p05, p25, p75, p95, p99 for low resolution grids
- Only mean aggregations for medium resolution grids
- Linar interpolation for high resolution grids
For geometries crossing the antimeridian, geometries are corrected.
| grid | method |
| ----- | ----------- |
| Hex3 | Common |
| Hex4 | Common |
| Hex5 | Mean |
| Hex6 | Interpolate |
| Hpx6 | Common |
| Hpx7 | Common |
| Hpx8 | Common |
| Hpx9 | Mean |
| Hpx10 | Interpolate |
???

27
pixi.lock generated
View file

@ -462,6 +462,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
@ -497,6 +498,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/11/a8/c6a4b901d17399c77cd81fb001ce8961e9f5e04d3daf27e8925cb012e163/docutils-0.22.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/cf/7feb3222d770566ca9eaf0bf6922745fadd1ed7ab11832520063a515c240/ecmwf_datastores_client-0.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/65/54/5e3b0e41799e17e5eff1547fda4aab53878c0adb4243de6b95f8ddef899e/ee_extra-2025.7.2-py3-none-any.whl
@ -788,6 +790,15 @@ packages:
- jupyter-book ; extra == 'doc'
- vl-convert-python ; extra == 'doc'
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
name: antimeridian
version: 0.4.5
sha256: 8b1f82c077d2c48eae0a6606759cfec9133a0701250371cb707a56959451d9dd
requires_dist:
- numpy>=1.22.4
- shapely>=2.0
- click>=8.1.6 ; extra == 'cli'
requires_python: '>=3.10'
- conda: https://conda.anaconda.org/conda-forge/noarch/anyio-4.11.0-pyhcf101f3_0.conda
sha256: 7378b5b9d81662d73a906fabfc2fb81daddffe8dc0680ed9cda7a9562af894b0
md5: 814472b61da9792fae28156cb9ee54f5
@ -2763,6 +2774,18 @@ packages:
- pytest ; extra == 'test'
- cloudpickle ; extra == 'test'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl
name: duckdb
version: 1.4.2
sha256: a456adbc3459c9dcd99052fad20bd5f8ef642be5b04d09590376b2eb3eb84f5c
requires_dist:
- ipython ; extra == 'all'
- fsspec ; extra == 'all'
- numpy ; extra == 'all'
- pandas ; extra == 'all'
- pyarrow ; extra == 'all'
- adbc-driver-manager ; extra == 'all'
requires_python: '>=3.9.0'
- pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl
name: earthengine-api
version: 1.7.1
@ -2819,7 +2842,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30
sha256: c335ffb8f5ffc53929fcd9d656087692b6e9918938384df60d136124ca5365bc
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -2875,6 +2898,8 @@ packages:
- cupy-xarray>=0.1.4,<0.2
- memray>=1.19.1,<2
- xarray-histogram>=0.2.2,<0.3
- antimeridian>=0.4.5,<0.5
- duckdb>=1.4.2,<2
requires_python: '>=3.13,<3.14'
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7

View file

@ -57,7 +57,7 @@ dependencies = [
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2",
]
[project.scripts]

View file

@ -1,6 +1,7 @@
"""Aggregation helpers."""
import gc
import multiprocessing as mp
import os
from collections import defaultdict
from collections.abc import Callable, Generator
@ -9,7 +10,11 @@ from dataclasses import dataclass, field
from functools import cache
from typing import Literal
import antimeridian
import cudf
import cuml.cluster
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import odc.geo.geobox
import pandas as pd
@ -250,7 +255,18 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati
return cell_data
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
def partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int, plot: bool = False) -> Generator[gpd.GeoDataFrame]:
"""Partition the input GeoDataFrame into n_partitions parts.
Args:
grid_gdf (gpd.GeoDataFrame): The input GeoDataFrame to partition.
n_partitions (int): The number of partitions.
plot (bool, optional): Whether to plot the partitions. Defaults to True.
Yields:
Generator[gpd.GeoDataFrame]: Partitions of the input GeoDataFrame.
"""
if grid_gdf.crs.to_epsg() == 4326:
crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian)
else:
@ -262,7 +278,24 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
# use cuml and cudf if len of centroids is larger than 100000
if len(centroids) > 100000:
print(f"Using cuML KMeans for partitioning {len(centroids)} centroids")
centroids_cudf = cudf.DataFrame.from_pandas(centroids)
kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42)
labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy()
else:
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
if plot:
grid_gdf = grid_gdf.copy()
grid_gdf["partition"] = labels
ax = grid_gdf.plot(column="partition", categorical=True, legend=True, figsize=(10, 10))
if crosses_antimeridian.any():
grid_gdf_am.plot(ax=ax, color="red", edgecolor="black", alpha=0.5)
ax.set_title("Grid partitions")
plt.show()
for i in range(n_partitions):
partition = grid_gdf[labels == i]
yield partition
@ -292,19 +325,30 @@ class _MemoryProfiler:
memprof = None
shared_raster = None
def _init_worker():
def _init_worker(r: xr.Dataset | None):
global memprof
global shared_raster
memprof = _MemoryProfiler()
if r is not None:
# print("Initializing shared raster in worker")
shared_raster = r
def _align_partition(
grid_partition_gdf: gpd.GeoDataFrame,
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
aggregations: _Aggregations | None, # None -> Interpolation
pxbuffer: int,
):
# ? This function is expected to run inside a worker process
# It heavily utilizes different techniques to reduce memory usage such as
# Lazy operations, reading only necessary data, and cleaning up memory after use.
# Shared in-memory raster datasets are used when possible to avoid duplicating large datasets in memory.
# Shared raster datasets only work when using the "fork" start method for multiprocessing.
# Strategy for each cell:
# 1. Correct the geometry to account for the antimeridian
# 2. Cop the dataset and load the data into memory
@ -328,122 +372,159 @@ def _align_partition(
memprof.log_memory("Before reading partial raster", log=False)
if callable(raster) and not isinstance(raster, xr.Dataset):
need_to_close_raster = False
if raster is None:
# print("Using shared raster in worker")
raster = shared_raster
elif callable(raster) and not isinstance(raster, xr.Dataset):
# print("Loading raster in partition")
raster = raster()
need_to_close_raster = True
else:
need_to_close_raster = False
# else:
# print("Using provided raster in partition")
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
return ongrid
if partial_raster.nbytes / 1e9 > 20:
print(
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
f" This may lead to out-of-memory errors."
if aggregations is None:
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
if grid_partition_gdf.crs.to_epsg() == 4326:
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
cx = centroids.apply(lambda p: p.x)
cy = centroids.apply(lambda p: p.y)
else:
centroids = grid_partition_gdf.geometry.centroid
cx = centroids.x
cy = centroids.y
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
interp_coords = (
{"latitude": interp_y, "longitude": interp_x}
if "latitude" in raster.dims and "longitude" in raster.dims
else {"y": interp_y, "x": interp_x}
)
memprof.log_memory("After reading partial raster", log=False)
# ?: Cubic does not work with NaNs in xarray interp
with stopwatch("Interpolating data to grid centroids", log=False):
ongrid = raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
memprof.log_memory("After interpolating data", log=False)
else:
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
raise e
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
try:
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {row['cell_id']}: {e}")
continue
ongrid[i, ...] = cell_data
if partial_raster.nbytes / 1e9 > 20:
print(
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
f" This may lead to out-of-memory errors."
)
memprof.log_memory("After reading partial raster", log=False)
others_shape = tuple(
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
)
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
try:
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {row['cell_id']}: {e}")
continue
ongrid[i, ...] = cell_data
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
partial_raster.close()
del partial_raster
partial_raster.close()
del partial_raster
if need_to_close_raster:
raster.close()
del raster
gc.collect()
memprof.log_memory("After cleaning", log=False)
print("Finished processing partition")
print("### Stopwatch summary ###\n")
print(stopwatch.summary())
print("### Memory summary ###\n")
print(memprof.summary())
print("#########################")
# print("Finished processing partition")
# print("### Stopwatch summary ###\n")
# print(stopwatch.summary())
# print("### Memory summary ###\n")
# print(memprof.summary())
# print("#########################")
return ongrid
@stopwatch("Aligning data with grid")
def _align_data(
grid_gdf: gpd.GeoDataFrame,
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
n_partitions: int,
aggregations: _Aggregations | None,
n_partitions: int | None,
concurrent_partitions: int,
pxbuffer: int,
):
partial_ongrids = []
_init_worker()
if isinstance(grid_gdf, list):
n_partitions = len(grid_gdf)
grid_partitions = grid_gdf
else:
grid_partitions = partition_grid(grid_gdf, n_partitions)
if n_partitions < concurrent_partitions:
print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}")
concurrent_partitions = n_partitions
if concurrent_partitions <= 1:
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
_init_worker(None) # No need to use a shared raster, since the processing is done in the main process
for i, grid_partition in enumerate(grid_partitions):
# print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
part_ongrid = _align_partition(
grid_partition,
raster, # .copy() if isinstance(raster, xr.Dataset) else raster,
raster,
aggregations,
pxbuffer,
)
partial_ongrids.append(part_ongrid)
else:
# For mp start method fork, we can share the raster dataset between workers
if mp.get_start_method(allow_none=True) == "fork":
_init_worker(raster if isinstance(raster, xr.Dataset) else None)
initargs = (None,)
else:
# For spawn or forkserver, we need to copy the raster into each worker
initargs = (raster if isinstance(raster, xr.Dataset) else None,)
with ProcessPoolExecutor(
max_workers=concurrent_partitions,
initializer=_init_worker,
# initializer=_init_raster_global,
# initargs=(raster,),
initargs=initargs,
) as executor:
futures = {}
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
for i, grid_partition in enumerate(grid_partitions):
futures[
executor.submit(
_align_partition,
grid_partition,
raster.copy() if isinstance(raster, xr.Dataset) else raster,
None if isinstance(raster, xr.Dataset) else raster,
aggregations,
pxbuffer,
)
] = i
if i == 6:
print("Breaking after 3 partitions for testing purposes")
print("Submitted all partitions, waiting for results...")
for future in track(
as_completed(futures),
@ -452,7 +533,7 @@ def _align_data(
):
try:
i = futures[future]
print(f"Processed partition {i + 1}/{len(futures)}")
# print(f"Processed partition {i + 1}/{len(futures)}")
part_ongrid = future.result()
partial_ongrids.append(part_ongrid)
except Exception as e:
@ -465,11 +546,11 @@ def _align_data(
def aggregate_raster_into_grid(
raster: xr.Dataset | Callable[[], xr.Dataset],
grid_gdf: gpd.GeoDataFrame,
aggregations: _Aggregations,
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
aggregations: _Aggregations | Literal["interpolate"],
grid: Literal["hex", "healpix"],
level: int,
n_partitions: int = 20,
n_partitions: int | None = 20,
concurrent_partitions: int = 5,
pxbuffer: int = 15,
):
@ -477,11 +558,13 @@ def aggregate_raster_into_grid(
Args:
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
aggregations (_Aggregations): The aggregations to perform.
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
No further partitioning will be done and the n_partitions argument will be ignored.
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
grid (Literal["hex", "healpix"]): The type of grid to use.
level (int): The level of the grid.
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
Defaults to 5.
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
@ -493,7 +576,7 @@ def aggregate_raster_into_grid(
ongrid = _align_data(
grid_gdf,
raster,
aggregations,
aggregations if aggregations != "interpolate" else None,
n_partitions=n_partitions,
concurrent_partitions=concurrent_partitions,
pxbuffer=pxbuffer,

View file

@ -87,10 +87,12 @@ import numpy as np
import odc.geo
import odc.geo.xr
import pandas as pd
import shapely.geometry
import ultraplot as uplt
import xarray as xr
import xdggs
import xvec
from rasterio.features import shapes
from rich import pretty, print, traceback
from stopuhr import stopwatch
@ -643,6 +645,7 @@ def viz(
# ===========================
@stopwatch("Correcting longitudes to -180 to 180")
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
@ -651,6 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
def spatial_agg(
grid: Literal["hex", "healpix"],
level: int,
concurrent_partitions: int = 20,
):
"""Perform spatial aggregation of ERA5 data to grid cells.
@ -661,6 +665,8 @@ def spatial_agg(
Args:
grid ("hex" | "healpix"): Grid type.
level (int): Grid resolution level.
concurrent_partitions (int, optional): Number of concurrent partitions to process.
Defaults to 20.
"""
with stopwatch(f"Loading {grid} grid at level {level}"):
@ -669,6 +675,23 @@ def spatial_agg(
grid_gdf = watermask.clip_grid(grid_gdf)
grid_gdf = grid_gdf.to_crs("epsg:4326")
aggregations = {
"hex": {
3: _Aggregations.common(),
4: _Aggregations.common(),
5: _Aggregations(mean=True),
6: "interpolate",
},
"healpix": {
6: _Aggregations.common(),
7: _Aggregations.common(),
8: _Aggregations.common(),
9: _Aggregations(mean=True),
10: "interpolate",
},
}
aggregations = aggregations[grid][level]
for agg in ["yearly", "seasonal", "shoulder"]:
unaligned_store = get_era5_stores(agg)
with stopwatch(f"Loading {agg} ERA5 data"):
@ -677,8 +700,26 @@ def spatial_agg(
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
unaligned = _correct_longs(unaligned)
aggregations = _Aggregations.common()
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
# Filter out Grid Cells that are completely outside the ERA5 valid area
valid_geoms = []
for g, v in shapes(
unaligned.t2m_mean.isel(time=0).isnull().astype("uint8").values,
transform=unaligned.odc.transform,
):
if v == 0:
valid_geoms.append(shapely.geometry.shape(g))
grid_gdf_filtered = grid_gdf[grid_gdf.geometry.intersects(shapely.geometry.MultiPolygon(valid_geoms))]
aggregated = aggregate_raster_into_grid(
unaligned,
grid_gdf_filtered,
aggregations,
grid,
level,
n_partitions=40,
concurrent_partitions=concurrent_partitions,
pxbuffer=10,
)
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
store = get_era5_stores(agg, grid, level)

View file

@ -1,5 +1,6 @@
"""Helpers for the watermask."""
import duckdb
import geopandas as gpd
from entropice.paths import watermask_file
@ -16,7 +17,7 @@ def open():
return watermask
def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
def clip_grid(grid_gdf: gpd.GeoDataFrame, allow_duckdb: bool = False) -> gpd.GeoDataFrame:
"""Clip the input GeoDataFrame with the watermask.
Args:
@ -27,6 +28,71 @@ def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""
watermask = open()
watermask = watermask.to_crs(gdf.crs)
gdf = gdf.overlay(watermask, how="difference")
return gdf
watermask = watermask.to_crs(grid_gdf.crs)
# ! Currently disabled - kernel crashes
allow_duckdb = False
if len(grid_gdf) >= 10000 and allow_duckdb:
# Use duckdb for large datasets
crs = grid_gdf.crs
# Convert geometry columns to WKB format for DuckDB
# No need to copy the watermask
watermask["geometry"] = watermask.geometry.to_wkb()
grid_gdf = grid_gdf.copy()
grid_gdf["geometry"] = grid_gdf.geometry.to_wkb()
# Connect to DuckDB
con = duckdb.connect(":memory:")
# Install and load spatial extension
con.execute("INSTALL spatial;")
con.execute("LOAD spatial;")
# Register the DataFrames as tables in DuckDB
con.register("watermask", watermask)
con.register("grid", grid_gdf)
query = """
SELECT g.*
FROM grid g
WHERE NOT EXISTS (
SELECT 1
FROM watermask w
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
)
"""
query = """
WITH clipped AS (
SELECT
g.* EXCLUDE (geometry),
CASE
WHEN EXISTS (
SELECT 1
FROM watermask w
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
)
THEN ST_Difference(
ST_GeomFromWKB(g.geometry),
(
SELECT ST_Union_Agg(ST_GeomFromWKB(w.geometry))
FROM watermask w
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
)
)
ELSE ST_GeomFromWKB(g.geometry)
END AS geometry
FROM grid g
)
SELECT * FROM clipped
WHERE NOT ST_IsEmpty(geometry)
"""
result = con.execute(query).df()
# Convert back to GeoDataFrame
result["geometry"] = gpd.GeoSeries.from_wkb(result["geometry"].apply(bytes))
grid_gdf = gpd.GeoDataFrame(result, geometry="geometry", crs=crs)
con.close()
else:
grid_gdf = grid_gdf.overlay(watermask, how="difference")
return grid_gdf