Finally fix aggregations also for era5

This commit is contained in:
Tobias Hölzer 2025-11-30 01:12:00 +01:00
parent 7b09dda6a3
commit 33c9667383
5 changed files with 151 additions and 52 deletions

View file

@ -48,14 +48,14 @@ Each scale was choosen so that each grid cell had around 10000px do estimate the
| grid | time | scale | | grid | time | scale |
| ----- | ------- | ----- | | ----- | ------- | ----- |
| Hex3 | | 1600 | | Hex3 | 46 min | 1600 |
| Hex4 | | 600 | | Hex4 | 5:04 h | 600 |
| Hex5 | | 240 | | Hex5 | 31:48 h | 240 |
| Hex6 | | 90 | | Hex6 | | 90 |
| Hpx6 | 58 min | 1600 | | Hpx6 | 58 min | 1600 |
| Hpx7 | 3:16 h | 800 | | Hpx7 | 3:16 h | 800 |
| Hpx8 | 13:19 h | 400 | | Hpx8 | 13:19 h | 400 |
| Hpx9 | | 200 | | Hpx9 | 51:33 h | 200 |
| Hpx10 | | 100 | | Hpx10 | | 100 |
## Era5 ## Era5
@ -78,13 +78,59 @@ For geometries crossing the antimeridian, geometries are corrected.
| grid | method | | grid | method |
| ----- | ----------- | | ----- | ----------- |
| Hex3 | Common | | Hex3 | Common |
| Hex4 | Common | | Hex4 | Mean |
| Hex5 | Mean | | Hex5 | Interpolate |
| Hex6 | Interpolate | | Hex6 | Interpolate |
| Hpx6 | Common | | Hpx6 | Common |
| Hpx7 | Common | | Hpx7 | Mean |
| Hpx8 | Common | | Hpx8 | Mean |
| Hpx9 | Mean | | Hpx9 | Interpolate |
| Hpx10 | Interpolate | | Hpx10 | Interpolate |
- hex level 3
min: 30.0
max: 850.0
mean: 251.25216674804688
median: 235.5
- hex level 4
min: 8.0
max: 166.0
mean: 47.2462158203125
median: 44.0
- hex level 5
min: 3.0
max: 41.0
mean: 11.164162635803223
median: 10.0
- hex level 6
min: 2.0
max: 14.0
mean: 4.509947776794434
median: 4.0
- healpix level 6
min: 25.0
max: 769.0
mean: 214.97296142578125
median: 204.0
healpix level 7
min: 9.0
max: 231.0
mean: 65.91140747070312
median: 62.0
healpix level 8
min: 4.0
max: 75.0
mean: 22.516725540161133
median: 21.0
healpix level 9
min: 2.0
max: 29.0
mean: 8.952080726623535
median: 9.0
healpix level 10
min: 2.0
max: 15.0
mean: 4.361577987670898
median: 4.0
??? ???

View file

@ -3,13 +3,13 @@
# pixi run era5 download # pixi run era5 download
# pixi run era5 enrich # pixi run era5 enrich
pixi run era5 spatial-agg --grid hex --level 3 pixi run era5 spatial-agg --grid hex --level 3 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 4 pixi run era5 spatial-agg --grid hex --level 4 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 5 pixi run era5 spatial-agg --grid hex --level 5 --concurrent-partitions 20
pixi run era5 spatial-agg --grid hex --level 6 pixi run era5 spatial-agg --grid hex --level 6 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 6 pixi run era5 spatial-agg --grid healpix --level 6 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 7 pixi run era5 spatial-agg --grid healpix --level 7 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 8 pixi run era5 spatial-agg --grid healpix --level 8 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 9 pixi run era5 spatial-agg --grid healpix --level 9 --concurrent-partitions 20
pixi run era5 spatial-agg --grid healpix --level 10 pixi run era5 spatial-agg --grid healpix --level 10 --concurrent-partitions 20

15
scripts/04arcticdem.sh Normal file
View file

@ -0,0 +1,15 @@
#!/bin/bash
# pixi run arcticdem download
# pixi run arcticdem enrich
pixi run arcticdem aggregate --grid hex --level 3 --concurrent-partitions 20
pixi run arcticdem aggregate --grid hex --level 4 --concurrent-partitions 20
pixi run arcticdem aggregate --grid hex --level 5 --concurrent-partitions 20
pixi run arcticdem aggregate --grid hex --level 6 --concurrent-partitions 20
pixi run arcticdem aggregate --grid healpix --level 6 --concurrent-partitions 20
pixi run arcticdem aggregate --grid healpix --level 7 --concurrent-partitions 20
pixi run arcticdem aggregate --grid healpix --level 8 --concurrent-partitions 20
pixi run arcticdem aggregate --grid healpix --level 9 --concurrent-partitions 20
pixi run arcticdem aggregate --grid healpix --level 10 --concurrent-partitions 20

View file

@ -147,6 +147,7 @@ class _Aggregations:
dim="z", dim="z",
skipna=True, skipna=True,
).to_numpy() ).to_numpy()
return cell_data
def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray: def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray:
if isinstance(flattened, xr.DataArray): if isinstance(flattened, xr.DataArray):
@ -199,7 +200,6 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
# Split geometry in case it crossed antimeridian # Split geometry in case it crossed antimeridian
else: else:
geoms = _split_antimeridian_cell(geom) geoms = _split_antimeridian_cell(geom)
geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms] geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms]
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms)) geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
return geoms return geoms
@ -373,16 +373,45 @@ def _align_partition(
memprof.log_memory("Before reading partial raster", log=False) memprof.log_memory("Before reading partial raster", log=False)
need_to_close_raster = False need_to_close_raster = False
if raster is None: if raster is None:
assert shared_raster is not None, "Shared raster is not initialized in worker"
# print("Using shared raster in worker") # print("Using shared raster in worker")
raster = shared_raster raster = shared_raster
elif callable(raster) and not isinstance(raster, xr.Dataset): elif callable(raster):
# print("Loading raster in partition") # print("Loading raster in partition")
raster = raster() raster = raster()
need_to_close_raster = True need_to_close_raster = True
# else: # else:
# print("Using provided raster in partition") # print("Using provided raster in partition")
# Partition the raster if necessary
is_raster_in_memory = isinstance(raster, xr.Dataset) and all(
isinstance(raster[var].data, np.ndarray) for var in raster.data_vars
)
if is_raster_in_memory:
partial_raster = raster
else:
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
raise e
if partial_raster.nbytes / 1e9 > 20:
print(
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
f" This may lead to out-of-memory errors."
)
memprof.log_memory("After reading partial raster", log=False)
if aggregations is None: if aggregations is None:
cell_ids = grids.convert_cell_ids(grid_partition_gdf) cell_ids = grids.convert_cell_ids(grid_partition_gdf)
if grid_partition_gdf.crs.to_epsg() == 4326: if grid_partition_gdf.crs.to_epsg() == 4326:
@ -402,28 +431,9 @@ def _align_partition(
) )
# ?: Cubic does not work with NaNs in xarray interp # ?: Cubic does not work with NaNs in xarray interp
with stopwatch("Interpolating data to grid centroids", log=False): with stopwatch("Interpolating data to grid centroids", log=False):
ongrid = raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan}) ongrid = partial_raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
memprof.log_memory("After interpolating data", log=False) memprof.log_memory("After interpolating data", log=False)
else: else:
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster: xr.Dataset = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
raise e
if partial_raster.nbytes / 1e9 > 20:
print(
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
f" This may lead to out-of-memory errors."
)
memprof.log_memory("After reading partial raster", log=False)
others_shape = tuple( others_shape = tuple(
[raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
) )
@ -449,9 +459,8 @@ def _align_partition(
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables") ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
partial_raster.close() partial_raster.close()
del partial_raster del partial_raster
if need_to_close_raster: if need_to_close_raster:
raster.close() raster.close()
del raster del raster
@ -477,6 +486,27 @@ def _align_data(
concurrent_partitions: int, concurrent_partitions: int,
pxbuffer: int, pxbuffer: int,
): ):
# ? Logic for memory management of raster dataset in multiprocessing scenarios
# There are multiple possible scenarios for the raster dataset:
# Raster can be either:
# 1. A in-memory xarray.Dataset object
# 2. A callable that returns an xarray.Dataset object
# 3. A lazy xarray.Dataset object (e.g., backed by zarr or icechunk)
# Concurrent partitions can be either:
# A. 1 (no multiprocessing)
# B. >1 (multiprocessing with multiple workers in fork)
# C. >1 (multiprocessing with multiple workers in spawn/forkserver)
# This results into potential working modes:
# 1. Work on the complete raster in memory (no partitioning of the raster is needed)
# A. Pass it to the function
# B. Make a raster handle in each worker
# C. Move it into a shared memory buffer (NOT Implemented)
# 2. & 3. Work on partitions of the raster
# A. Load a partition into memory in each task
# B. & C. Make a raster handle in each worker and load partitions into memory in each task
# ! Note, this is not YET implemented, but this is how it SHOULD be implemented
partial_ongrids = [] partial_ongrids = []
if isinstance(grid_gdf, list): if isinstance(grid_gdf, list):
@ -501,18 +531,21 @@ def _align_data(
) )
partial_ongrids.append(part_ongrid) partial_ongrids.append(part_ongrid)
else: else:
is_raster_in_memory = isinstance(raster, xr.Dataset) and all(
isinstance(raster[var].data, np.ndarray) for var in raster.data_vars
)
is_mpfork = mp.get_start_method(allow_none=True) == "fork"
taskarg = None if is_raster_in_memory else raster
# For spawn or forkserver, we need to copy the raster into each worker
workerargs = (None if not is_raster_in_memory or not is_mpfork else raster,)
# For mp start method fork, we can share the raster dataset between workers # For mp start method fork, we can share the raster dataset between workers
if mp.get_start_method(allow_none=True) == "fork": if mp.get_start_method(allow_none=True) == "fork" and is_raster_in_memory:
_init_worker(raster if isinstance(raster, xr.Dataset) else None) _init_worker(raster)
initargs = (None,)
else:
# For spawn or forkserver, we need to copy the raster into each worker
initargs = (raster if isinstance(raster, xr.Dataset) else None,)
with ProcessPoolExecutor( with ProcessPoolExecutor(
max_workers=concurrent_partitions, max_workers=concurrent_partitions,
initializer=_init_worker, initializer=_init_worker,
initargs=initargs, initargs=workerargs,
) as executor: ) as executor:
futures = {} futures = {}
for i, grid_partition in enumerate(grid_partitions): for i, grid_partition in enumerate(grid_partitions):
@ -520,7 +553,7 @@ def _align_data(
executor.submit( executor.submit(
_align_partition, _align_partition,
grid_partition, grid_partition,
None if isinstance(raster, xr.Dataset) else raster, taskarg,
aggregations, aggregations,
pxbuffer, pxbuffer,
) )

View file

@ -674,6 +674,11 @@ def spatial_agg(
# ? Mask out water, since we don't want to aggregate over oceans # ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf) grid_gdf = watermask.clip_grid(grid_gdf)
grid_gdf = grid_gdf.to_crs("epsg:4326") grid_gdf = grid_gdf.to_crs("epsg:4326")
# There is a small group of invalid geometry in healpix 10, filter it out manually
# They are invalid because the watermask clipping does not handle the antimeridian well
if grid == "healpix" and level == 10:
invalid_cell_id = [3059646, 3063547]
grid_gdf = grid_gdf[~grid_gdf.cell_id.isin(invalid_cell_id)]
aggregations = { aggregations = {
"hex": { "hex": {
@ -685,7 +690,7 @@ def spatial_agg(
"healpix": { "healpix": {
6: _Aggregations.common(), 6: _Aggregations.common(),
7: _Aggregations.common(), 7: _Aggregations.common(),
8: _Aggregations.common(), 8: _Aggregations(mean=True),
9: _Aggregations(mean=True), 9: _Aggregations(mean=True),
10: "interpolate", 10: "interpolate",
}, },