Fix memory leak in aggregations

This commit is contained in:
Tobias Hölzer 2025-11-26 18:11:15 +01:00
parent 341fa7b836
commit 98314fe8b3
7 changed files with 327 additions and 200 deletions

View file

@ -0,0 +1,74 @@
# Processing Documentation
This document documents how long each processing step took and how much memory and compute (CPU & GPU) it needed.
| Grid | ArcticDEM | Era5 | AlphaEarth | Darts |
| ----- | --------- | ---- | ---------- | ----- |
| Hex3 | [ ] | [/] | [ ] | [/] |
| Hex4 | [ ] | [/] | [ ] | [/] |
| Hex5 | [ ] | [/] | [ ] | [/] |
| Hex6 | [ ] | [ ] | [ ] | [ ] |
| Hpx6 | [x] | [/] | [ ] | [/] |
| Hpx7 | [ ] | [/] | [ ] | [/] |
| Hpx8 | [ ] | [/] | [ ] | [/] |
| Hpx9 | [ ] | [/] | [ ] | [/] |
| Hpx10 | [ ] | [ ] | [ ] | [ ] |
## Grid creation
The creation of grids did not take up any significant amount of memory or compute.
The time taken to create a grid was between few seconds for smaller levels up to a few minutes for the high levels.
## DARTS
Similar to grid creation, no significant amount of memory, compute or time needed.
## ArcticDEM
The download took around 8h with memory usage of about 10GB and no stronger limitations on compute.
The size of the resulted icechunk zarr datacube was approx. 160GB on disk which corresponse to approx. 270GB in memory if loaded in.
The enrichment took around 2h on a single A100 GPU node (40GB) with a local dask cluster consisting of 7 processes, each using 2 threads and 30GB of memory, making up a total of 210GB of memory.
These settings can be changed easily to consume less memory by reducing the number of processes or threads.
More processes or thread could not be used to ensure that the GPU does not run out of memory.
### Spatial aggregations into grids
All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
| grid | time | memory | processes |
| ----- | ------ | ------ | --------- |
| Hex3 | | | |
| Hex4 | | | |
| Hex5 | | | |
| Hex6 | | | |
| Hpx6 | 37 min | ~300GB | 40 |
| Hpx7 | | | |
| Hpx8 | | | |
| Hpx9 | 25m | ~300GB | 40 |
| Hpx10 | 34 min | ~300GB | 40 |
## Alpha Earth
The download was heavy limited through the scale of the input data, which is ~10m in the original dataset.
10m as a scale was not computationally feasible for the Google Earth Engine servers, thus each grid and level used another scale to aggregate and download the data.
Each scale was choosen so that each grid cell had around 10000px do estimate the aggregations from it.
| grid | time | scale |
| ----- | ------- | ----- |
| Hex3 | | 1600 |
| Hex4 | | 600 |
| Hex5 | | 240 |
| Hex6 | | 90 |
| Hpx6 | 58 min | 1600 |
| Hpx7 | 3:16 h | 800 |
| Hpx8 | 13:19 h | 400 |
| Hpx9 | | 200 |
| Hpx10 | | 100 |
## Era5
???

37
pixi.lock generated
View file

@ -475,6 +475,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/55/d2/507fd0ee4dd574d2bdbdeac5df83f39d2cae1ffe97d4622cca6f6bab39f1/botocore-1.40.70-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
@ -609,6 +610,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
- pypi: git+https://github.com/davbyr/xAnimate#750e03e480db309407e09f4ffe5f49522a4c4f9b
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/65/ad/8f9ff43ff49ef02c7b8202a42c32a1fe8de1276bba0e6f55609e19ff7585/xdggs-0.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/05/b9/b6a9cf72aef69c3e6db869dcc130e19452a658366dac9377f9cd32a76b80/xproj-0.2.1-py3-none-any.whl
@ -1410,6 +1412,13 @@ packages:
- pkg:pypi/bokeh?source=compressed-mapping
size: 5027028
timestamp: 1762557204752
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
name: boost-histogram
version: 1.6.1
sha256: c5700e53bf2d3d006f71610f71fc592d88dab3279dad300e8178dc084018b177
requires_dist:
- numpy
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
name: boto3
version: 1.40.70
@ -2810,7 +2819,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -2865,6 +2874,7 @@ packages:
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
- cupy-xarray>=0.1.4,<0.2
- memray>=1.19.1,<2
- xarray-histogram>=0.2.2,<0.3
requires_python: '>=3.13,<3.14'
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
@ -9373,6 +9383,31 @@ packages:
- pkg:pypi/xarray?source=hash-mapping
size: 989507
timestamp: 1763464377176
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
name: xarray-histogram
version: 0.2.2
sha256: 90682f9575131e5ea5d96feba5f8679e8c3fae039018d5e38f497696d55ccd5c
requires_dist:
- boost-histogram
- numpy
- xarray
- dask ; extra == 'full'
- scipy ; extra == 'full'
- dask ; extra == 'dev'
- scipy ; extra == 'dev'
- sphinx ; extra == 'dev'
- sphinx-book-theme ; extra == 'dev'
- ruff ; extra == 'dev'
- mypy>=1.5 ; extra == 'dev'
- pytest>=7.4 ; extra == 'dev'
- coverage ; extra == 'dev'
- pytest-cov ; extra == 'dev'
- dask ; extra == 'tests'
- scipy ; extra == 'tests'
- pytest>=7.4 ; extra == 'tests'
- coverage ; extra == 'tests'
- pytest-cov ; extra == 'tests'
requires_python: '>=3.11'
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
name: xarray-spatial
version: 0.1.dev519+g3a3120981

View file

@ -57,7 +57,7 @@ dependencies = [
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3",
]
[project.scripts]

View file

@ -2,7 +2,6 @@
import gc
import os
import sys
from collections import defaultdict
from collections.abc import Callable, Generator
from concurrent.futures import ProcessPoolExecutor, as_completed
@ -33,6 +32,7 @@ from entropice import grids
@dataclass(frozen=True)
class _Aggregations:
# ! The ordering is super important for this class!
_common: bool = False # If true, use _agg_cell_data_single_common and ignore orther flags
mean: bool = True
sum: bool = False
std: bool = False
@ -46,8 +46,14 @@ class _Aggregations:
if self.median and 0.5 in self.quantiles:
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
@classmethod
def common(cls):
return cls(_common=True)
@cache
def __len__(self) -> int:
if self._common:
return 11
length = 0
if self.mean:
length += 1
@ -65,6 +71,8 @@ class _Aggregations:
return length
def aggnames(self) -> list[str]:
if self._common:
return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
names = []
if self.mean:
names.append("mean")
@ -83,7 +91,24 @@ class _Aggregations:
names.append(f"p{q_int}")
return names
def _agg_cell_data_single_common(self, flattened_var: xr.DataArray) -> np.ndarray:
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
cell_data[0, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here!
cell_data[4:, ...] = flattened_var.quantile(
q=quantiles_to_compute,
dim="z",
skipna=True,
).to_numpy()
return cell_data
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
if self._common:
return self._agg_cell_data_single_common(flattened_var)
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
j = 0
@ -106,6 +131,12 @@ class _Aggregations:
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
if self.median:
quantiles_to_compute.insert(0, 0.5)
# Potential way to use cupy:
# cp.apply_along_axis(
# lambda x: cp.quantile(x[~cp.isnan(x)], q=quantiles), axis=0, # arr=data_gpu_raw
# )
cell_data[j:, ...] = flattened_var.quantile(
q=quantiles_to_compute,
dim="z",
@ -173,7 +204,7 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
flattened = cropped.stack(z=spatdims) # noqa: PD013
# if flattened.z.size > 3000:
# if flattened.z.size > 1000000:
# flattened = flattened.cupy.as_cupy()
cell_data = aggregations.agg_cell_data(flattened)
return cell_data
@ -187,6 +218,8 @@ def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggr
else ["y", "x"]
)
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
# if flattened.z.size > 1000000:
# flattened = flattened.cupy.as_cupy()
cell_data = aggregations.agg_cell_data(flattened)
return cell_data
@ -218,7 +251,15 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
if grid_gdf.crs.to_epsg() == 4326:
crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian)
else:
crosses_antimeridian = pd.Series([False] * len(grid_gdf), index=grid_gdf.index)
if crosses_antimeridian.any():
grid_gdf_am = grid_gdf[crosses_antimeridian]
grid_gdf = grid_gdf[~crosses_antimeridian]
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
@ -226,6 +267,9 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[
partition = grid_gdf[labels == i]
yield partition
if crosses_antimeridian.any():
yield grid_gdf_am
class _MemoryProfiler:
def __init__(self):
@ -279,20 +323,20 @@ def _align_partition(
# 1. unmeasurable
# 2. ~0.1-0.3s
# 3. ~0.05s
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
# faster than a processpoolexecutor
memprof.log_memory("Before reading partial raster", log=False)
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
if hasattr(raster, "__dask_graph__"):
graph_size = len(raster.__dask_graph__())
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
need_to_close_raster = True
else:
need_to_close_raster = False
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
data = np.full(data_shape, np.nan, dtype=np.float32)
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
@ -304,57 +348,42 @@ def _align_partition(
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
return data
return ongrid
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
memprof.log_memory("After reading partial raster", log=True)
if partial_raster.nbytes / 1e9 > 20:
print(
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
f" This may lead to out-of-memory errors."
)
memprof.log_memory("After reading partial raster", log=False)
with ProcessPoolExecutor(
max_workers=60,
) as executor:
futures = {}
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# ? Splitting the geometries already here and passing only the cropped data to the worker to
# reduce pickling overhead
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
if len(geoms) == 0:
continue
elif len(geoms) == 1:
cropped = [_read_cell_data(partial_raster, geoms[0])]
else:
cropped = _read_split_cell_data(partial_raster, geoms)
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
try:
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {row['cell_id']}: {e}")
continue
ongrid[i, ...] = cell_data
# Clean up immediately after submitting
del geoms, cropped # Added this line
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
for future in as_completed(futures):
i = futures[future]
try:
cell_data = future.result()
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {i}: {e}")
continue
data[i, ...] = cell_data
del cell_data, futures[future]
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# try:
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
# raise e
# except Exception as e:
# print(f"Error processing cell {row['cell_id']}: {e}")
# continue
# data[i, ...] = cell_data
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
partial_raster.close()
raster.close()
del partial_raster, raster, future, futures
del partial_raster
if need_to_close_raster:
raster.close()
del raster
gc.collect()
memprof.log_memory("After cleaning", log=True)
memprof.log_memory("After cleaning", log=False)
print("Finished processing partition")
print("### Stopwatch summary ###\n")
@ -362,45 +391,8 @@ def _align_partition(
print("### Memory summary ###\n")
print(memprof.summary())
print("#########################")
# with ProcessPoolExecutor(
# max_workers=max_workers,
# initializer=_init_partial_raster_global,
# initargs=(raster, grid_partition_gdf, pxbuffer),
# ) as executor:
# futures = {
# executor.submit(_process_geom_global, row.geometry, aggregations): i
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
# }
# others_shape = tuple(
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
# )
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
# print(f"Allocating partial data array of shape {data_shape}")
# data = np.full(data_shape, np.nan, dtype=np.float32)
# lap = time.time()
# for future in track(
# as_completed(futures),
# total=len(futures),
# description="Spatially aggregating ERA5 data...",
# ):
# print(f"time since last cell: {time.time() - lap:.2f}s")
# i = futures[future]
# try:
# cell_data = future.result()
# data[i, ...] = cell_data
# except Exception as e:
# print(f"Error processing cell {i}: {e}")
# continue
# finally:
# # Try to free memory
# del futures[future], future
# if "cell_data" in locals():
# del cell_data
# gc.collect()
# lap = time.time()
return data
return ongrid
@stopwatch("Aligning data with grid")
@ -409,47 +401,66 @@ def _align_data(
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
n_partitions: int,
max_workers: int,
concurrent_partitions: int,
pxbuffer: int,
):
partial_data = {}
partial_ongrids = []
with ProcessPoolExecutor(
max_workers=max_workers,
initializer=_init_worker,
# initializer=_init_raster_global,
# initargs=(raster,),
) as executor:
futures = {}
_init_worker()
if n_partitions < concurrent_partitions:
print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}")
concurrent_partitions = n_partitions
if concurrent_partitions <= 1:
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
futures[
executor.submit(
_align_partition,
grid_partition,
raster.copy() if isinstance(raster, xr.Dataset) else raster,
aggregations,
pxbuffer,
)
] = i
print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
part_ongrid = _align_partition(
grid_partition,
raster, # .copy() if isinstance(raster, xr.Dataset) else raster,
aggregations,
pxbuffer,
)
partial_ongrids.append(part_ongrid)
else:
with ProcessPoolExecutor(
max_workers=concurrent_partitions,
initializer=_init_worker,
# initializer=_init_raster_global,
# initargs=(raster,),
) as executor:
futures = {}
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
futures[
executor.submit(
_align_partition,
grid_partition,
raster.copy() if isinstance(raster, xr.Dataset) else raster,
aggregations,
pxbuffer,
)
] = i
if i == 6:
print("Breaking after 3 partitions for testing purposes")
print("Submitted all partitions, waiting for results...")
print("Submitted all partitions, waiting for results...")
for future in track(
as_completed(futures),
total=len(futures),
description="Processing grid partitions...",
):
try:
i = futures[future]
print(f"Processed partition {i + 1}/{len(futures)}")
part_data = future.result()
partial_data[i] = part_data
except Exception as e:
print(f"Error processing partition {i}: {e}")
raise e
for future in track(
as_completed(futures),
total=len(futures),
description="Processing grid partitions...",
):
try:
i = futures[future]
print(f"Processed partition {i + 1}/{len(futures)}")
part_ongrid = future.result()
partial_ongrids.append(part_ongrid)
except Exception as e:
print(f"Error processing partition {i}: {e}")
raise e
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
return data
ongrid = xr.concat(partial_ongrids, dim="cell_ids")
return ongrid
def aggregate_raster_into_grid(
@ -459,7 +470,7 @@ def aggregate_raster_into_grid(
grid: Literal["hex", "healpix"],
level: int,
n_partitions: int = 20,
max_workers: int = 5,
concurrent_partitions: int = 5,
pxbuffer: int = 15,
):
"""Aggregate raster data into grid cells.
@ -471,19 +482,20 @@ def aggregate_raster_into_grid(
grid (Literal["hex", "healpix"]): The type of grid to use.
level (int): The level of the grid.
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
Defaults to 5.
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
Returns:
xr.Dataset: Aggregated data aligned with the grid.
"""
aligned = _align_data(
ongrid = _align_data(
grid_gdf,
raster,
aggregations,
n_partitions=n_partitions,
max_workers=max_workers,
concurrent_partitions=concurrent_partitions,
pxbuffer=pxbuffer,
)
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
@ -491,16 +503,6 @@ def aggregate_raster_into_grid(
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
cell_ids = grids.get_cell_ids(grid, level)
dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,

View file

@ -10,6 +10,7 @@ import cupyx.scipy.signal
import cyclopts
import dask.array
import dask.distributed as dd
import geopandas as gpd
import icechunk
import icechunk.xarray
import numpy as np
@ -48,7 +49,7 @@ def download():
@dataclass
class KernelFactory:
class _KernelFactory:
res: int
size_px: int
@ -77,14 +78,14 @@ class KernelFactory:
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
def tpi_cupy(chunk, kernels: KernelFactory):
def tpi_cupy(chunk, kernels: _KernelFactory):
kernel = kernels.ring()
kernel = kernel / cp.nansum(kernel)
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
return tpi
def tri_cupy(chunk, kernels: KernelFactory):
def tri_cupy(chunk, kernels: _KernelFactory):
kernel = kernels.tri()
c2 = chunk**2
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
@ -93,7 +94,7 @@ def tri_cupy(chunk, kernels: KernelFactory):
return tri
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory):
slope_rad = slope * (cp.pi / 180)
aspect_rad = aspect * (cp.pi / 180)
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
@ -150,9 +151,9 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
res = 32 # 32m resolution
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
large_kernels = _KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
# Check if there is data in the chunk
if np.all(np.isnan(chunk)):
@ -306,49 +307,61 @@ def enrich():
def _open_adem():
arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
# config = icechunk.RepositoryConfig.default()
# config.caching = icechunk.CachingConfig(
# num_snapshot_nodes=None,
# num_chunk_refs=None,
# num_transaction_changes=None,
# num_bytes_attributes=None,
# num_bytes_chunks=None,
# )
# repo = icechunk.Repository.open(accessor.repo.storage, config=config)
# session = repo.readonly_session("main")
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
adem = accessor.open_xarray()
# Don't use the datamask
adem = adem.drop_vars("datamask")
assert {"x", "y"} == set(adem.dims)
assert adem.odc.crs == "EPSG:3413"
return adem
@cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20):
mp.set_start_method("forkserver", force=True)
with (
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
):
print(cluster)
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
# ? Mask out cells that do not intersect with ArcticDEM extent
arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
extent_info = gpd.read_parquet(accessor._aux_dir / "ArcticDEM_Mosaic_Index_v4_1_32m.parquet")
grid_gdf = grid_gdf[grid_gdf.intersects(extent_info.geometry.union_all())]
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
aggregations = _Aggregations.common()
aggregated = aggregate_raster_into_grid(
_open_adem,
grid_gdf,
aggregations,
grid,
level,
concurrent_partitions=concurrent_partitions,
n_partitions=400, # 400 # Results in ~10GB large patches
pxbuffer=30,
)
store = get_arcticdem_stores(grid, level)
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
aggregations = _Aggregations(
mean=True,
sum=False,
std=True,
min=True,
max=True,
median=True,
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
)
print("Aggregation complete.")
aggregated = aggregate_raster_into_grid(
_open_adem,
grid_gdf,
aggregations,
grid,
level,
max_workers=max_workers,
n_partitions=200, # Results in 10GB large patches
pxbuffer=30,
)
store = get_arcticdem_stores(grid, level)
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
print("Aggregation complete.")
print("### Finished ArcticDEM processing ###")
stopwatch.summary()
print("### Finished ArcticDEM processing ###")
stopwatch.summary()
if __name__ == "__main__":

View file

@ -677,25 +677,7 @@ def spatial_agg(
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
unaligned = _correct_longs(unaligned)
# with stopwatch("Precomputing cell geometries"):
# with ProcessPoolExecutor(max_workers=20) as executor:
# cell_geometries = list(
# executor.map(
# _get_corrected_geoms,
# [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
# )
# )
# data = _align_data(cell_geometries, unaligned)
# aggregated = _create_aligned(unaligned, data, grid, level)
aggregations = _Aggregations(
mean=True,
sum=False,
std=True,
min=True,
max=True,
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
)
aggregations = _Aggregations.common()
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})

View file

@ -66,6 +66,27 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
return cell_ids
def convert_cell_ids(grid_gdf: gpd.GeoDataFrame):
"""Convert cell IDs in a GeoDataFrame to xdggs-compatible format.
Args:
grid_gdf (gpd.GeoDataFrame): The input grid GeoDataFrame.
Returns:
pd.Series: The the converted cell IDs.
"""
if "cell_id" not in grid_gdf.columns:
raise ValueError("The GeoDataFrame must contain a 'cell_id' column.")
# Check if cell IDs are hex strings (for H3)
if isinstance(grid_gdf["cell_id"].iloc[0], str):
grid_gdf = grid_gdf.copy()
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda cid: int(cid, 16))
return grid_gdf["cell_id"]
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
hex_batch = []
hex_id_batch = []