Fix memory leak in aggregations
This commit is contained in:
parent
341fa7b836
commit
98314fe8b3
7 changed files with 327 additions and 200 deletions
74
Processing Documentation.md
Normal file
74
Processing Documentation.md
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Processing Documentation
|
||||||
|
|
||||||
|
This document documents how long each processing step took and how much memory and compute (CPU & GPU) it needed.
|
||||||
|
|
||||||
|
| Grid | ArcticDEM | Era5 | AlphaEarth | Darts |
|
||||||
|
| ----- | --------- | ---- | ---------- | ----- |
|
||||||
|
| Hex3 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hex4 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hex5 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hex6 | [ ] | [ ] | [ ] | [ ] |
|
||||||
|
| Hpx6 | [x] | [/] | [ ] | [/] |
|
||||||
|
| Hpx7 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hpx8 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hpx9 | [ ] | [/] | [ ] | [/] |
|
||||||
|
| Hpx10 | [ ] | [ ] | [ ] | [ ] |
|
||||||
|
|
||||||
|
## Grid creation
|
||||||
|
|
||||||
|
The creation of grids did not take up any significant amount of memory or compute.
|
||||||
|
The time taken to create a grid was between few seconds for smaller levels up to a few minutes for the high levels.
|
||||||
|
|
||||||
|
## DARTS
|
||||||
|
|
||||||
|
Similar to grid creation, no significant amount of memory, compute or time needed.
|
||||||
|
|
||||||
|
## ArcticDEM
|
||||||
|
|
||||||
|
The download took around 8h with memory usage of about 10GB and no stronger limitations on compute.
|
||||||
|
The size of the resulted icechunk zarr datacube was approx. 160GB on disk which corresponse to approx. 270GB in memory if loaded in.
|
||||||
|
|
||||||
|
The enrichment took around 2h on a single A100 GPU node (40GB) with a local dask cluster consisting of 7 processes, each using 2 threads and 30GB of memory, making up a total of 210GB of memory.
|
||||||
|
These settings can be changed easily to consume less memory by reducing the number of processes or threads.
|
||||||
|
More processes or thread could not be used to ensure that the GPU does not run out of memory.
|
||||||
|
|
||||||
|
### Spatial aggregations into grids
|
||||||
|
|
||||||
|
All spatial aggregations relied heavily on CPU compute, since Cupy lacking support for nanquantile
|
||||||
|
and for higher resolution grids the amount of pixels to reduce where too small to overcome the data movement overhead of using a GPU.
|
||||||
|
|
||||||
|
The aggregations scale through the number of concurrent processes (specified by `--concurrent_partitions`) accumulating linearly more memory with higher parallel computation.
|
||||||
|
|
||||||
|
| grid | time | memory | processes |
|
||||||
|
| ----- | ------ | ------ | --------- |
|
||||||
|
| Hex3 | | | |
|
||||||
|
| Hex4 | | | |
|
||||||
|
| Hex5 | | | |
|
||||||
|
| Hex6 | | | |
|
||||||
|
| Hpx6 | 37 min | ~300GB | 40 |
|
||||||
|
| Hpx7 | | | |
|
||||||
|
| Hpx8 | | | |
|
||||||
|
| Hpx9 | 25m | ~300GB | 40 |
|
||||||
|
| Hpx10 | 34 min | ~300GB | 40 |
|
||||||
|
|
||||||
|
## Alpha Earth
|
||||||
|
|
||||||
|
The download was heavy limited through the scale of the input data, which is ~10m in the original dataset.
|
||||||
|
10m as a scale was not computationally feasible for the Google Earth Engine servers, thus each grid and level used another scale to aggregate and download the data.
|
||||||
|
Each scale was choosen so that each grid cell had around 10000px do estimate the aggregations from it.
|
||||||
|
|
||||||
|
| grid | time | scale |
|
||||||
|
| ----- | ------- | ----- |
|
||||||
|
| Hex3 | | 1600 |
|
||||||
|
| Hex4 | | 600 |
|
||||||
|
| Hex5 | | 240 |
|
||||||
|
| Hex6 | | 90 |
|
||||||
|
| Hpx6 | 58 min | 1600 |
|
||||||
|
| Hpx7 | 3:16 h | 800 |
|
||||||
|
| Hpx8 | 13:19 h | 400 |
|
||||||
|
| Hpx9 | | 200 |
|
||||||
|
| Hpx10 | | 100 |
|
||||||
|
|
||||||
|
## Era5
|
||||||
|
|
||||||
|
???
|
||||||
37
pixi.lock
generated
37
pixi.lock
generated
|
|
@ -475,6 +475,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/55/d2/507fd0ee4dd574d2bdbdeac5df83f39d2cae1ffe97d4622cca6f6bab39f1/botocore-1.40.70-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/55/d2/507fd0ee4dd574d2bdbdeac5df83f39d2cae1ffe97d4622cca6f6bab39f1/botocore-1.40.70-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
||||||
|
|
@ -609,6 +610,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl
|
||||||
- pypi: git+https://github.com/davbyr/xAnimate#750e03e480db309407e09f4ffe5f49522a4c4f9b
|
- pypi: git+https://github.com/davbyr/xAnimate#750e03e480db309407e09f4ffe5f49522a4c4f9b
|
||||||
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/65/ad/8f9ff43ff49ef02c7b8202a42c32a1fe8de1276bba0e6f55609e19ff7585/xdggs-0.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/65/ad/8f9ff43ff49ef02c7b8202a42c32a1fe8de1276bba0e6f55609e19ff7585/xdggs-0.2.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl
|
- pypi: https://files.pythonhosted.org/packages/56/b0/e3efafd9c97ed931f6453bd71aa8feaffc9217e6121af65fda06cf32f608/xgboost-3.1.1-py3-none-manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/05/b9/b6a9cf72aef69c3e6db869dcc130e19452a658366dac9377f9cd32a76b80/xproj-0.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/05/b9/b6a9cf72aef69c3e6db869dcc130e19452a658366dac9377f9cd32a76b80/xproj-0.2.1-py3-none-any.whl
|
||||||
|
|
@ -1410,6 +1412,13 @@ packages:
|
||||||
- pkg:pypi/bokeh?source=compressed-mapping
|
- pkg:pypi/bokeh?source=compressed-mapping
|
||||||
size: 5027028
|
size: 5027028
|
||||||
timestamp: 1762557204752
|
timestamp: 1762557204752
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||||
|
name: boost-histogram
|
||||||
|
version: 1.6.1
|
||||||
|
sha256: c5700e53bf2d3d006f71610f71fc592d88dab3279dad300e8178dc084018b177
|
||||||
|
requires_dist:
|
||||||
|
- numpy
|
||||||
|
requires_python: '>=3.9'
|
||||||
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/f3/cf/e24d08b37cd318754a8e94906c8b34b88676899aad1907ff6942311f13c4/boto3-1.40.70-py3-none-any.whl
|
||||||
name: boto3
|
name: boto3
|
||||||
version: 1.40.70
|
version: 1.40.70
|
||||||
|
|
@ -2810,7 +2819,7 @@ packages:
|
||||||
- pypi: ./
|
- pypi: ./
|
||||||
name: entropice
|
name: entropice
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
|
sha256: d22e8659bedd1389a563f9cc66c579cad437d279597fa5b21126dda3bb856a30
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- aiohttp>=3.12.11
|
- aiohttp>=3.12.11
|
||||||
- bokeh>=3.7.3
|
- bokeh>=3.7.3
|
||||||
|
|
@ -2865,6 +2874,7 @@ packages:
|
||||||
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
||||||
- cupy-xarray>=0.1.4,<0.2
|
- cupy-xarray>=0.1.4,<0.2
|
||||||
- memray>=1.19.1,<2
|
- memray>=1.19.1,<2
|
||||||
|
- xarray-histogram>=0.2.2,<0.3
|
||||||
requires_python: '>=3.13,<3.14'
|
requires_python: '>=3.13,<3.14'
|
||||||
editable: true
|
editable: true
|
||||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||||
|
|
@ -9373,6 +9383,31 @@ packages:
|
||||||
- pkg:pypi/xarray?source=hash-mapping
|
- pkg:pypi/xarray?source=hash-mapping
|
||||||
size: 989507
|
size: 989507
|
||||||
timestamp: 1763464377176
|
timestamp: 1763464377176
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/14/38/d1a8b0c8b7749fde76daa12ec3e63aa052cf37cacc2e9715377ce0197a99/xarray_histogram-0.2.2-py3-none-any.whl
|
||||||
|
name: xarray-histogram
|
||||||
|
version: 0.2.2
|
||||||
|
sha256: 90682f9575131e5ea5d96feba5f8679e8c3fae039018d5e38f497696d55ccd5c
|
||||||
|
requires_dist:
|
||||||
|
- boost-histogram
|
||||||
|
- numpy
|
||||||
|
- xarray
|
||||||
|
- dask ; extra == 'full'
|
||||||
|
- scipy ; extra == 'full'
|
||||||
|
- dask ; extra == 'dev'
|
||||||
|
- scipy ; extra == 'dev'
|
||||||
|
- sphinx ; extra == 'dev'
|
||||||
|
- sphinx-book-theme ; extra == 'dev'
|
||||||
|
- ruff ; extra == 'dev'
|
||||||
|
- mypy>=1.5 ; extra == 'dev'
|
||||||
|
- pytest>=7.4 ; extra == 'dev'
|
||||||
|
- coverage ; extra == 'dev'
|
||||||
|
- pytest-cov ; extra == 'dev'
|
||||||
|
- dask ; extra == 'tests'
|
||||||
|
- scipy ; extra == 'tests'
|
||||||
|
- pytest>=7.4 ; extra == 'tests'
|
||||||
|
- coverage ; extra == 'tests'
|
||||||
|
- pytest-cov ; extra == 'tests'
|
||||||
|
requires_python: '>=3.11'
|
||||||
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
- pypi: git+https://github.com/relativityhd/xarray-spatial#3a3120981dc910cbfc824bd03d1c1f8637efaf2d
|
||||||
name: xarray-spatial
|
name: xarray-spatial
|
||||||
version: 0.1.dev519+g3a3120981
|
version: 0.1.dev519+g3a3120981
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ dependencies = [
|
||||||
"xgboost>=3.1.1,<4",
|
"xgboost>=3.1.1,<4",
|
||||||
"s3fs>=2025.10.0,<2026",
|
"s3fs>=2025.10.0,<2026",
|
||||||
"xarray-spatial",
|
"xarray-spatial",
|
||||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
|
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
@ -33,6 +32,7 @@ from entropice import grids
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class _Aggregations:
|
class _Aggregations:
|
||||||
# ! The ordering is super important for this class!
|
# ! The ordering is super important for this class!
|
||||||
|
_common: bool = False # If true, use _agg_cell_data_single_common and ignore orther flags
|
||||||
mean: bool = True
|
mean: bool = True
|
||||||
sum: bool = False
|
sum: bool = False
|
||||||
std: bool = False
|
std: bool = False
|
||||||
|
|
@ -46,8 +46,14 @@ class _Aggregations:
|
||||||
if self.median and 0.5 in self.quantiles:
|
if self.median and 0.5 in self.quantiles:
|
||||||
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
|
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def common(cls):
|
||||||
|
return cls(_common=True)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
if self._common:
|
||||||
|
return 11
|
||||||
length = 0
|
length = 0
|
||||||
if self.mean:
|
if self.mean:
|
||||||
length += 1
|
length += 1
|
||||||
|
|
@ -65,6 +71,8 @@ class _Aggregations:
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def aggnames(self) -> list[str]:
|
def aggnames(self) -> list[str]:
|
||||||
|
if self._common:
|
||||||
|
return ["mean", "std", "min", "max", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
|
||||||
names = []
|
names = []
|
||||||
if self.mean:
|
if self.mean:
|
||||||
names.append("mean")
|
names.append("mean")
|
||||||
|
|
@ -83,7 +91,24 @@ class _Aggregations:
|
||||||
names.append(f"p{q_int}")
|
names.append(f"p{q_int}")
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
def _agg_cell_data_single_common(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||||
|
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||||
|
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||||
|
cell_data[0, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
|
||||||
|
cell_data[1, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
|
||||||
|
cell_data[2, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
|
||||||
|
cell_data[3, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
|
||||||
|
quantiles_to_compute = [0.5, 0.01, 0.05, 0.25, 0.75, 0.95, 0.99] # ? Ordering is important here!
|
||||||
|
cell_data[4:, ...] = flattened_var.quantile(
|
||||||
|
q=quantiles_to_compute,
|
||||||
|
dim="z",
|
||||||
|
skipna=True,
|
||||||
|
).to_numpy()
|
||||||
|
return cell_data
|
||||||
|
|
||||||
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
|
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||||
|
if self._common:
|
||||||
|
return self._agg_cell_data_single_common(flattened_var)
|
||||||
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||||
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||||
j = 0
|
j = 0
|
||||||
|
|
@ -106,6 +131,12 @@ class _Aggregations:
|
||||||
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
|
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
|
||||||
if self.median:
|
if self.median:
|
||||||
quantiles_to_compute.insert(0, 0.5)
|
quantiles_to_compute.insert(0, 0.5)
|
||||||
|
|
||||||
|
# Potential way to use cupy:
|
||||||
|
# cp.apply_along_axis(
|
||||||
|
# lambda x: cp.quantile(x[~cp.isnan(x)], q=quantiles), axis=0, # arr=data_gpu_raw
|
||||||
|
# )
|
||||||
|
|
||||||
cell_data[j:, ...] = flattened_var.quantile(
|
cell_data[j:, ...] = flattened_var.quantile(
|
||||||
q=quantiles_to_compute,
|
q=quantiles_to_compute,
|
||||||
dim="z",
|
dim="z",
|
||||||
|
|
@ -173,7 +204,7 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
|
||||||
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||||
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
|
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
|
||||||
flattened = cropped.stack(z=spatdims) # noqa: PD013
|
flattened = cropped.stack(z=spatdims) # noqa: PD013
|
||||||
# if flattened.z.size > 3000:
|
# if flattened.z.size > 1000000:
|
||||||
# flattened = flattened.cupy.as_cupy()
|
# flattened = flattened.cupy.as_cupy()
|
||||||
cell_data = aggregations.agg_cell_data(flattened)
|
cell_data = aggregations.agg_cell_data(flattened)
|
||||||
return cell_data
|
return cell_data
|
||||||
|
|
@ -187,6 +218,8 @@ def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggr
|
||||||
else ["y", "x"]
|
else ["y", "x"]
|
||||||
)
|
)
|
||||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
|
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
|
||||||
|
# if flattened.z.size > 1000000:
|
||||||
|
# flattened = flattened.cupy.as_cupy()
|
||||||
cell_data = aggregations.agg_cell_data(flattened)
|
cell_data = aggregations.agg_cell_data(flattened)
|
||||||
return cell_data
|
return cell_data
|
||||||
|
|
||||||
|
|
@ -218,7 +251,15 @@ def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregati
|
||||||
|
|
||||||
|
|
||||||
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||||
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
|
if grid_gdf.crs.to_epsg() == 4326:
|
||||||
|
crosses_antimeridian = grid_gdf.geometry.apply(_crosses_antimeridian)
|
||||||
|
else:
|
||||||
|
crosses_antimeridian = pd.Series([False] * len(grid_gdf), index=grid_gdf.index)
|
||||||
|
|
||||||
|
if crosses_antimeridian.any():
|
||||||
|
grid_gdf_am = grid_gdf[crosses_antimeridian]
|
||||||
|
grid_gdf = grid_gdf[~crosses_antimeridian]
|
||||||
|
|
||||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||||
|
|
@ -226,6 +267,9 @@ def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[
|
||||||
partition = grid_gdf[labels == i]
|
partition = grid_gdf[labels == i]
|
||||||
yield partition
|
yield partition
|
||||||
|
|
||||||
|
if crosses_antimeridian.any():
|
||||||
|
yield grid_gdf_am
|
||||||
|
|
||||||
|
|
||||||
class _MemoryProfiler:
|
class _MemoryProfiler:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -279,20 +323,20 @@ def _align_partition(
|
||||||
# 1. unmeasurable
|
# 1. unmeasurable
|
||||||
# 2. ~0.1-0.3s
|
# 2. ~0.1-0.3s
|
||||||
# 3. ~0.05s
|
# 3. ~0.05s
|
||||||
|
# => There is a shift towards step 2 being the bottleneck for higher resolution grids, thus a simple loop becomes
|
||||||
|
# faster than a processpoolexecutor
|
||||||
|
|
||||||
memprof.log_memory("Before reading partial raster", log=False)
|
memprof.log_memory("Before reading partial raster", log=False)
|
||||||
|
|
||||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||||
raster = raster()
|
raster = raster()
|
||||||
|
need_to_close_raster = True
|
||||||
if hasattr(raster, "__dask_graph__"):
|
else:
|
||||||
graph_size = len(raster.__dask_graph__())
|
need_to_close_raster = False
|
||||||
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
|
|
||||||
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
|
|
||||||
|
|
||||||
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
||||||
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
ongrid_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||||
data = np.full(data_shape, np.nan, dtype=np.float32)
|
ongrid = np.full(ongrid_shape, np.nan, dtype=np.float32)
|
||||||
|
|
||||||
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||||
partial_extent = partial_extent.buffered(
|
partial_extent = partial_extent.buffered(
|
||||||
|
|
@ -304,57 +348,42 @@ def _align_partition(
|
||||||
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error cropping raster to partition extent: {e}")
|
print(f"Error cropping raster to partition extent: {e}")
|
||||||
return data
|
return ongrid
|
||||||
|
|
||||||
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
|
if partial_raster.nbytes / 1e9 > 20:
|
||||||
memprof.log_memory("After reading partial raster", log=True)
|
print(
|
||||||
|
f"{os.getpid()}: WARNING! Partial raster size is larger than 20GB:"
|
||||||
|
f" {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)."
|
||||||
|
f" This may lead to out-of-memory errors."
|
||||||
|
)
|
||||||
|
memprof.log_memory("After reading partial raster", log=False)
|
||||||
|
|
||||||
with ProcessPoolExecutor(
|
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||||
max_workers=60,
|
try:
|
||||||
) as executor:
|
cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||||
futures = {}
|
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
raise e
|
||||||
# ? Splitting the geometries already here and passing only the cropped data to the worker to
|
except Exception as e:
|
||||||
# reduce pickling overhead
|
print(f"Error processing cell {row['cell_id']}: {e}")
|
||||||
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
|
continue
|
||||||
if len(geoms) == 0:
|
ongrid[i, ...] = cell_data
|
||||||
continue
|
|
||||||
elif len(geoms) == 1:
|
|
||||||
cropped = [_read_cell_data(partial_raster, geoms[0])]
|
|
||||||
else:
|
|
||||||
cropped = _read_split_cell_data(partial_raster, geoms)
|
|
||||||
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
|
|
||||||
|
|
||||||
# Clean up immediately after submitting
|
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||||
del geoms, cropped # Added this line
|
dims = ["cell_ids", "variables", "aggregations"]
|
||||||
|
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||||
|
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||||
|
dims.append(dim)
|
||||||
|
coords[dim] = raster.coords[dim]
|
||||||
|
|
||||||
for future in as_completed(futures):
|
ongrid = xr.DataArray(ongrid, dims=dims, coords=coords).to_dataset("variables")
|
||||||
i = futures[future]
|
|
||||||
try:
|
|
||||||
cell_data = future.result()
|
|
||||||
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing cell {i}: {e}")
|
|
||||||
continue
|
|
||||||
data[i, ...] = cell_data
|
|
||||||
del cell_data, futures[future]
|
|
||||||
|
|
||||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
|
||||||
# try:
|
|
||||||
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
|
||||||
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
|
||||||
# raise e
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Error processing cell {row['cell_id']}: {e}")
|
|
||||||
# continue
|
|
||||||
# data[i, ...] = cell_data
|
|
||||||
|
|
||||||
partial_raster.close()
|
partial_raster.close()
|
||||||
raster.close()
|
del partial_raster
|
||||||
del partial_raster, raster, future, futures
|
if need_to_close_raster:
|
||||||
|
raster.close()
|
||||||
|
del raster
|
||||||
gc.collect()
|
gc.collect()
|
||||||
memprof.log_memory("After cleaning", log=True)
|
memprof.log_memory("After cleaning", log=False)
|
||||||
|
|
||||||
print("Finished processing partition")
|
print("Finished processing partition")
|
||||||
print("### Stopwatch summary ###\n")
|
print("### Stopwatch summary ###\n")
|
||||||
|
|
@ -362,45 +391,8 @@ def _align_partition(
|
||||||
print("### Memory summary ###\n")
|
print("### Memory summary ###\n")
|
||||||
print(memprof.summary())
|
print(memprof.summary())
|
||||||
print("#########################")
|
print("#########################")
|
||||||
# with ProcessPoolExecutor(
|
|
||||||
# max_workers=max_workers,
|
|
||||||
# initializer=_init_partial_raster_global,
|
|
||||||
# initargs=(raster, grid_partition_gdf, pxbuffer),
|
|
||||||
# ) as executor:
|
|
||||||
# futures = {
|
|
||||||
# executor.submit(_process_geom_global, row.geometry, aggregations): i
|
|
||||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
|
|
||||||
# }
|
|
||||||
|
|
||||||
# others_shape = tuple(
|
return ongrid
|
||||||
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
|
||||||
# )
|
|
||||||
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
|
|
||||||
# print(f"Allocating partial data array of shape {data_shape}")
|
|
||||||
# data = np.full(data_shape, np.nan, dtype=np.float32)
|
|
||||||
|
|
||||||
# lap = time.time()
|
|
||||||
# for future in track(
|
|
||||||
# as_completed(futures),
|
|
||||||
# total=len(futures),
|
|
||||||
# description="Spatially aggregating ERA5 data...",
|
|
||||||
# ):
|
|
||||||
# print(f"time since last cell: {time.time() - lap:.2f}s")
|
|
||||||
# i = futures[future]
|
|
||||||
# try:
|
|
||||||
# cell_data = future.result()
|
|
||||||
# data[i, ...] = cell_data
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Error processing cell {i}: {e}")
|
|
||||||
# continue
|
|
||||||
# finally:
|
|
||||||
# # Try to free memory
|
|
||||||
# del futures[future], future
|
|
||||||
# if "cell_data" in locals():
|
|
||||||
# del cell_data
|
|
||||||
# gc.collect()
|
|
||||||
# lap = time.time()
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Aligning data with grid")
|
@stopwatch("Aligning data with grid")
|
||||||
|
|
@ -409,47 +401,66 @@ def _align_data(
|
||||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
aggregations: _Aggregations,
|
aggregations: _Aggregations,
|
||||||
n_partitions: int,
|
n_partitions: int,
|
||||||
max_workers: int,
|
concurrent_partitions: int,
|
||||||
pxbuffer: int,
|
pxbuffer: int,
|
||||||
):
|
):
|
||||||
partial_data = {}
|
partial_ongrids = []
|
||||||
|
|
||||||
with ProcessPoolExecutor(
|
_init_worker()
|
||||||
max_workers=max_workers,
|
|
||||||
initializer=_init_worker,
|
if n_partitions < concurrent_partitions:
|
||||||
# initializer=_init_raster_global,
|
print(f"Adjusting concurrent_partitions from {concurrent_partitions} to {n_partitions}")
|
||||||
# initargs=(raster,),
|
concurrent_partitions = n_partitions
|
||||||
) as executor:
|
|
||||||
futures = {}
|
if concurrent_partitions <= 1:
|
||||||
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||||
futures[
|
print(f"Processing partition {i + 1}/{n_partitions} with {len(grid_partition)} cells")
|
||||||
executor.submit(
|
part_ongrid = _align_partition(
|
||||||
_align_partition,
|
grid_partition,
|
||||||
grid_partition,
|
raster, # .copy() if isinstance(raster, xr.Dataset) else raster,
|
||||||
raster.copy() if isinstance(raster, xr.Dataset) else raster,
|
aggregations,
|
||||||
aggregations,
|
pxbuffer,
|
||||||
pxbuffer,
|
)
|
||||||
)
|
partial_ongrids.append(part_ongrid)
|
||||||
] = i
|
else:
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=concurrent_partitions,
|
||||||
|
initializer=_init_worker,
|
||||||
|
# initializer=_init_raster_global,
|
||||||
|
# initargs=(raster,),
|
||||||
|
) as executor:
|
||||||
|
futures = {}
|
||||||
|
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||||
|
futures[
|
||||||
|
executor.submit(
|
||||||
|
_align_partition,
|
||||||
|
grid_partition,
|
||||||
|
raster.copy() if isinstance(raster, xr.Dataset) else raster,
|
||||||
|
aggregations,
|
||||||
|
pxbuffer,
|
||||||
|
)
|
||||||
|
] = i
|
||||||
|
if i == 6:
|
||||||
|
print("Breaking after 3 partitions for testing purposes")
|
||||||
|
|
||||||
print("Submitted all partitions, waiting for results...")
|
print("Submitted all partitions, waiting for results...")
|
||||||
|
|
||||||
for future in track(
|
for future in track(
|
||||||
as_completed(futures),
|
as_completed(futures),
|
||||||
total=len(futures),
|
total=len(futures),
|
||||||
description="Processing grid partitions...",
|
description="Processing grid partitions...",
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
i = futures[future]
|
i = futures[future]
|
||||||
print(f"Processed partition {i + 1}/{len(futures)}")
|
print(f"Processed partition {i + 1}/{len(futures)}")
|
||||||
part_data = future.result()
|
part_ongrid = future.result()
|
||||||
partial_data[i] = part_data
|
partial_ongrids.append(part_ongrid)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing partition {i}: {e}")
|
print(f"Error processing partition {i}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
|
ongrid = xr.concat(partial_ongrids, dim="cell_ids")
|
||||||
return data
|
return ongrid
|
||||||
|
|
||||||
|
|
||||||
def aggregate_raster_into_grid(
|
def aggregate_raster_into_grid(
|
||||||
|
|
@ -459,7 +470,7 @@ def aggregate_raster_into_grid(
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Literal["hex", "healpix"],
|
||||||
level: int,
|
level: int,
|
||||||
n_partitions: int = 20,
|
n_partitions: int = 20,
|
||||||
max_workers: int = 5,
|
concurrent_partitions: int = 5,
|
||||||
pxbuffer: int = 15,
|
pxbuffer: int = 15,
|
||||||
):
|
):
|
||||||
"""Aggregate raster data into grid cells.
|
"""Aggregate raster data into grid cells.
|
||||||
|
|
@ -471,19 +482,20 @@ def aggregate_raster_into_grid(
|
||||||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||||
level (int): The level of the grid.
|
level (int): The level of the grid.
|
||||||
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||||
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
|
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
||||||
|
Defaults to 5.
|
||||||
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
xr.Dataset: Aggregated data aligned with the grid.
|
xr.Dataset: Aggregated data aligned with the grid.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
aligned = _align_data(
|
ongrid = _align_data(
|
||||||
grid_gdf,
|
grid_gdf,
|
||||||
raster,
|
raster,
|
||||||
aggregations,
|
aggregations,
|
||||||
n_partitions=n_partitions,
|
n_partitions=n_partitions,
|
||||||
max_workers=max_workers,
|
concurrent_partitions=concurrent_partitions,
|
||||||
pxbuffer=pxbuffer,
|
pxbuffer=pxbuffer,
|
||||||
)
|
)
|
||||||
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
|
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
|
||||||
|
|
@ -491,16 +503,6 @@ def aggregate_raster_into_grid(
|
||||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||||
raster = raster()
|
raster = raster()
|
||||||
|
|
||||||
cell_ids = grids.get_cell_ids(grid, level)
|
|
||||||
|
|
||||||
dims = ["cell_ids", "variables", "aggregations"]
|
|
||||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
|
||||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
|
||||||
dims.append(dim)
|
|
||||||
coords[dim] = raster.coords[dim]
|
|
||||||
|
|
||||||
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
|
|
||||||
|
|
||||||
gridinfo = {
|
gridinfo = {
|
||||||
"grid_name": "h3" if grid == "hex" else grid,
|
"grid_name": "h3" if grid == "hex" else grid,
|
||||||
"level": level,
|
"level": level,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import cupyx.scipy.signal
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import dask.array
|
import dask.array
|
||||||
import dask.distributed as dd
|
import dask.distributed as dd
|
||||||
|
import geopandas as gpd
|
||||||
import icechunk
|
import icechunk
|
||||||
import icechunk.xarray
|
import icechunk.xarray
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -48,7 +49,7 @@ def download():
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KernelFactory:
|
class _KernelFactory:
|
||||||
res: int
|
res: int
|
||||||
size_px: int
|
size_px: int
|
||||||
|
|
||||||
|
|
@ -77,14 +78,14 @@ class KernelFactory:
|
||||||
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
|
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
|
||||||
|
|
||||||
|
|
||||||
def tpi_cupy(chunk, kernels: KernelFactory):
|
def tpi_cupy(chunk, kernels: _KernelFactory):
|
||||||
kernel = kernels.ring()
|
kernel = kernels.ring()
|
||||||
kernel = kernel / cp.nansum(kernel)
|
kernel = kernel / cp.nansum(kernel)
|
||||||
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||||
return tpi
|
return tpi
|
||||||
|
|
||||||
|
|
||||||
def tri_cupy(chunk, kernels: KernelFactory):
|
def tri_cupy(chunk, kernels: _KernelFactory):
|
||||||
kernel = kernels.tri()
|
kernel = kernels.tri()
|
||||||
c2 = chunk**2
|
c2 = chunk**2
|
||||||
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||||
|
|
@ -93,7 +94,7 @@ def tri_cupy(chunk, kernels: KernelFactory):
|
||||||
return tri
|
return tri
|
||||||
|
|
||||||
|
|
||||||
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
|
def ruggedness_cupy(chunk, slope, aspect, kernels: _KernelFactory):
|
||||||
slope_rad = slope * (cp.pi / 180)
|
slope_rad = slope * (cp.pi / 180)
|
||||||
aspect_rad = aspect * (cp.pi / 180)
|
aspect_rad = aspect * (cp.pi / 180)
|
||||||
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
|
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
|
||||||
|
|
@ -150,9 +151,9 @@ def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
|
||||||
|
|
||||||
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
|
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
|
||||||
res = 32 # 32m resolution
|
res = 32 # 32m resolution
|
||||||
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
small_kernels = _KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
||||||
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
medium_kernels = _KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
||||||
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
|
large_kernels = _KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
|
||||||
|
|
||||||
# Check if there is data in the chunk
|
# Check if there is data in the chunk
|
||||||
if np.all(np.isnan(chunk)):
|
if np.all(np.isnan(chunk)):
|
||||||
|
|
@ -306,49 +307,61 @@ def enrich():
|
||||||
def _open_adem():
|
def _open_adem():
|
||||||
arcticdem_store = get_arcticdem_stores()
|
arcticdem_store = get_arcticdem_stores()
|
||||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
|
# config = icechunk.RepositoryConfig.default()
|
||||||
|
# config.caching = icechunk.CachingConfig(
|
||||||
|
# num_snapshot_nodes=None,
|
||||||
|
# num_chunk_refs=None,
|
||||||
|
# num_transaction_changes=None,
|
||||||
|
# num_bytes_attributes=None,
|
||||||
|
# num_bytes_chunks=None,
|
||||||
|
# )
|
||||||
|
# repo = icechunk.Repository.open(accessor.repo.storage, config=config)
|
||||||
|
# session = repo.readonly_session("main")
|
||||||
|
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
|
||||||
adem = accessor.open_xarray()
|
adem = accessor.open_xarray()
|
||||||
|
# Don't use the datamask
|
||||||
|
adem = adem.drop_vars("datamask")
|
||||||
assert {"x", "y"} == set(adem.dims)
|
assert {"x", "y"} == set(adem.dims)
|
||||||
assert adem.odc.crs == "EPSG:3413"
|
assert adem.odc.crs == "EPSG:3413"
|
||||||
return adem
|
return adem
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
|
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20):
|
||||||
mp.set_start_method("forkserver", force=True)
|
mp.set_start_method("forkserver", force=True)
|
||||||
|
with (
|
||||||
|
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
|
||||||
|
):
|
||||||
|
print(cluster)
|
||||||
|
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||||
|
grid_gdf = grids.open(grid, level)
|
||||||
|
# ? Mask out water, since we don't want to aggregate over oceans
|
||||||
|
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||||
|
# ? Mask out cells that do not intersect with ArcticDEM extent
|
||||||
|
arcticdem_store = get_arcticdem_stores()
|
||||||
|
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
|
extent_info = gpd.read_parquet(accessor._aux_dir / "ArcticDEM_Mosaic_Index_v4_1_32m.parquet")
|
||||||
|
grid_gdf = grid_gdf[grid_gdf.intersects(extent_info.geometry.union_all())]
|
||||||
|
|
||||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
aggregations = _Aggregations.common()
|
||||||
grid_gdf = grids.open(grid, level)
|
aggregated = aggregate_raster_into_grid(
|
||||||
# ? Mask out water, since we don't want to aggregate over oceans
|
_open_adem,
|
||||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
grid_gdf,
|
||||||
|
aggregations,
|
||||||
|
grid,
|
||||||
|
level,
|
||||||
|
concurrent_partitions=concurrent_partitions,
|
||||||
|
n_partitions=400, # 400 # Results in ~10GB large patches
|
||||||
|
pxbuffer=30,
|
||||||
|
)
|
||||||
|
store = get_arcticdem_stores(grid, level)
|
||||||
|
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
||||||
|
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||||
|
|
||||||
aggregations = _Aggregations(
|
print("Aggregation complete.")
|
||||||
mean=True,
|
|
||||||
sum=False,
|
|
||||||
std=True,
|
|
||||||
min=True,
|
|
||||||
max=True,
|
|
||||||
median=True,
|
|
||||||
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
|
|
||||||
)
|
|
||||||
|
|
||||||
aggregated = aggregate_raster_into_grid(
|
print("### Finished ArcticDEM processing ###")
|
||||||
_open_adem,
|
stopwatch.summary()
|
||||||
grid_gdf,
|
|
||||||
aggregations,
|
|
||||||
grid,
|
|
||||||
level,
|
|
||||||
max_workers=max_workers,
|
|
||||||
n_partitions=200, # Results in 10GB large patches
|
|
||||||
pxbuffer=30,
|
|
||||||
)
|
|
||||||
store = get_arcticdem_stores(grid, level)
|
|
||||||
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
|
||||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
|
||||||
|
|
||||||
print("Aggregation complete.")
|
|
||||||
|
|
||||||
print("### Finished ArcticDEM processing ###")
|
|
||||||
stopwatch.summary()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -677,25 +677,7 @@ def spatial_agg(
|
||||||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||||
unaligned = _correct_longs(unaligned)
|
unaligned = _correct_longs(unaligned)
|
||||||
|
|
||||||
# with stopwatch("Precomputing cell geometries"):
|
aggregations = _Aggregations.common()
|
||||||
# with ProcessPoolExecutor(max_workers=20) as executor:
|
|
||||||
# cell_geometries = list(
|
|
||||||
# executor.map(
|
|
||||||
# _get_corrected_geoms,
|
|
||||||
# [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# data = _align_data(cell_geometries, unaligned)
|
|
||||||
# aggregated = _create_aligned(unaligned, data, grid, level)
|
|
||||||
|
|
||||||
aggregations = _Aggregations(
|
|
||||||
mean=True,
|
|
||||||
sum=False,
|
|
||||||
std=True,
|
|
||||||
min=True,
|
|
||||||
max=True,
|
|
||||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
|
||||||
)
|
|
||||||
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
||||||
|
|
||||||
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,27 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
|
||||||
return cell_ids
|
return cell_ids
|
||||||
|
|
||||||
|
|
||||||
|
def convert_cell_ids(grid_gdf: gpd.GeoDataFrame):
|
||||||
|
"""Convert cell IDs in a GeoDataFrame to xdggs-compatible format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_gdf (gpd.GeoDataFrame): The input grid GeoDataFrame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.Series: The the converted cell IDs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if "cell_id" not in grid_gdf.columns:
|
||||||
|
raise ValueError("The GeoDataFrame must contain a 'cell_id' column.")
|
||||||
|
|
||||||
|
# Check if cell IDs are hex strings (for H3)
|
||||||
|
if isinstance(grid_gdf["cell_id"].iloc[0], str):
|
||||||
|
grid_gdf = grid_gdf.copy()
|
||||||
|
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda cid: int(cid, 16))
|
||||||
|
|
||||||
|
return grid_gdf["cell_id"]
|
||||||
|
|
||||||
|
|
||||||
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
|
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
|
||||||
hex_batch = []
|
hex_batch = []
|
||||||
hex_id_batch = []
|
hex_id_batch = []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue