Fix memory leak in aggregations
This commit is contained in:
parent
341fa7b836
commit
98314fe8b3
7 changed files with 327 additions and 200 deletions
74
Processing Documentation.md
Normal file
74
Processing Documentation.md
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
# Processing Documentation
|
||||
|
||||
This document documents how long each processing step took and how much memory and compute (CPU & GPU) it needed.
|
||||
|
||||
| Grid | ArcticDEM | Era5 | AlphaEarth | Darts |
|
||||
| ----- | --------- | ---- | ---------- | ----- |
|
||||
| Hex3 | [ ] | [/] | [ ] | [/] |
|
||||
| Hex4 | [ ] | [/] | [ ] | [/] |
|
||||
| Hex5 | [ ] | [/] | [ ] | [/] |
|
||||
| Hex6 | [ ] | [ ] | [ ] | [ ] |
|
||||
| Hpx6 | [x] | [/] | [ ] | [/] |
|
||||
| Hpx7 | [ ] | [/] | [ ] | [/] |
|
||||
| Hpx8 | [ ] | [/] | [ ] | [/] |
|
||||
| Hpx9 | [ ] | [/] | [ ] | [/] |
|
||||
| Hpx10 | [ ] | [ ] | [ ] | [ ] |
|
||||
|
||||
## Grid creation
|
||||
|
||||
The creation of grids did not take up any significant amount of memory or compute.
|
||||
The time taken to create a grid was between few seconds for smaller levels up to a few minutes for the high levels.
|
||||
|
||||
## DARTS
|
||||
|
||||
Similar to grid creation, no significant amount of memory, compute or time needed.
|
||||
|
||||
## ArcticDEM
|
||||
|
||||
The download took around 8h with memory usage of about 10GB and no stronger limitations on compute.
|
||||
The size of the resulted icechunk zarr datacube was approx. 160GB on disk which corresponse to approx. 270GB in memory if loaded in.
|
||||
|
||||
The enrichment took around 2h on a single A100 GPU node (40GB) with a local dask cluster consisting of 7 processes, each using 2 threads and 30GB of memory, making up a total of 210GB of memory.
|
||||
These settings can be changed easily to consume less memory by reducing the number of processes or threads.
|
||||
More processes or thread could not be used to ensure that the GPU does not run out of memory.
|
||||
|
||||
### Spatial aggregations into grids
|
||||
|
||||
All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile
|
||||
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
|
||||
|
||||
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
|
||||
|
||||
| grid | time | memory | processes |
|
||||
| ----- | ------ | ------ | --------- |
|
||||
| Hex3 | | | |
|
||||
| Hex4 | | | |
|
||||
| Hex5 | | | |
|
||||
| Hex6 | | | |
|
||||
| Hpx6 | 37 min | ~300GB | 40 |
|
||||
| Hpx7 | | | |
|
||||
| Hpx8 | | | |
|
||||
| Hpx9 | 25m | ~300GB | 40 |
|
||||
| Hpx10 | 34 min | ~300GB | 40 |
|
||||
|
||||
## Alpha Earth
|
||||
|
||||
The download was heavy limited through the scale of the input data, which is ~10m in the original dataset.
|
||||
10m as a scale was not computationally feasible for the Google Earth Engine servers, thus each grid and level used another scale to aggregate and download the data.
|
||||
Each scale was choosen so that each grid cell had around 10000px do estimate the aggregations from it.
|
||||
|
||||
| grid | time | scale |
|
||||
| ----- | ------- | ----- |
|
||||
| Hex3 | | 1600 |
|
||||
| Hex4 | | 600 |
|
||||
| Hex5 | | 240 |
|
||||
| Hex6 | | 90 |
|
||||
| Hpx6 | 58 min | 1600 |
|
||||
| Hpx7 | 3:16 h | 800 |
|
||||
| Hpx8 | 13:19 h | 400 |
|
||||
| Hpx9 | | 200 |
|
||||
| Hpx10 | | 100 |
|
||||
|
||||
## Era5
|
||||
|
||||
???
|
||||
37
pixi.lock
generated
37
pixi.lock
generated
|
|
@ -475,6 +475,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/55/d2/507fd0ee4dd574d2bdbdeac5df83f39d2cae1ffe97d4622cca6f6bab39f1/botocore-1.40.70-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
||||
|
|
@ -609,6 +610,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
||||
- pypi: git+https://github.com/davbyr/xAnimate#750e03e480db309407e09f4ffe5f49522a4c4f9b
|
||||
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
||||
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/65/ad/8f9ff43ff49ef02c7b8202a42c32a1fe8de1276bba0e6f55609e19ff7585/xdggs-0.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/05/b9/b6a9cf72aef69c3e6db869dcc130e19452a658366dac9377f9cd32a76b80/xproj-0.2.1-py3-none-any.whl
|
||||
|
|
@ -1410,6 +1412,13 @@ packages:
|
|||
- pkg:pypi/bokeh?source=compressed-mapping
|
||||
size: 5027028
|
||||
timestamp: 1762557204752
|
||||
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: boost-histogram
|
||||
version: 1.6.1
|
||||
sha256: c5700e53bf2d3d006f71610f71fc592d88dab3279dad300e8178dc084018b177
|
||||
requires_dist:
|
||||
- numpy
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
||||
name: boto3
|
||||
version: 1.40.70
|
||||
|
|
@ -2810,7 +2819,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
|
||||
sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2865,6 +2874,7 @@ packages:
|
|||
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
||||
- cupy-xarray>=0.1.4,<0.2
|
||||
- memray>=1.19.1,<2
|
||||
- xarray-histogram>=0.2.2,<0.3
|
||||
requires_python: '>=3.13,<3.14'
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
|
|
@ -9373,6 +9383,31 @@ packages:
|
|||
- pkg:pypi/xarray?source=hash-mapping
|
||||
size: 989507
|
||||
timestamp: 1763464377176
|
||||
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
|
||||
name: xarray-histogram
|
||||
version: 0.2.2
|
||||
sha256: 90682f9575131e5ea5d96feba5f8679e8c3fae039018d5e38f497696d55ccd5c
|
||||
requires_dist:
|
||||
- boost-histogram
|
||||
- numpy
|
||||
- xarray
|
||||
- dask ; extra == 'full'
|
||||
- scipy ; extra == 'full'
|
||||
- dask ; extra == 'dev'
|
||||
- scipy ; extra == 'dev'
|
||||
- sphinx ; extra == 'dev'
|
||||
- sphinx-book-theme ; extra == 'dev'
|
||||
- ruff ; extra == 'dev'
|
||||
- mypy>=1.5 ; extra == 'dev'
|
||||
- pytest>=7.4 ; extra == 'dev'
|
||||
- coverage ; extra == 'dev'
|
||||
- pytest-cov ; extra == 'dev'
|
||||
- dask ; extra == 'tests'
|
||||
- scipy ; extra == 'tests'
|
||||
- pytest>=7.4 ; extra == 'tests'
|
||||
- coverage ; extra == 'tests'
|
||||
- pytest-cov ; extra == 'tests'
|
||||
requires_python: '>=3.11'
|
||||
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
||||
name: xarray-spatial
|
||||
version: 0.1.dev519+g3a3120981
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ dependencies = [
|
|||
"xgboost>=3.1.1,<4",
|
||||
"s3fs>=2025.10.0,<2026",
|
||||
"xarray-spatial",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
|
@ -33,6 +32,7 @@ from entropice import grids
|
|||
@dataclass(frozen=True)
|
||||
class _Aggregations:
|
||||
# ! The ordering is super important for this class!
|
||||
_common: bool = False # If true, use _agg_cell_data_single_common and ignore orther flags
|
||||
mean: bool = True
|
||||
sum: bool = False
|
||||
std: bool = False
|
||||
|
|
@ -46,8 +46,14 @@ class _Aggregations:
|
|||
if self.median and 0.5 in self.quantiles:
|
||||
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
|
||||
|
||||
@classmethod
|
||||
def common(cls):
|
||||
return cls(_common=True)
|
||||
|
||||
@cache
|
||||
def __len__(self) -> int:
|
||||
if self._common:
|
||||
return 11
|
||||
length = 0
|
||||
if self.mean:
|
||||
length += 1
|
||||
|
|
@ -65,6 +71,8 @@ class _Aggregations:
|
|||
return length
|
||||
|
||||
def aggnames(self) -> list[str]:
|
||||
if self._common:
|
||||
return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||
names = []
|
||||
if self.mean:
|
||||
names.append("mean")
|
||||
|
|
@ -83,7 +91,24 @@ class _Aggregations:
|
|||
names.append(f"p{q_int}")
|
||||
return names
|
||||
|
||||
def _agg_cell_data_single_common(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||
cell_data[0, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
|
||||
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
|
||||
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
|
||||
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
|
||||
quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here!
|
||||
cell_data[4:, ...] = flattened_var.quantile(
|
||||
q=quantiles_to_compute,
|
||||
dim="z",
|
||||
skipna=True,
|
||||
).to_numpy()
|
||||
return cell_data
|
||||
|
||||
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||
if self._common:
|
||||
return self._agg_cell_data_single_common(flattened_var)
|
||||
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||
j = 0
|
||||
|
|
@ -106,6 +131,12 @@ class _Aggregations:
|
|||
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
|
||||
if self.median:
|
||||
quantiles_to_compute.insert(0, 0.5)
|
||||
|
||||
# Potential way to use cupy:
|
||||
# cp.apply_along_axis(
|
||||
# lambda x: cp.quantile(x[~cp.isnan(x)], q=quantiles), axis=0, # arr=data_gpu_raw
|
||||
# )
|
||||
|
||||
cell_data[j:, ...] = flattened_var.quantile(
|
||||
q=quantiles_to_compute,
|
||||
dim="z",
|
||||
|
|
@ -173,7 +204,7 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
|
|||
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
|
||||
flattened = cropped.stack(z=spatdims) # noqa: PD013
|
||||
# if flattened.z.size > 3000:
|
||||
# if flattened.z.size > 1000000:
|
||||
# flattened = flattened.cupy.as_cupy()
|
||||
cell_data = aggregations.agg_cell_data(flattened)
|
||||
return cell_data
|
||||
|
|
@ -187,6 +218,8 @@ def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggr
|
|||
else ["y", "x"]
|
||||
)
|
||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
|
||||
# if flattened.z.size > 1000000:
|
||||
# flattened = flattened.cupy.as_cupy()
|
||||
cell_data = aggregations.agg_cell_data(flattened)
|
||||
return cell_data
|
||||
|
||||
|
|
@ -218,7 +251,15 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati
|
|||
|
||||
|
||||
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
|
||||
if grid_gdf.crs.to_epsg() == 4326:
|
||||
crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian)
|
||||
else:
|
||||
crosses_antimeridian = pd.Series([False] * len(grid_gdf), index=grid_gdf.index)
|
||||
|
||||
if crosses_antimeridian.any():
|
||||
grid_gdf_am = grid_gdf[crosses_antimeridian]
|
||||
grid_gdf = grid_gdf[~crosses_antimeridian]
|
||||
|
||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||
|
|
@ -226,6 +267,9 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[
|
|||
partition = grid_gdf[labels == i]
|
||||
yield partition
|
||||
|
||||
if crosses_antimeridian.any():
|
||||
yield grid_gdf_am
|
||||
|
||||
|
||||
class _MemoryProfiler:
|
||||
def __init__(self):
|
||||
|
|
@ -279,20 +323,20 @@ def _align_partition(
|
|||
# 1. unmeasurable
|
||||
# 2. ~0.1-0.3s
|
||||
# 3. ~0.05s
|
||||
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
|
||||
# faster than a processpoolexecutor
|
||||
|
||||
memprof.log_memory("Before reading partial raster", log=False)
|
||||
|
||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
raster = raster()
|
||||
|
||||
if hasattr(raster, "__dask_graph__"):
|
||||
graph_size = len(raster.__dask_graph__())
|
||||
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
|
||||
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
|
||||
need_to_close_raster = True
|
||||
else:
|
||||
need_to_close_raster = False
|
||||
|
||||
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
||||
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
|
||||
|
||||
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||
partial_extent = partial_extent.buffered(
|
||||
|
|
@ -304,57 +348,42 @@ def _align_partition(
|
|||
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||
except Exception as e:
|
||||
print(f"Error cropping raster to partition extent: {e}")
|
||||
return data
|
||||
return ongrid
|
||||
|
||||
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
|
||||
memprof.log_memory("After reading partial raster", log=True)
|
||||
if partial_raster.nbytes / 1e9 > 20:
|
||||
print(
|
||||
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
|
||||
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
|
||||
f" This may lead to out-of-memory errors."
|
||||
)
|
||||
memprof.log_memory("After reading partial raster", log=False)
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=60,
|
||||
) as executor:
|
||||
futures = {}
|
||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
# ? Splitting the geometries already here and passing only the cropped data to the worker to
|
||||
# reduce pickling overhead
|
||||
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
|
||||
if len(geoms) == 0:
|
||||
continue
|
||||
elif len(geoms) == 1:
|
||||
cropped = [_read_cell_data(partial_raster, geoms[0])]
|
||||
else:
|
||||
cropped = _read_split_cell_data(partial_raster, geoms)
|
||||
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
|
||||
|
||||
# Clean up immediately after submitting
|
||||
del geoms, cropped # Added this line
|
||||
|
||||
for future in as_completed(futures):
|
||||
i = futures[future]
|
||||
try:
|
||||
cell_data = future.result()
|
||||
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Error processing cell {i}: {e}")
|
||||
print(f"Error processing cell {row['cell_id']}: {e}")
|
||||
continue
|
||||
data[i, ...] = cell_data
|
||||
del cell_data, futures[future]
|
||||
ongrid[i, ...] = cell_data
|
||||
|
||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
# try:
|
||||
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
# raise e
|
||||
# except Exception as e:
|
||||
# print(f"Error processing cell {row['cell_id']}: {e}")
|
||||
# continue
|
||||
# data[i, ...] = cell_data
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
||||
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
|
||||
|
||||
partial_raster.close()
|
||||
del partial_raster
|
||||
if need_to_close_raster:
|
||||
raster.close()
|
||||
del partial_raster, raster, future, futures
|
||||
del raster
|
||||
gc.collect()
|
||||
memprof.log_memory("After cleaning", log=True)
|
||||
memprof.log_memory("After cleaning", log=False)
|
||||
|
||||
print("Finished processing partition")
|
||||
print("### Stopwatch summary ###\n")
|
||||
|
|
@ -362,45 +391,8 @@ def _align_partition(
|
|||
print("### Memory summary ###\n")
|
||||
print(memprof.summary())
|
||||
print("#########################")
|
||||
# with ProcessPoolExecutor(
|
||||
# max_workers=max_workers,
|
||||
# initializer=_init_partial_raster_global,
|
||||
# initargs=(raster, grid_partition_gdf, pxbuffer),
|
||||
# ) as executor:
|
||||
# futures = {
|
||||
# executor.submit(_process_geom_global, row.geometry, aggregations): i
|
||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
|
||||
# }
|
||||
|
||||
# others_shape = tuple(
|
||||
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
# )
|
||||
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
# print(f"Allocating partial data array of shape {data_shape}")
|
||||
# data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||
|
||||
# lap = time.time()
|
||||
# for future in track(
|
||||
# as_completed(futures),
|
||||
# total=len(futures),
|
||||
# description="Spatially aggregating ERA5 data...",
|
||||
# ):
|
||||
# print(f"time since last cell: {time.time() - lap:.2f}s")
|
||||
# i = futures[future]
|
||||
# try:
|
||||
# cell_data = future.result()
|
||||
# data[i, ...] = cell_data
|
||||
# except Exception as e:
|
||||
# print(f"Error processing cell {i}: {e}")
|
||||
# continue
|
||||
# finally:
|
||||
# # Try to free memory
|
||||
# del futures[future], future
|
||||
# if "cell_data" in locals():
|
||||
# del cell_data
|
||||
# gc.collect()
|
||||
# lap = time.time()
|
||||
return data
|
||||
return ongrid
|
||||
|
||||
|
||||
@stopwatch("Aligning data with grid")
|
||||
|
|
@ -409,13 +401,30 @@ def _align_data(
|
|||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations,
|
||||
n_partitions: int,
|
||||
max_workers: int,
|
||||
concurrent_partitions: int,
|
||||
pxbuffer: int,
|
||||
):
|
||||
partial_data = {}
|
||||
partial_ongrids = []
|
||||
|
||||
_init_worker()
|
||||
|
||||
if n_partitions < concurrent_partitions:
|
||||
print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}")
|
||||
concurrent_partitions = n_partitions
|
||||
|
||||
if concurrent_partitions <= 1:
|
||||
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||
print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
|
||||
part_ongrid = _align_partition(
|
||||
grid_partition,
|
||||
raster, # .copy() if isinstance(raster, xr.Dataset) else raster,
|
||||
aggregations,
|
||||
pxbuffer,
|
||||
)
|
||||
partial_ongrids.append(part_ongrid)
|
||||
else:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
max_workers=concurrent_partitions,
|
||||
initializer=_init_worker,
|
||||
# initializer=_init_raster_global,
|
||||
# initargs=(raster,),
|
||||
|
|
@ -431,6 +440,8 @@ def _align_data(
|
|||
pxbuffer,
|
||||
)
|
||||
] = i
|
||||
if i == 6:
|
||||
print("Breaking after 3 partitions for testing purposes")
|
||||
|
||||
print("Submitted all partitions, waiting for results...")
|
||||
|
||||
|
|
@ -442,14 +453,14 @@ def _align_data(
|
|||
try:
|
||||
i = futures[future]
|
||||
print(f"Processed partition {i + 1}/{len(futures)}")
|
||||
part_data = future.result()
|
||||
partial_data[i] = part_data
|
||||
part_ongrid = future.result()
|
||||
partial_ongrids.append(part_ongrid)
|
||||
except Exception as e:
|
||||
print(f"Error processing partition {i}: {e}")
|
||||
raise e
|
||||
|
||||
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
|
||||
return data
|
||||
ongrid = xr.concat(partial_ongrids, dim="cell_ids")
|
||||
return ongrid
|
||||
|
||||
|
||||
def aggregate_raster_into_grid(
|
||||
|
|
@ -459,7 +470,7 @@ def aggregate_raster_into_grid(
|
|||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
n_partitions: int = 20,
|
||||
max_workers: int = 5,
|
||||
concurrent_partitions: int = 5,
|
||||
pxbuffer: int = 15,
|
||||
):
|
||||
"""Aggregate raster data into grid cells.
|
||||
|
|
@ -471,19 +482,20 @@ def aggregate_raster_into_grid(
|
|||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||
level (int): The level of the grid.
|
||||
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
|
||||
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
||||
Defaults to 5.
|
||||
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
||||
|
||||
Returns:
|
||||
xr.Dataset: Aggregated data aligned with the grid.
|
||||
|
||||
"""
|
||||
aligned = _align_data(
|
||||
ongrid = _align_data(
|
||||
grid_gdf,
|
||||
raster,
|
||||
aggregations,
|
||||
n_partitions=n_partitions,
|
||||
max_workers=max_workers,
|
||||
concurrent_partitions=concurrent_partitions,
|
||||
pxbuffer=pxbuffer,
|
||||
)
|
||||
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
|
||||
|
|
@ -491,16 +503,6 @@ def aggregate_raster_into_grid(
|
|||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
raster = raster()
|
||||
|
||||
cell_ids = grids.get_cell_ids(grid, level)
|
||||
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
||||
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
|
||||
|
||||
gridinfo = {
|
||||
"grid_name": "h3" if grid == "hex" else grid,
|
||||
"level": level,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import cupyx.scipy.signal
|
|||
import cyclopts
|
||||
import dask.array
|
||||
import dask.distributed as dd
|
||||
import geopandas as gpd
|
||||
import icechunk
|
||||
import icechunk.xarray
|
||||
import numpy as np
|
||||
|
|
@ -48,7 +49,7 @@ def download():
|
|||
|
||||
|
||||
@dataclass
|
||||
class KernelFactory:
|
||||
class _KernelFactory:
|
||||
res: int
|
||||
size_px: int
|
||||
|
||||
|
|
@ -77,14 +78,14 @@ class KernelFactory:
|
|||
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
|
||||
|
||||
|
||||
def tpi_cupy(chunk, kernels: KernelFactory):
|
||||
def tpi_cupy(chunk, kernels: _KernelFactory):
|
||||
kernel = kernels.ring()
|
||||
kernel = kernel / cp.nansum(kernel)
|
||||
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||
return tpi
|
||||
|
||||
|
||||
def tri_cupy(chunk, kernels: KernelFactory):
|
||||
def tri_cupy(chunk, kernels: _KernelFactory):
|
||||
kernel = kernels.tri()
|
||||
c2 = chunk**2
|
||||
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||
|
|
@ -93,7 +94,7 @@ def tri_cupy(chunk, kernels: KernelFactory):
|
|||
return tri
|
||||
|
||||
|
||||
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
|
||||
def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory):
|
||||
slope_rad = slope * (cp.pi / 180)
|
||||
aspect_rad = aspect * (cp.pi / 180)
|
||||
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
|
||||
|
|
@ -150,9 +151,9 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
|
|||
|
||||
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
|
||||
res = 32 # 32m resolution
|
||||
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
||||
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
||||
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
|
||||
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
||||
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
||||
large_kernels = _KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
|
||||
|
||||
# Check if there is data in the chunk
|
||||
if np.all(np.isnan(chunk)):
|
||||
|
|
@ -306,39 +307,51 @@ def enrich():
|
|||
def _open_adem():
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
# config = icechunk.RepositoryConfig.default()
|
||||
# config.caching = icechunk.CachingConfig(
|
||||
# num_snapshot_nodes=None,
|
||||
# num_chunk_refs=None,
|
||||
# num_transaction_changes=None,
|
||||
# num_bytes_attributes=None,
|
||||
# num_bytes_chunks=None,
|
||||
# )
|
||||
# repo = icechunk.Repository.open(accessor.repo.storage, config=config)
|
||||
# session = repo.readonly_session("main")
|
||||
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
|
||||
adem = accessor.open_xarray()
|
||||
# Don't use the datamask
|
||||
adem = adem.drop_vars("datamask")
|
||||
assert {"x", "y"} == set(adem.dims)
|
||||
assert adem.odc.crs == "EPSG:3413"
|
||||
return adem
|
||||
|
||||
|
||||
@cli.command()
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20):
|
||||
mp.set_start_method("forkserver", force=True)
|
||||
|
||||
with (
|
||||
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
|
||||
):
|
||||
print(cluster)
|
||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
# ? Mask out water, since we don't want to aggregate over oceans
|
||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||
# ? Mask out cells that do not intersect with ArcticDEM extent
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
extent_info = gpd.read_parquet(accessor._aux_dir / "ArcticDEM_Mosaic_Index_v4_1_32m.parquet")
|
||||
grid_gdf = grid_gdf[grid_gdf.intersects(extent_info.geometry.union_all())]
|
||||
|
||||
aggregations = _Aggregations(
|
||||
mean=True,
|
||||
sum=False,
|
||||
std=True,
|
||||
min=True,
|
||||
max=True,
|
||||
median=True,
|
||||
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
|
||||
)
|
||||
|
||||
aggregations = _Aggregations.common()
|
||||
aggregated = aggregate_raster_into_grid(
|
||||
_open_adem,
|
||||
grid_gdf,
|
||||
aggregations,
|
||||
grid,
|
||||
level,
|
||||
max_workers=max_workers,
|
||||
n_partitions=200, # Results in 10GB large patches
|
||||
concurrent_partitions=concurrent_partitions,
|
||||
n_partitions=400, # 400 # Results in ~10GB large patches
|
||||
pxbuffer=30,
|
||||
)
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
|
|
|
|||
|
|
@ -677,25 +677,7 @@ def spatial_agg(
|
|||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||
unaligned = _correct_longs(unaligned)
|
||||
|
||||
# with stopwatch("Precomputing cell geometries"):
|
||||
# with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
# cell_geometries = list(
|
||||
# executor.map(
|
||||
# _get_corrected_geoms,
|
||||
# [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
||||
# )
|
||||
# )
|
||||
# data = _align_data(cell_geometries, unaligned)
|
||||
# aggregated = _create_aligned(unaligned, data, grid, level)
|
||||
|
||||
aggregations = _Aggregations(
|
||||
mean=True,
|
||||
sum=False,
|
||||
std=True,
|
||||
min=True,
|
||||
max=True,
|
||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||
)
|
||||
aggregations = _Aggregations.common()
|
||||
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
||||
|
||||
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||
|
|
|
|||
|
|
@ -66,6 +66,27 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
|
|||
return cell_ids
|
||||
|
||||
|
||||
def convert_cell_ids(grid_gdf: gpd.GeoDataFrame):
|
||||
"""Convert cell IDs in a GeoDataFrame to xdggs-compatible format.
|
||||
|
||||
Args:
|
||||
grid_gdf (gpd.GeoDataFrame): The input grid GeoDataFrame.
|
||||
|
||||
Returns:
|
||||
pd.Series: The the converted cell IDs.
|
||||
|
||||
"""
|
||||
if "cell_id" not in grid_gdf.columns:
|
||||
raise ValueError("The GeoDataFrame must contain a 'cell_id' column.")
|
||||
|
||||
# Check if cell IDs are hex strings (for H3)
|
||||
if isinstance(grid_gdf["cell_id"].iloc[0], str):
|
||||
grid_gdf = grid_gdf.copy()
|
||||
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda cid: int(cid, 16))
|
||||
|
||||
return grid_gdf["cell_id"]
|
||||
|
||||
|
||||
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
|
||||
hex_batch = []
|
||||
hex_id_batch = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue