Make aggregations work
This commit is contained in:
parent
98314fe8b3
commit
7b09dda6a3
6 changed files with 328 additions and 97 deletions
|
|
@ -38,18 +38,7 @@ All spatial aggregations relied heavily on CPU compute, since Cupy lacking suppo
|
|||
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
|
||||
|
||||
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
|
||||
|
||||
| grid | time | memory | processes |
|
||||
| ----- | ------ | ------ | --------- |
|
||||
| Hex3 | | | |
|
||||
| Hex4 | | | |
|
||||
| Hex5 | | | |
|
||||
| Hex6 | | | |
|
||||
| Hpx6 | 37 min | ~300GB | 40 |
|
||||
| Hpx7 | | | |
|
||||
| Hpx8 | | | |
|
||||
| Hpx9 | 25m | ~300GB | 40 |
|
||||
| Hpx10 | 34 min | ~300GB | 40 |
|
||||
All spatial aggregations into the different grids done took around 30 min each, with a total memory peak of ~300 GB partitioned over 40 processes.
|
||||
|
||||
## Alpha Earth
|
||||
|
||||
|
|
@ -71,4 +60,31 @@ Each scale was choosen so that each grid cell had around 10000px do estimate the
|
|||
|
||||
## Era5
|
||||
|
||||
### Spatial aggregations into grids
|
||||
|
||||
All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile
|
||||
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
|
||||
|
||||
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
|
||||
|
||||
Since the resolution of the ERA5 dataset is spatially smaller than the resolution of the higher-resolution, different aggregations methods where used for different grid-levels:
|
||||
|
||||
- Common aggregations: mean, min, max, std, median, p01, p05, p25, p75, p95, p99 for low resolution grids
|
||||
- Only mean aggregations for medium resolution grids
|
||||
- Linar interpolation for high resolution grids
|
||||
|
||||
For geometries crossing the antimeridian, geometries are corrected.
|
||||
|
||||
| grid | method |
|
||||
| ----- | ----------- |
|
||||
| Hex3 | Common |
|
||||
| Hex4 | Common |
|
||||
| Hex5 | Mean |
|
||||
| Hex6 | Interpolate |
|
||||
| Hpx6 | Common |
|
||||
| Hpx7 | Common |
|
||||
| Hpx8 | Common |
|
||||
| Hpx9 | Mean |
|
||||
| Hpx10 | Interpolate |
|
||||
|
||||
???
|
||||
|
|
|
|||
27
pixi.lock
generated
27
pixi.lock
generated
|
|
@ -462,6 +462,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c8/a7/a597ff7dd1e1603abd94991ce242f93979d5f10b0d45ed23976dfb22bf64/altair_tiles-0.4.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
|
||||
|
|
@ -497,6 +498,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/11/a8/c6a4b901d17399c77cd81fb001ce8961e9f5e04d3daf27e8925cb012e163/docutils-0.22.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a3/cf/7feb3222d770566ca9eaf0bf6922745fadd1ed7ab11832520063a515c240/ecmwf_datastores_client-0.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/65/54/5e3b0e41799e17e5eff1547fda4aab53878c0adb4243de6b95f8ddef899e/ee_extra-2025.7.2-py3-none-any.whl
|
||||
|
|
@ -788,6 +790,15 @@ packages:
|
|||
- jupyter-book ; extra == 'doc'
|
||||
- vl-convert-python ; extra == 'doc'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
|
||||
name: antimeridian
|
||||
version: 0.4.5
|
||||
sha256: 8b1f82c077d2c48eae0a6606759cfec9133a0701250371cb707a56959451d9dd
|
||||
requires_dist:
|
||||
- numpy>=1.22.4
|
||||
- shapely>=2.0
|
||||
- click>=8.1.6 ; extra == 'cli'
|
||||
requires_python: '>=3.10'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/anyio-4.11.0-pyhcf101f3_0.conda
|
||||
sha256: 7378b5b9d81662d73a906fabfc2fb81daddffe8dc0680ed9cda7a9562af894b0
|
||||
md5: 814472b61da9792fae28156cb9ee54f5
|
||||
|
|
@ -2763,6 +2774,18 @@ packages:
|
|||
- pytest ; extra == 'test'
|
||||
- cloudpickle ; extra == 'test'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/82/29/153d1b4fc14c68e6766d7712d35a7ab6272a801c52160126ac7df681f758/duckdb-1.4.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: duckdb
|
||||
version: 1.4.2
|
||||
sha256: a456adbc3459c9dcd99052fad20bd5f8ef642be5b04d09590376b2eb3eb84f5c
|
||||
requires_dist:
|
||||
- ipython ; extra == 'all'
|
||||
- fsspec ; extra == 'all'
|
||||
- numpy ; extra == 'all'
|
||||
- pandas ; extra == 'all'
|
||||
- pyarrow ; extra == 'all'
|
||||
- adbc-driver-manager ; extra == 'all'
|
||||
requires_python: '>=3.9.0'
|
||||
- pypi: https://files.pythonhosted.org/packages/91/bd/d501c3c3602e70d1d729f042ae0b85446a1213a630a7a4290f361b37d9a8/earthengine_api-1.7.1-py3-none-any.whl
|
||||
name: earthengine-api
|
||||
version: 1.7.1
|
||||
|
|
@ -2819,7 +2842,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30
|
||||
sha256: c335ffb8f5ffc53929fcd9d656087692b6e9918938384df60d136124ca5365bc
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2875,6 +2898,8 @@ packages:
|
|||
- cupy-xarray>=0.1.4,<0.2
|
||||
- memray>=1.19.1,<2
|
||||
- xarray-histogram>=0.2.2,<0.3
|
||||
- antimeridian>=0.4.5,<0.5
|
||||
- duckdb>=1.4.2,<2
|
||||
requires_python: '>=3.13,<3.14'
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ dependencies = [
|
|||
"xgboost>=3.1.1,<4",
|
||||
"s3fs>=2025.10.0,<2026",
|
||||
"xarray-spatial",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Aggregation helpers."""
|
||||
|
||||
import gc
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator
|
||||
|
|
@ -9,7 +10,11 @@ from dataclasses import dataclass, field
|
|||
from functools import cache
|
||||
from typing import Literal
|
||||
|
||||
import antimeridian
|
||||
import cudf
|
||||
import cuml.cluster
|
||||
import geopandas as gpd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import odc.geo.geobox
|
||||
import pandas as pd
|
||||
|
|
@ -250,7 +255,18 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati
|
|||
return cell_data
|
||||
|
||||
|
||||
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||
def partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int, plot: bool = False) -> Generator[gpd.GeoDataFrame]:
|
||||
"""Partition the input GeoDataFrame into n_partitions parts.
|
||||
|
||||
Args:
|
||||
grid_gdf (gpd.GeoDataFrame): The input GeoDataFrame to partition.
|
||||
n_partitions (int): The number of partitions.
|
||||
plot (bool, optional): Whether to plot the partitions. Defaults to True.
|
||||
|
||||
Yields:
|
||||
Generator[gpd.GeoDataFrame]: Partitions of the input GeoDataFrame.
|
||||
|
||||
"""
|
||||
if grid_gdf.crs.to_epsg() == 4326:
|
||||
crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian)
|
||||
else:
|
||||
|
|
@ -262,7 +278,24 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[
|
|||
|
||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||
|
||||
# use cuml and cudf if len of centroids is larger than 100000
|
||||
if len(centroids) > 100000:
|
||||
print(f"Using cuML KMeans for partitioning {len(centroids)} centroids")
|
||||
centroids_cudf = cudf.DataFrame.from_pandas(centroids)
|
||||
kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42)
|
||||
labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy()
|
||||
else:
|
||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||
|
||||
if plot:
|
||||
grid_gdf = grid_gdf.copy()
|
||||
grid_gdf["partition"] = labels
|
||||
ax = grid_gdf.plot(column="partition", categorical=True, legend=True, figsize=(10, 10))
|
||||
if crosses_antimeridian.any():
|
||||
grid_gdf_am.plot(ax=ax, color="red", edgecolor="black", alpha=0.5)
|
||||
ax.set_title("Grid partitions")
|
||||
plt.show()
|
||||
for i in range(n_partitions):
|
||||
partition = grid_gdf[labels == i]
|
||||
yield partition
|
||||
|
|
@ -292,19 +325,30 @@ class _MemoryProfiler:
|
|||
|
||||
|
||||
memprof = None
|
||||
shared_raster = None
|
||||
|
||||
|
||||
def _init_worker():
|
||||
def _init_worker(r: xr.Dataset | None):
|
||||
global memprof
|
||||
global shared_raster
|
||||
memprof = _MemoryProfiler()
|
||||
if r is not None:
|
||||
# print("Initializing shared raster in worker")
|
||||
shared_raster = r
|
||||
|
||||
|
||||
def _align_partition(
|
||||
grid_partition_gdf: gpd.GeoDataFrame,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
|
||||
aggregations: _Aggregations | None, # None -> Interpolation
|
||||
pxbuffer: int,
|
||||
):
|
||||
# ? This function is expected to run inside a worker process
|
||||
# It heavily utilizes different techniques to reduce memory usage such as
|
||||
# Lazy operations, reading only necessary data, and cleaning up memory after use.
|
||||
# Shared in-memory raster datasets are used when possible to avoid duplicating large datasets in memory.
|
||||
# Shared raster datasets only work when using the "fork" start method for multiprocessing.
|
||||
|
||||
# Strategy for each cell:
|
||||
# 1. Correct the geometry to account for the antimeridian
|
||||
# 2. Cop the dataset and load the data into memory
|
||||
|
|
@ -328,122 +372,159 @@ def _align_partition(
|
|||
|
||||
memprof.log_memory("Before reading partial raster", log=False)
|
||||
|
||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
need_to_close_raster = False
|
||||
if raster is None:
|
||||
# print("Using shared raster in worker")
|
||||
raster = shared_raster
|
||||
elif callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
# print("Loading raster in partition")
|
||||
raster = raster()
|
||||
need_to_close_raster = True
|
||||
else:
|
||||
need_to_close_raster = False
|
||||
# else:
|
||||
# print("Using provided raster in partition")
|
||||
|
||||
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
||||
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
|
||||
|
||||
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||
partial_extent = partial_extent.buffered(
|
||||
raster.odc.geobox.resolution.x * pxbuffer,
|
||||
raster.odc.geobox.resolution.y * pxbuffer,
|
||||
) # buffer by pxbuffer pixels
|
||||
with stopwatch("Cropping raster to partition extent", log=False):
|
||||
try:
|
||||
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||
except Exception as e:
|
||||
print(f"Error cropping raster to partition extent: {e}")
|
||||
return ongrid
|
||||
|
||||
if partial_raster.nbytes / 1e9 > 20:
|
||||
print(
|
||||
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
|
||||
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
|
||||
f" This may lead to out-of-memory errors."
|
||||
if aggregations is None:
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
if grid_partition_gdf.crs.to_epsg() == 4326:
|
||||
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
|
||||
cx = centroids.apply(lambda p: p.x)
|
||||
cy = centroids.apply(lambda p: p.y)
|
||||
else:
|
||||
centroids = grid_partition_gdf.geometry.centroid
|
||||
cx = centroids.x
|
||||
cy = centroids.y
|
||||
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||
interp_coords = (
|
||||
{"latitude": interp_y, "longitude": interp_x}
|
||||
if "latitude" in raster.dims and "longitude" in raster.dims
|
||||
else {"y": interp_y, "x": interp_x}
|
||||
)
|
||||
memprof.log_memory("After reading partial raster", log=False)
|
||||
# ?: Cubic does not work with NaNs in xarray interp
|
||||
with stopwatch("Interpolating data to grid centroids", log=False):
|
||||
ongrid = raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
|
||||
memprof.log_memory("After interpolating data", log=False)
|
||||
else:
|
||||
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||
partial_extent = partial_extent.buffered(
|
||||
raster.odc.geobox.resolution.x * pxbuffer,
|
||||
raster.odc.geobox.resolution.y * pxbuffer,
|
||||
) # buffer by pxbuffer pixels
|
||||
with stopwatch("Cropping raster to partition extent", log=False):
|
||||
try:
|
||||
partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||
except Exception as e:
|
||||
print(f"Error cropping raster to partition extent: {e}")
|
||||
raise e
|
||||
|
||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
try:
|
||||
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Error processing cell {row['cell_id']}: {e}")
|
||||
continue
|
||||
ongrid[i, ...] = cell_data
|
||||
if partial_raster.nbytes / 1e9 > 20:
|
||||
print(
|
||||
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
|
||||
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
|
||||
f" This may lead to out-of-memory errors."
|
||||
)
|
||||
memprof.log_memory("After reading partial raster", log=False)
|
||||
others_shape = tuple(
|
||||
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
)
|
||||
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
|
||||
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
try:
|
||||
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Error processing cell {row['cell_id']}: {e}")
|
||||
continue
|
||||
ongrid[i, ...] = cell_data
|
||||
|
||||
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
||||
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
|
||||
|
||||
partial_raster.close()
|
||||
del partial_raster
|
||||
|
||||
partial_raster.close()
|
||||
del partial_raster
|
||||
if need_to_close_raster:
|
||||
raster.close()
|
||||
del raster
|
||||
gc.collect()
|
||||
memprof.log_memory("After cleaning", log=False)
|
||||
|
||||
print("Finished processing partition")
|
||||
print("### Stopwatch summary ###\n")
|
||||
print(stopwatch.summary())
|
||||
print("### Memory summary ###\n")
|
||||
print(memprof.summary())
|
||||
print("#########################")
|
||||
# print("Finished processing partition")
|
||||
# print("### Stopwatch summary ###\n")
|
||||
# print(stopwatch.summary())
|
||||
# print("### Memory summary ###\n")
|
||||
# print(memprof.summary())
|
||||
# print("#########################")
|
||||
|
||||
return ongrid
|
||||
|
||||
|
||||
@stopwatch("Aligning data with grid")
|
||||
def _align_data(
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations,
|
||||
n_partitions: int,
|
||||
aggregations: _Aggregations | None,
|
||||
n_partitions: int | None,
|
||||
concurrent_partitions: int,
|
||||
pxbuffer: int,
|
||||
):
|
||||
partial_ongrids = []
|
||||
|
||||
_init_worker()
|
||||
if isinstance(grid_gdf, list):
|
||||
n_partitions = len(grid_gdf)
|
||||
grid_partitions = grid_gdf
|
||||
else:
|
||||
grid_partitions = partition_grid(grid_gdf, n_partitions)
|
||||
|
||||
if n_partitions < concurrent_partitions:
|
||||
print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}")
|
||||
concurrent_partitions = n_partitions
|
||||
|
||||
if concurrent_partitions <= 1:
|
||||
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||
print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
|
||||
_init_worker(None) # No need to use a shared raster, since the processing is done in the main process
|
||||
for i, grid_partition in enumerate(grid_partitions):
|
||||
# print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
|
||||
part_ongrid = _align_partition(
|
||||
grid_partition,
|
||||
raster, # .copy() if isinstance(raster, xr.Dataset) else raster,
|
||||
raster,
|
||||
aggregations,
|
||||
pxbuffer,
|
||||
)
|
||||
partial_ongrids.append(part_ongrid)
|
||||
else:
|
||||
# For mp start method fork, we can share the raster dataset between workers
|
||||
if mp.get_start_method(allow_none=True) == "fork":
|
||||
_init_worker(raster if isinstance(raster, xr.Dataset) else None)
|
||||
initargs = (None,)
|
||||
else:
|
||||
# For spawn or forkserver, we need to copy the raster into each worker
|
||||
initargs = (raster if isinstance(raster, xr.Dataset) else None,)
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=concurrent_partitions,
|
||||
initializer=_init_worker,
|
||||
# initializer=_init_raster_global,
|
||||
# initargs=(raster,),
|
||||
initargs=initargs,
|
||||
) as executor:
|
||||
futures = {}
|
||||
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||
for i, grid_partition in enumerate(grid_partitions):
|
||||
futures[
|
||||
executor.submit(
|
||||
_align_partition,
|
||||
grid_partition,
|
||||
raster.copy() if isinstance(raster, xr.Dataset) else raster,
|
||||
None if isinstance(raster, xr.Dataset) else raster,
|
||||
aggregations,
|
||||
pxbuffer,
|
||||
)
|
||||
] = i
|
||||
if i == 6:
|
||||
print("Breaking after 3 partitions for testing purposes")
|
||||
|
||||
print("Submitted all partitions, waiting for results...")
|
||||
|
||||
for future in track(
|
||||
as_completed(futures),
|
||||
|
|
@ -452,7 +533,7 @@ def _align_data(
|
|||
):
|
||||
try:
|
||||
i = futures[future]
|
||||
print(f"Processed partition {i + 1}/{len(futures)}")
|
||||
# print(f"Processed partition {i + 1}/{len(futures)}")
|
||||
part_ongrid = future.result()
|
||||
partial_ongrids.append(part_ongrid)
|
||||
except Exception as e:
|
||||
|
|
@ -465,11 +546,11 @@ def _align_data(
|
|||
|
||||
def aggregate_raster_into_grid(
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
aggregations: _Aggregations,
|
||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||
aggregations: _Aggregations | Literal["interpolate"],
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
n_partitions: int = 20,
|
||||
n_partitions: int | None = 20,
|
||||
concurrent_partitions: int = 5,
|
||||
pxbuffer: int = 15,
|
||||
):
|
||||
|
|
@ -477,11 +558,13 @@ def aggregate_raster_into_grid(
|
|||
|
||||
Args:
|
||||
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
|
||||
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
|
||||
aggregations (_Aggregations): The aggregations to perform.
|
||||
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
|
||||
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
||||
No further partitioning will be done and the n_partitions argument will be ignored.
|
||||
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
||||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||
level (int): The level of the grid.
|
||||
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
||||
Defaults to 5.
|
||||
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
||||
|
|
@ -493,7 +576,7 @@ def aggregate_raster_into_grid(
|
|||
ongrid = _align_data(
|
||||
grid_gdf,
|
||||
raster,
|
||||
aggregations,
|
||||
aggregations if aggregations != "interpolate" else None,
|
||||
n_partitions=n_partitions,
|
||||
concurrent_partitions=concurrent_partitions,
|
||||
pxbuffer=pxbuffer,
|
||||
|
|
|
|||
|
|
@ -87,10 +87,12 @@ import numpy as np
|
|||
import odc.geo
|
||||
import odc.geo.xr
|
||||
import pandas as pd
|
||||
import shapely.geometry
|
||||
import ultraplot as uplt
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xvec
|
||||
from rasterio.features import shapes
|
||||
from rich import pretty, print, traceback
|
||||
from stopuhr import stopwatch
|
||||
|
||||
|
|
@ -643,6 +645,7 @@ def viz(
|
|||
# ===========================
|
||||
|
||||
|
||||
@stopwatch("Correcting longitudes to -180 to 180")
|
||||
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
||||
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
||||
|
||||
|
|
@ -651,6 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
|||
def spatial_agg(
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
concurrent_partitions: int = 20,
|
||||
):
|
||||
"""Perform spatial aggregation of ERA5 data to grid cells.
|
||||
|
||||
|
|
@ -661,6 +665,8 @@ def spatial_agg(
|
|||
Args:
|
||||
grid ("hex" | "healpix"): Grid type.
|
||||
level (int): Grid resolution level.
|
||||
concurrent_partitions (int, optional): Number of concurrent partitions to process.
|
||||
Defaults to 20.
|
||||
|
||||
"""
|
||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||
|
|
@ -669,6 +675,23 @@ def spatial_agg(
|
|||
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||
grid_gdf = grid_gdf.to_crs("epsg:4326")
|
||||
|
||||
aggregations = {
|
||||
"hex": {
|
||||
3: _Aggregations.common(),
|
||||
4: _Aggregations.common(),
|
||||
5: _Aggregations(mean=True),
|
||||
6: "interpolate",
|
||||
},
|
||||
"healpix": {
|
||||
6: _Aggregations.common(),
|
||||
7: _Aggregations.common(),
|
||||
8: _Aggregations.common(),
|
||||
9: _Aggregations(mean=True),
|
||||
10: "interpolate",
|
||||
},
|
||||
}
|
||||
aggregations = aggregations[grid][level]
|
||||
|
||||
for agg in ["yearly", "seasonal", "shoulder"]:
|
||||
unaligned_store = get_era5_stores(agg)
|
||||
with stopwatch(f"Loading {agg} ERA5 data"):
|
||||
|
|
@ -677,8 +700,26 @@ def spatial_agg(
|
|||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||
unaligned = _correct_longs(unaligned)
|
||||
|
||||
aggregations = _Aggregations.common()
|
||||
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
||||
# Filter out Grid Cells that are completely outside the ERA5 valid area
|
||||
valid_geoms = []
|
||||
for g, v in shapes(
|
||||
unaligned.t2m_mean.isel(time=0).isnull().astype("uint8").values,
|
||||
transform=unaligned.odc.transform,
|
||||
):
|
||||
if v == 0:
|
||||
valid_geoms.append(shapely.geometry.shape(g))
|
||||
grid_gdf_filtered = grid_gdf[grid_gdf.geometry.intersects(shapely.geometry.MultiPolygon(valid_geoms))]
|
||||
|
||||
aggregated = aggregate_raster_into_grid(
|
||||
unaligned,
|
||||
grid_gdf_filtered,
|
||||
aggregations,
|
||||
grid,
|
||||
level,
|
||||
n_partitions=40,
|
||||
concurrent_partitions=concurrent_partitions,
|
||||
pxbuffer=10,
|
||||
)
|
||||
|
||||
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||
store = get_era5_stores(agg, grid, level)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Helpers for the watermask."""
|
||||
|
||||
import duckdb
|
||||
import geopandas as gpd
|
||||
|
||||
from entropice.paths import watermask_file
|
||||
|
|
@ -16,7 +17,7 @@ def open():
|
|||
return watermask
|
||||
|
||||
|
||||
def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
|
||||
def clip_grid(grid_gdf: gpd.GeoDataFrame, allow_duckdb: bool = False) -> gpd.GeoDataFrame:
|
||||
"""Clip the input GeoDataFrame with the watermask.
|
||||
|
||||
Args:
|
||||
|
|
@ -27,6 +28,71 @@ def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
|
|||
|
||||
"""
|
||||
watermask = open()
|
||||
watermask = watermask.to_crs(gdf.crs)
|
||||
gdf = gdf.overlay(watermask, how="difference")
|
||||
return gdf
|
||||
watermask = watermask.to_crs(grid_gdf.crs)
|
||||
# ! Currently disabled - kernel crashes
|
||||
allow_duckdb = False
|
||||
if len(grid_gdf) >= 10000 and allow_duckdb:
|
||||
# Use duckdb for large datasets
|
||||
crs = grid_gdf.crs
|
||||
|
||||
# Convert geometry columns to WKB format for DuckDB
|
||||
# No need to copy the watermask
|
||||
watermask["geometry"] = watermask.geometry.to_wkb()
|
||||
grid_gdf = grid_gdf.copy()
|
||||
grid_gdf["geometry"] = grid_gdf.geometry.to_wkb()
|
||||
|
||||
# Connect to DuckDB
|
||||
con = duckdb.connect(":memory:")
|
||||
|
||||
# Install and load spatial extension
|
||||
con.execute("INSTALL spatial;")
|
||||
con.execute("LOAD spatial;")
|
||||
|
||||
# Register the DataFrames as tables in DuckDB
|
||||
con.register("watermask", watermask)
|
||||
con.register("grid", grid_gdf)
|
||||
|
||||
query = """
|
||||
SELECT g.*
|
||||
FROM grid g
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM watermask w
|
||||
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
|
||||
)
|
||||
"""
|
||||
|
||||
query = """
|
||||
WITH clipped AS (
|
||||
SELECT
|
||||
g.* EXCLUDE (geometry),
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM watermask w
|
||||
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
|
||||
)
|
||||
THEN ST_Difference(
|
||||
ST_GeomFromWKB(g.geometry),
|
||||
(
|
||||
SELECT ST_Union_Agg(ST_GeomFromWKB(w.geometry))
|
||||
FROM watermask w
|
||||
WHERE ST_Intersects(ST_GeomFromWKB(g.geometry), ST_GeomFromWKB(w.geometry))
|
||||
)
|
||||
)
|
||||
ELSE ST_GeomFromWKB(g.geometry)
|
||||
END AS geometry
|
||||
FROM grid g
|
||||
)
|
||||
SELECT * FROM clipped
|
||||
WHERE NOT ST_IsEmpty(geometry)
|
||||
"""
|
||||
|
||||
result = con.execute(query).df()
|
||||
# Convert back to GeoDataFrame
|
||||
result["geometry"] = gpd.GeoSeries.from_wkb(result["geometry"].apply(bytes))
|
||||
grid_gdf = gpd.GeoDataFrame(result, geometry="geometry", crs=crs)
|
||||
con.close()
|
||||
else:
|
||||
grid_gdf = grid_gdf.overlay(watermask, how="difference")
|
||||
return grid_gdf
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue