leaking grid aggregations
This commit is contained in:
parent
18cc1b8601
commit
341fa7b836
5 changed files with 521 additions and 154 deletions
85
pixi.lock
generated
85
pixi.lock
generated
|
|
@ -543,6 +543,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
|
||||||
|
|
@ -594,6 +595,7 @@ environments:
|
||||||
- pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
|
||||||
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
|
||||||
|
|
@ -2808,7 +2810,7 @@ packages:
|
||||||
- pypi: ./
|
- pypi: ./
|
||||||
name: entropice
|
name: entropice
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
sha256: d95c691c76206bf54e207fe02b100a247a3847f37135d1cdf6ee18165770ea46
|
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- aiohttp>=3.12.11
|
- aiohttp>=3.12.11
|
||||||
- bokeh>=3.7.3
|
- bokeh>=3.7.3
|
||||||
|
|
@ -2862,6 +2864,7 @@ packages:
|
||||||
- s3fs>=2025.10.0,<2026
|
- s3fs>=2025.10.0,<2026
|
||||||
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
||||||
- cupy-xarray>=0.1.4,<0.2
|
- cupy-xarray>=0.1.4,<0.2
|
||||||
|
- memray>=1.19.1,<2
|
||||||
requires_python: '>=3.13,<3.14'
|
requires_python: '>=3.13,<3.14'
|
||||||
editable: true
|
editable: true
|
||||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||||
|
|
@ -6344,6 +6347,58 @@ packages:
|
||||||
- pkg:pypi/mdurl?source=hash-mapping
|
- pkg:pypi/mdurl?source=hash-mapping
|
||||||
size: 14465
|
size: 14465
|
||||||
timestamp: 1733255681319
|
timestamp: 1733255681319
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
|
name: memray
|
||||||
|
version: 1.19.1
|
||||||
|
sha256: 9770b6046a8b7d7bb982c2cefb324a60a0ee7cd7c35c58c0df058355a9a54676
|
||||||
|
requires_dist:
|
||||||
|
- jinja2>=2.9
|
||||||
|
- typing-extensions ; python_full_version < '3.8'
|
||||||
|
- rich>=11.2.0
|
||||||
|
- textual>=0.41.0
|
||||||
|
- cython ; extra == 'test'
|
||||||
|
- greenlet ; python_full_version < '3.14' and extra == 'test'
|
||||||
|
- pytest ; extra == 'test'
|
||||||
|
- pytest-cov ; extra == 'test'
|
||||||
|
- ipython ; extra == 'test'
|
||||||
|
- setuptools ; extra == 'test'
|
||||||
|
- pytest-textual-snapshot ; extra == 'test'
|
||||||
|
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'test'
|
||||||
|
- packaging ; extra == 'test'
|
||||||
|
- ipython ; extra == 'docs'
|
||||||
|
- bump2version ; extra == 'docs'
|
||||||
|
- sphinx ; extra == 'docs'
|
||||||
|
- furo ; extra == 'docs'
|
||||||
|
- sphinx-argparse ; extra == 'docs'
|
||||||
|
- towncrier ; extra == 'docs'
|
||||||
|
- black ; extra == 'lint'
|
||||||
|
- flake8 ; extra == 'lint'
|
||||||
|
- isort ; extra == 'lint'
|
||||||
|
- mypy ; extra == 'lint'
|
||||||
|
- check-manifest ; extra == 'lint'
|
||||||
|
- asv ; extra == 'benchmark'
|
||||||
|
- cython ; extra == 'dev'
|
||||||
|
- greenlet ; python_full_version < '3.14' and extra == 'dev'
|
||||||
|
- pytest ; extra == 'dev'
|
||||||
|
- pytest-cov ; extra == 'dev'
|
||||||
|
- ipython ; extra == 'dev'
|
||||||
|
- setuptools ; extra == 'dev'
|
||||||
|
- pytest-textual-snapshot ; extra == 'dev'
|
||||||
|
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'dev'
|
||||||
|
- packaging ; extra == 'dev'
|
||||||
|
- black ; extra == 'dev'
|
||||||
|
- flake8 ; extra == 'dev'
|
||||||
|
- isort ; extra == 'dev'
|
||||||
|
- mypy ; extra == 'dev'
|
||||||
|
- check-manifest ; extra == 'dev'
|
||||||
|
- ipython ; extra == 'dev'
|
||||||
|
- bump2version ; extra == 'dev'
|
||||||
|
- sphinx ; extra == 'dev'
|
||||||
|
- furo ; extra == 'dev'
|
||||||
|
- sphinx-argparse ; extra == 'dev'
|
||||||
|
- towncrier ; extra == 'dev'
|
||||||
|
- asv ; extra == 'dev'
|
||||||
|
requires_python: '>=3.7.0'
|
||||||
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||||
name: mercantile
|
name: mercantile
|
||||||
version: 1.2.1
|
version: 1.2.1
|
||||||
|
|
@ -8866,6 +8921,34 @@ packages:
|
||||||
- pkg:pypi/terminado?source=hash-mapping
|
- pkg:pypi/terminado?source=hash-mapping
|
||||||
size: 22452
|
size: 22452
|
||||||
timestamp: 1710262728753
|
timestamp: 1710262728753
|
||||||
|
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
|
||||||
|
name: textual
|
||||||
|
version: 6.6.0
|
||||||
|
sha256: 5a9484bd15ee8a6fd8ac4ed4849fb25ee56bed2cecc7b8a83c4cd7d5f19515e5
|
||||||
|
requires_dist:
|
||||||
|
- markdown-it-py[linkify]>=2.1.0
|
||||||
|
- mdit-py-plugins
|
||||||
|
- platformdirs>=3.6.0,<5
|
||||||
|
- pygments>=2.19.2,<3.0.0
|
||||||
|
- rich>=14.2.0
|
||||||
|
- tree-sitter>=0.25.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-bash>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-css>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-go>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-html>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-java>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-javascript>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-json>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-markdown>=0.3.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-python>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-regex>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-rust>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-sql>=0.3.11 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-toml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-xml>=0.7.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- tree-sitter-yaml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||||
|
- typing-extensions>=4.4.0,<5.0.0
|
||||||
|
requires_python: '>=3.9,<4.0'
|
||||||
- conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda
|
- conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda
|
||||||
sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd
|
sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd
|
||||||
md5: 9d64911b31d57ca443e9f1e36b04385f
|
md5: 9d64911b31d57ca443e9f1e36b04385f
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ dependencies = [
|
||||||
"xgboost>=3.1.1,<4",
|
"xgboost>=3.1.1,<4",
|
||||||
"s3fs>=2025.10.0,<2026",
|
"s3fs>=2025.10.0,<2026",
|
||||||
"xarray-spatial",
|
"xarray-spatial",
|
||||||
"cupy-xarray>=0.1.4,<0.2",
|
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,27 @@
|
||||||
|
"""Aggregation helpers."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cache
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cupy_xarray
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import odc.geo.geobox
|
import odc.geo.geobox
|
||||||
|
import pandas as pd
|
||||||
|
import psutil
|
||||||
import shapely
|
import shapely
|
||||||
import shapely.ops
|
import shapely.ops
|
||||||
|
import sklearn
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
import xvec
|
import xvec
|
||||||
|
from rich import print
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from shapely.geometry import LineString, Polygon
|
from shapely.geometry import LineString, Polygon
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
@ -19,34 +30,100 @@ from xdggs.healpix import HealpixInfo
|
||||||
from entropice import grids
|
from entropice import grids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class _Aggregations:
|
class _Aggregations:
|
||||||
|
# ! The ordering is super important for this class!
|
||||||
mean: bool = True
|
mean: bool = True
|
||||||
sum: bool = False
|
sum: bool = False
|
||||||
std: bool = False
|
std: bool = False
|
||||||
min: bool = False
|
min: bool = False
|
||||||
max: bool = False
|
max: bool = False
|
||||||
quantiles: list[float] = field(default_factory=lambda: [])
|
median: bool = False # Will be mapped as quantile 0.5
|
||||||
|
quantiles: tuple[float] = field(default_factory=tuple)
|
||||||
|
|
||||||
def varnames(self, vars: list[str] | str) -> list[str]:
|
def __post_init__(self):
|
||||||
if isinstance(vars, str):
|
assert isinstance(self.quantiles, tuple), "Quantiles must be a tuple, ortherwise the class is not hashable"
|
||||||
vars = [vars]
|
if self.median and 0.5 in self.quantiles:
|
||||||
agg_vars = []
|
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
|
||||||
for var in vars:
|
|
||||||
|
@cache
|
||||||
|
def __len__(self) -> int:
|
||||||
|
length = 0
|
||||||
if self.mean:
|
if self.mean:
|
||||||
agg_vars.append(f"{var}_mean")
|
length += 1
|
||||||
if self.sum:
|
if self.sum:
|
||||||
agg_vars.append(f"{var}_sum")
|
length += 1
|
||||||
if self.std:
|
if self.std:
|
||||||
agg_vars.append(f"{var}_std")
|
length += 1
|
||||||
if self.min:
|
if self.min:
|
||||||
agg_vars.append(f"{var}_min")
|
length += 1
|
||||||
if self.max:
|
if self.max:
|
||||||
agg_vars.append(f"{var}_max")
|
length += 1
|
||||||
for q in self.quantiles:
|
if self.median and 0.5 not in self.quantiles:
|
||||||
|
length += 1
|
||||||
|
length += len(self.quantiles)
|
||||||
|
return length
|
||||||
|
|
||||||
|
def aggnames(self) -> list[str]:
|
||||||
|
names = []
|
||||||
|
if self.mean:
|
||||||
|
names.append("mean")
|
||||||
|
if self.sum:
|
||||||
|
names.append("sum")
|
||||||
|
if self.std:
|
||||||
|
names.append("std")
|
||||||
|
if self.min:
|
||||||
|
names.append("min")
|
||||||
|
if self.max:
|
||||||
|
names.append("max")
|
||||||
|
if self.median:
|
||||||
|
names.append("median")
|
||||||
|
for q in sorted(self.quantiles): # ? Ordering is important here!
|
||||||
q_int = int(q * 100)
|
q_int = int(q * 100)
|
||||||
agg_vars.append(f"{var}_p{q_int}")
|
names.append(f"p{q_int}")
|
||||||
return agg_vars
|
return names
|
||||||
|
|
||||||
|
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||||
|
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||||
|
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||||
|
j = 0
|
||||||
|
if self.mean:
|
||||||
|
cell_data[j, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
|
||||||
|
j += 1
|
||||||
|
if self.sum:
|
||||||
|
cell_data[j, ...] = flattened_var.sum(dim="z", skipna=True).to_numpy()
|
||||||
|
j += 1
|
||||||
|
if self.std:
|
||||||
|
cell_data[j, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
|
||||||
|
j += 1
|
||||||
|
if self.min:
|
||||||
|
cell_data[j, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
|
||||||
|
j += 1
|
||||||
|
if self.max:
|
||||||
|
cell_data[j, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
|
||||||
|
j += 1
|
||||||
|
if len(self.quantiles) > 0 or self.median:
|
||||||
|
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
|
||||||
|
if self.median:
|
||||||
|
quantiles_to_compute.insert(0, 0.5)
|
||||||
|
cell_data[j:, ...] = flattened_var.quantile(
|
||||||
|
q=quantiles_to_compute,
|
||||||
|
dim="z",
|
||||||
|
skipna=True,
|
||||||
|
).to_numpy()
|
||||||
|
|
||||||
|
def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray:
|
||||||
|
if isinstance(flattened, xr.DataArray):
|
||||||
|
return self._agg_cell_data_single(flattened)
|
||||||
|
|
||||||
|
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
|
||||||
|
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
|
||||||
|
for i, var in enumerate(flattened.data_vars):
|
||||||
|
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
|
||||||
|
# Transform to numpy arrays
|
||||||
|
# if flattened.cupy.is_cupy:
|
||||||
|
# cell_data = cell_data.get()
|
||||||
|
return cell_data
|
||||||
|
|
||||||
|
|
||||||
def _crosses_antimeridian(geom: Polygon) -> bool:
|
def _crosses_antimeridian(geom: Polygon) -> bool:
|
||||||
|
|
@ -77,11 +154,11 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Getting corrected geometries", log=False)
|
@stopwatch("Correcting geometries", log=False)
|
||||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
||||||
geom, gbox, crs = inp
|
geom, gbox, crs = inp
|
||||||
# cell.geometry is a shapely Polygon
|
# cell.geometry is a shapely Polygon
|
||||||
if not _crosses_antimeridian(geom):
|
if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
|
||||||
geoms = [geom]
|
geoms = [geom]
|
||||||
# Split geometry in case it crossed antimeridian
|
# Split geometry in case it crossed antimeridian
|
||||||
else:
|
else:
|
||||||
|
|
@ -92,153 +169,338 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
|
||||||
return geoms
|
return geoms
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Correcting geometries")
|
|
||||||
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
|
|
||||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
|
|
||||||
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
|
|
||||||
|
|
||||||
"""
|
|
||||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
|
||||||
cell_geometries = list(
|
|
||||||
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
|
|
||||||
)
|
|
||||||
return cell_geometries
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Aggregating cell data", log=False)
|
|
||||||
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
|
|
||||||
cell_data = {}
|
|
||||||
for var in flattened.data_vars:
|
|
||||||
if aggregations.mean:
|
|
||||||
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
|
|
||||||
if aggregations.sum:
|
|
||||||
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
|
|
||||||
if aggregations.std:
|
|
||||||
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
|
|
||||||
if aggregations.min:
|
|
||||||
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
|
|
||||||
if aggregations.max:
|
|
||||||
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
|
|
||||||
if len(aggregations.quantiles) > 0:
|
|
||||||
quantile_values = flattened[var].quantile(
|
|
||||||
q=aggregations.quantiles,
|
|
||||||
dim="z",
|
|
||||||
skipna=True,
|
|
||||||
)
|
|
||||||
for q, qv in zip(aggregations.quantiles, quantile_values):
|
|
||||||
q_int = int(q * 100)
|
|
||||||
cell_data[f"{var}_p{q_int}"] = qv
|
|
||||||
# Transform to numpy arrays
|
|
||||||
for key in cell_data:
|
|
||||||
cell_data[key] = cell_data[key].cupy.as_numpy().values
|
|
||||||
return cell_data
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting cell data", log=False)
|
@stopwatch("Extracting cell data", log=False)
|
||||||
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
|
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
|
||||||
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
|
flattened = cropped.stack(z=spatdims) # noqa: PD013
|
||||||
flattened = cropped.stack(z=spatdims)
|
# if flattened.z.size > 3000:
|
||||||
if flattened.z.size > 3000:
|
# flattened = flattened.cupy.as_cupy()
|
||||||
flattened = flattened.cupy.as_cupy()
|
cell_data = aggregations.agg_cell_data(flattened)
|
||||||
cell_data = _agg_cell_data(flattened, aggregations)
|
|
||||||
return cell_data
|
return cell_data
|
||||||
|
|
||||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
# cell_data = cropped.mean(spatdims)
|
|
||||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting split cell data", log=False)
|
@stopwatch("Extracting split cell data", log=False)
|
||||||
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
|
def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations):
|
||||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
spatdims = (
|
||||||
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
|
["latitude", "longitude"]
|
||||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
|
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
|
||||||
cell_data = _agg_cell_data(flattened, aggregations)
|
else ["y", "x"]
|
||||||
|
)
|
||||||
|
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
|
||||||
|
cell_data = aggregations.agg_cell_data(flattened)
|
||||||
return cell_data
|
return cell_data
|
||||||
|
|
||||||
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
|
|
||||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
@stopwatch("Cropping (and reading?) cell data", log=False)
|
||||||
# partial_means = [part.sum(dim=spatdims) for part in parts]
|
def _read_cell_data(unaligned: xr.Dataset | xr.DataArray, geom: odc.geo.Geometry) -> xr.Dataset:
|
||||||
# n = xr.concat(partial_counts, dim="part").sum("part")
|
cropped: xr.Dataset = unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref")
|
||||||
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
return cropped.compute()
|
||||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
|
||||||
|
|
||||||
|
@stopwatch("Cropping (and reading?) split cell data", log=False)
|
||||||
|
def _read_split_cell_data(unaligned: xr.Dataset | xr.DataArray, geoms: list[odc.geo.Geometry]) -> list[xr.Dataset]:
|
||||||
|
cropped: list[xr.Dataset] = [unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") for geom in geoms]
|
||||||
|
return [c.compute() for c in cropped]
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Processing cell", log=False)
|
||||||
|
def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||||
|
geoms = _get_corrected_geoms((poly, unaligned.odc.geobox, unaligned.odc.crs))
|
||||||
|
if len(geoms) == 0:
|
||||||
|
raise ValueError("No geometries to process")
|
||||||
|
elif len(geoms) == 1:
|
||||||
|
cropped = _read_cell_data(unaligned, geoms[0])
|
||||||
|
cell_data = _extract_cell_data(cropped, aggregations)
|
||||||
|
else:
|
||||||
|
cropped = _read_split_cell_data(unaligned, geoms)
|
||||||
|
cell_data = _extract_split_cell_data(cropped, aggregations)
|
||||||
|
return cell_data
|
||||||
|
|
||||||
|
|
||||||
|
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||||
|
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
|
||||||
|
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||||
|
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||||
|
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||||
|
for i in range(n_partitions):
|
||||||
|
partition = grid_gdf[labels == i]
|
||||||
|
yield partition
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryProfiler:
|
||||||
|
def __init__(self):
|
||||||
|
self.process = psutil.Process()
|
||||||
|
self._logs = defaultdict(list)
|
||||||
|
|
||||||
|
def log_memory(self, message: str, log: bool = True):
|
||||||
|
mem = self.process.memory_info().rss / 1024**3
|
||||||
|
self._logs[message].append(mem)
|
||||||
|
if log:
|
||||||
|
print(f"[MemoryProfiler] {message}: {mem:.2f} GB")
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
summary_lines = ["[MemoryProfiler] Memory usage summary:"]
|
||||||
|
for message, mems in self._logs.items():
|
||||||
|
mems_str = [f"{m:.2f}" for m in mems]
|
||||||
|
summary_lines.append(f" {message}: min={min(mems):.2f} GB, max={max(mems):.2f} GB")
|
||||||
|
summary_lines.append(f" Measurements: {', '.join(mems_str)}")
|
||||||
|
return "\n".join(summary_lines)
|
||||||
|
|
||||||
|
|
||||||
|
memprof = None
|
||||||
|
|
||||||
|
|
||||||
|
def _init_worker():
|
||||||
|
global memprof
|
||||||
|
memprof = _MemoryProfiler()
|
||||||
|
|
||||||
|
|
||||||
|
def _align_partition(
|
||||||
|
grid_partition_gdf: gpd.GeoDataFrame,
|
||||||
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
|
aggregations: _Aggregations,
|
||||||
|
pxbuffer: int,
|
||||||
|
):
|
||||||
|
# Strategy for each cell:
|
||||||
|
# 1. Correct the geometry to account for the antimeridian
|
||||||
|
# 2. Cop the dataset and load the data into memory
|
||||||
|
# 3. Aggregate the data
|
||||||
|
# 4. Store the aggregated data in the correct location in the output arrays
|
||||||
|
|
||||||
|
# Steps 1-3 are done in different Processes, since all of these steps are CPU bound
|
||||||
|
# Even the reading part is CPU bound, since it involves applying masks and cropping the data
|
||||||
|
# Also, the backend (icechunk or zarr) can do some computations, like combining chunks together
|
||||||
|
|
||||||
|
# For the Arcticdem on Hex3 grid the avg. time per task is:
|
||||||
|
# 1. ~0.01s
|
||||||
|
# 2. ~1-3s
|
||||||
|
# 3. ~10-20s
|
||||||
|
# And for Hex6:
|
||||||
|
# 1. unmeasurable
|
||||||
|
# 2. ~0.1-0.3s
|
||||||
|
# 3. ~0.05s
|
||||||
|
|
||||||
|
memprof.log_memory("Before reading partial raster", log=False)
|
||||||
|
|
||||||
|
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||||
|
raster = raster()
|
||||||
|
|
||||||
|
if hasattr(raster, "__dask_graph__"):
|
||||||
|
graph_size = len(raster.__dask_graph__())
|
||||||
|
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
|
||||||
|
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
|
||||||
|
|
||||||
|
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
||||||
|
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||||
|
data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||||
|
|
||||||
|
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||||
|
partial_extent = partial_extent.buffered(
|
||||||
|
raster.odc.geobox.resolution.x * pxbuffer,
|
||||||
|
raster.odc.geobox.resolution.y * pxbuffer,
|
||||||
|
) # buffer by pxbuffer pixels
|
||||||
|
with stopwatch("Cropping raster to partition extent", log=False):
|
||||||
|
try:
|
||||||
|
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cropping raster to partition extent: {e}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
|
||||||
|
memprof.log_memory("After reading partial raster", log=True)
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=60,
|
||||||
|
) as executor:
|
||||||
|
futures = {}
|
||||||
|
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||||
|
# ? Splitting the geometries already here and passing only the cropped data to the worker to
|
||||||
|
# reduce pickling overhead
|
||||||
|
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
|
||||||
|
if len(geoms) == 0:
|
||||||
|
continue
|
||||||
|
elif len(geoms) == 1:
|
||||||
|
cropped = [_read_cell_data(partial_raster, geoms[0])]
|
||||||
|
else:
|
||||||
|
cropped = _read_split_cell_data(partial_raster, geoms)
|
||||||
|
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
|
||||||
|
|
||||||
|
# Clean up immediately after submitting
|
||||||
|
del geoms, cropped # Added this line
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
i = futures[future]
|
||||||
|
try:
|
||||||
|
cell_data = future.result()
|
||||||
|
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing cell {i}: {e}")
|
||||||
|
continue
|
||||||
|
data[i, ...] = cell_data
|
||||||
|
del cell_data, futures[future]
|
||||||
|
|
||||||
|
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||||
|
# try:
|
||||||
|
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||||
|
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||||
|
# raise e
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error processing cell {row['cell_id']}: {e}")
|
||||||
|
# continue
|
||||||
|
# data[i, ...] = cell_data
|
||||||
|
|
||||||
|
partial_raster.close()
|
||||||
|
raster.close()
|
||||||
|
del partial_raster, raster, future, futures
|
||||||
|
gc.collect()
|
||||||
|
memprof.log_memory("After cleaning", log=True)
|
||||||
|
|
||||||
|
print("Finished processing partition")
|
||||||
|
print("### Stopwatch summary ###\n")
|
||||||
|
print(stopwatch.summary())
|
||||||
|
print("### Memory summary ###\n")
|
||||||
|
print(memprof.summary())
|
||||||
|
print("#########################")
|
||||||
|
# with ProcessPoolExecutor(
|
||||||
|
# max_workers=max_workers,
|
||||||
|
# initializer=_init_partial_raster_global,
|
||||||
|
# initargs=(raster, grid_partition_gdf, pxbuffer),
|
||||||
|
# ) as executor:
|
||||||
|
# futures = {
|
||||||
|
# executor.submit(_process_geom_global, row.geometry, aggregations): i
|
||||||
|
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
|
||||||
|
# }
|
||||||
|
|
||||||
|
# others_shape = tuple(
|
||||||
|
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||||
|
# )
|
||||||
|
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
|
||||||
|
# print(f"Allocating partial data array of shape {data_shape}")
|
||||||
|
# data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||||
|
|
||||||
|
# lap = time.time()
|
||||||
|
# for future in track(
|
||||||
|
# as_completed(futures),
|
||||||
|
# total=len(futures),
|
||||||
|
# description="Spatially aggregating ERA5 data...",
|
||||||
|
# ):
|
||||||
|
# print(f"time since last cell: {time.time() - lap:.2f}s")
|
||||||
|
# i = futures[future]
|
||||||
|
# try:
|
||||||
|
# cell_data = future.result()
|
||||||
|
# data[i, ...] = cell_data
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error processing cell {i}: {e}")
|
||||||
|
# continue
|
||||||
|
# finally:
|
||||||
|
# # Try to free memory
|
||||||
|
# del futures[future], future
|
||||||
|
# if "cell_data" in locals():
|
||||||
|
# del cell_data
|
||||||
|
# gc.collect()
|
||||||
|
# lap = time.time()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Aligning data with grid")
|
@stopwatch("Aligning data with grid")
|
||||||
def _align_data(
|
def _align_data(
|
||||||
grid_gdf: gpd.GeoDataFrame,
|
grid_gdf: gpd.GeoDataFrame,
|
||||||
unaligned: xr.Dataset,
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
aggregations: _Aggregations,
|
aggregations: _Aggregations,
|
||||||
) -> dict[str, np.ndarray]:
|
n_partitions: int,
|
||||||
# Persist the dataset, as all the operations here MUST NOT be lazy
|
max_workers: int,
|
||||||
unaligned = unaligned.load()
|
pxbuffer: int,
|
||||||
|
|
||||||
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
|
|
||||||
other_dims_shape = tuple(
|
|
||||||
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
|
||||||
)
|
|
||||||
data_shape = (len(cell_geometries), *other_dims_shape)
|
|
||||||
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
|
|
||||||
with ProcessPoolExecutor(max_workers=10) as executor:
|
|
||||||
futures = {}
|
|
||||||
for i, geoms in track(
|
|
||||||
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
|
||||||
):
|
):
|
||||||
if len(geoms) == 0:
|
partial_data = {}
|
||||||
continue
|
|
||||||
elif len(geoms) == 1:
|
with ProcessPoolExecutor(
|
||||||
geom = geoms[0]
|
max_workers=max_workers,
|
||||||
# Reduce the amount of data needed to be sent to the worker
|
initializer=_init_worker,
|
||||||
# Since we dont mask the data, only isel operations are done here
|
# initializer=_init_raster_global,
|
||||||
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
# initargs=(raster,),
|
||||||
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
) as executor:
|
||||||
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
|
futures = {}
|
||||||
else:
|
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||||
# Same as above but for multiple parts
|
futures[
|
||||||
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
executor.submit(
|
||||||
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
|
_align_partition,
|
||||||
futures[fut] = i
|
grid_partition,
|
||||||
|
raster.copy() if isinstance(raster, xr.Dataset) else raster,
|
||||||
|
aggregations,
|
||||||
|
pxbuffer,
|
||||||
|
)
|
||||||
|
] = i
|
||||||
|
|
||||||
|
print("Submitted all partitions, waiting for results...")
|
||||||
|
|
||||||
for future in track(
|
for future in track(
|
||||||
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
as_completed(futures),
|
||||||
|
total=len(futures),
|
||||||
|
description="Processing grid partitions...",
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
i = futures[future]
|
i = futures[future]
|
||||||
cell_data = future.result()
|
print(f"Processed partition {i + 1}/{len(futures)}")
|
||||||
for var in unaligned.data_vars:
|
part_data = future.result()
|
||||||
data[var][i, :] = cell_data[var]
|
partial_data[i] = part_data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing partition {i}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def aggregate_raster_into_grid(
|
def aggregate_raster_into_grid(
|
||||||
raster: xr.Dataset,
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
grid_gdf: gpd.GeoDataFrame,
|
grid_gdf: gpd.GeoDataFrame,
|
||||||
aggregations: _Aggregations,
|
aggregations: _Aggregations,
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Literal["hex", "healpix"],
|
||||||
level: int,
|
level: int,
|
||||||
|
n_partitions: int = 20,
|
||||||
|
max_workers: int = 5,
|
||||||
|
pxbuffer: int = 15,
|
||||||
):
|
):
|
||||||
aligned = _align_data(grid_gdf, raster, aggregations)
|
"""Aggregate raster data into grid cells.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
|
||||||
|
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
|
||||||
|
aggregations (_Aggregations): The aggregations to perform.
|
||||||
|
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||||
|
level (int): The level of the grid.
|
||||||
|
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||||
|
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
|
||||||
|
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.Dataset: Aggregated data aligned with the grid.
|
||||||
|
|
||||||
|
"""
|
||||||
|
aligned = _align_data(
|
||||||
|
grid_gdf,
|
||||||
|
raster,
|
||||||
|
aggregations,
|
||||||
|
n_partitions=n_partitions,
|
||||||
|
max_workers=max_workers,
|
||||||
|
pxbuffer=pxbuffer,
|
||||||
|
)
|
||||||
|
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
|
||||||
|
|
||||||
|
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||||
|
raster = raster()
|
||||||
|
|
||||||
cell_ids = grids.get_cell_ids(grid, level)
|
cell_ids = grids.get_cell_ids(grid, level)
|
||||||
|
|
||||||
dims = ["cell_ids"]
|
dims = ["cell_ids", "variables", "aggregations"]
|
||||||
coords = {"cell_ids": cell_ids}
|
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||||
for dim in raster.dims:
|
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||||
if dim not in ["y", "x", "latitude", "longitude"]:
|
|
||||||
dims.append(dim)
|
dims.append(dim)
|
||||||
coords[dim] = raster.coords[dim]
|
coords[dim] = raster.coords[dim]
|
||||||
|
|
||||||
data_vars = {var: (dims, values) for var, values in aligned.items()}
|
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
|
||||||
ongrid = xr.Dataset(
|
|
||||||
data_vars,
|
|
||||||
coords=coords,
|
|
||||||
)
|
|
||||||
gridinfo = {
|
gridinfo = {
|
||||||
"grid_name": "h3" if grid == "hex" else grid,
|
"grid_name": "h3" if grid == "hex" else grid,
|
||||||
"level": level,
|
"level": level,
|
||||||
|
|
@ -247,8 +509,7 @@ def aggregate_raster_into_grid(
|
||||||
gridinfo["indexing_scheme"] = "nested"
|
gridinfo["indexing_scheme"] = "nested"
|
||||||
ongrid.cell_ids.attrs = gridinfo
|
ongrid.cell_ids.attrs = gridinfo
|
||||||
for var in raster.data_vars:
|
for var in raster.data_vars:
|
||||||
for v in aggregations.varnames(var):
|
ongrid[var].attrs = raster[var].attrs
|
||||||
ongrid[v].attrs = raster[var].attrs
|
|
||||||
|
|
||||||
ongrid = xdggs.decode(ongrid)
|
ongrid = xdggs.decode(ongrid)
|
||||||
return ongrid
|
return ongrid
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
return feature.set(mean_dict)
|
return feature.set(mean_dict)
|
||||||
|
|
||||||
# Process grid in batches of 100
|
# Process grid in batches of 100
|
||||||
batch_size = 100
|
batch_size = 50
|
||||||
all_results = []
|
all_results = []
|
||||||
n_batches = len(grid_gdf) // batch_size
|
n_batches = len(grid_gdf) // batch_size
|
||||||
for batch_num, batch_grid in track(
|
for batch_num, batch_grid in track(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import multiprocessing as mp
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
@ -9,6 +10,7 @@ import cupyx.scipy.signal
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import dask.array
|
import dask.array
|
||||||
import dask.distributed as dd
|
import dask.distributed as dd
|
||||||
|
import icechunk
|
||||||
import icechunk.xarray
|
import icechunk.xarray
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import smart_geocubes
|
import smart_geocubes
|
||||||
|
|
@ -301,18 +303,23 @@ def enrich():
|
||||||
print("Enrichment complete.")
|
print("Enrichment complete.")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
def _open_adem():
|
||||||
def aggregate(grid: Literal["hex", "healpix"], level: int):
|
|
||||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
|
||||||
grid_gdf = grids.open(grid, level)
|
|
||||||
# ? Mask out water, since we don't want to aggregate over oceans
|
|
||||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
|
||||||
|
|
||||||
arcticdem_store = get_arcticdem_stores()
|
arcticdem_store = get_arcticdem_stores()
|
||||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
adem = accessor.open_xarray()
|
adem = accessor.open_xarray()
|
||||||
assert {"x", "y"} == set(adem.dims)
|
assert {"x", "y"} == set(adem.dims)
|
||||||
assert adem.odc.crs == "EPSG:3413"
|
assert adem.odc.crs == "EPSG:3413"
|
||||||
|
return adem
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
|
||||||
|
mp.set_start_method("forkserver", force=True)
|
||||||
|
|
||||||
|
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||||
|
grid_gdf = grids.open(grid, level)
|
||||||
|
# ? Mask out water, since we don't want to aggregate over oceans
|
||||||
|
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||||
|
|
||||||
aggregations = _Aggregations(
|
aggregations = _Aggregations(
|
||||||
mean=True,
|
mean=True,
|
||||||
|
|
@ -320,13 +327,29 @@ def aggregate(grid: Literal["hex", "healpix"], level: int):
|
||||||
std=True,
|
std=True,
|
||||||
min=True,
|
min=True,
|
||||||
max=True,
|
max=True,
|
||||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
median=True,
|
||||||
|
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = aggregate_raster_into_grid(
|
||||||
|
_open_adem,
|
||||||
|
grid_gdf,
|
||||||
|
aggregations,
|
||||||
|
grid,
|
||||||
|
level,
|
||||||
|
max_workers=max_workers,
|
||||||
|
n_partitions=200, # Results in 10GB large patches
|
||||||
|
pxbuffer=30,
|
||||||
)
|
)
|
||||||
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
|
|
||||||
store = get_arcticdem_stores(grid, level)
|
store = get_arcticdem_stores(grid, level)
|
||||||
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
||||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||||
|
|
||||||
|
print("Aggregation complete.")
|
||||||
|
|
||||||
|
print("### Finished ArcticDEM processing ###")
|
||||||
|
stopwatch.summary()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue