leaking grid aggregations

This commit is contained in:
Tobias Hölzer 2025-11-24 22:40:17 +01:00
parent 18cc1b8601
commit 341fa7b836
5 changed files with 521 additions and 154 deletions

85
pixi.lock generated
View file

@ -543,6 +543,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
@ -594,6 +595,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
@ -2808,7 +2810,7 @@ packages:
- pypi: ./ - pypi: ./
name: entropice name: entropice
version: 0.1.0 version: 0.1.0
sha256: d95c691c76206bf54e207fe02b100a247a3847f37135d1cdf6ee18165770ea46 sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
requires_dist: requires_dist:
- aiohttp>=3.12.11 - aiohttp>=3.12.11
- bokeh>=3.7.3 - bokeh>=3.7.3
@ -2862,6 +2864,7 @@ packages:
- s3fs>=2025.10.0,<2026 - s3fs>=2025.10.0,<2026
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial - xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
- cupy-xarray>=0.1.4,<0.2 - cupy-xarray>=0.1.4,<0.2
- memray>=1.19.1,<2
requires_python: '>=3.13,<3.14' requires_python: '>=3.13,<3.14'
editable: true editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
@ -6344,6 +6347,58 @@ packages:
- pkg:pypi/mdurl?source=hash-mapping - pkg:pypi/mdurl?source=hash-mapping
size: 14465 size: 14465
timestamp: 1733255681319 timestamp: 1733255681319
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
name: memray
version: 1.19.1
sha256: 9770b6046a8b7d7bb982c2cefb324a60a0ee7cd7c35c58c0df058355a9a54676
requires_dist:
- jinja2>=2.9
- typing-extensions ; python_full_version < '3.8'
- rich>=11.2.0
- textual>=0.41.0
- cython ; extra == 'test'
- greenlet ; python_full_version < '3.14' and extra == 'test'
- pytest ; extra == 'test'
- pytest-cov ; extra == 'test'
- ipython ; extra == 'test'
- setuptools ; extra == 'test'
- pytest-textual-snapshot ; extra == 'test'
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'test'
- packaging ; extra == 'test'
- ipython ; extra == 'docs'
- bump2version ; extra == 'docs'
- sphinx ; extra == 'docs'
- furo ; extra == 'docs'
- sphinx-argparse ; extra == 'docs'
- towncrier ; extra == 'docs'
- black ; extra == 'lint'
- flake8 ; extra == 'lint'
- isort ; extra == 'lint'
- mypy ; extra == 'lint'
- check-manifest ; extra == 'lint'
- asv ; extra == 'benchmark'
- cython ; extra == 'dev'
- greenlet ; python_full_version < '3.14' and extra == 'dev'
- pytest ; extra == 'dev'
- pytest-cov ; extra == 'dev'
- ipython ; extra == 'dev'
- setuptools ; extra == 'dev'
- pytest-textual-snapshot ; extra == 'dev'
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'dev'
- packaging ; extra == 'dev'
- black ; extra == 'dev'
- flake8 ; extra == 'dev'
- isort ; extra == 'dev'
- mypy ; extra == 'dev'
- check-manifest ; extra == 'dev'
- ipython ; extra == 'dev'
- bump2version ; extra == 'dev'
- sphinx ; extra == 'dev'
- furo ; extra == 'dev'
- sphinx-argparse ; extra == 'dev'
- towncrier ; extra == 'dev'
- asv ; extra == 'dev'
requires_python: '>=3.7.0'
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
name: mercantile name: mercantile
version: 1.2.1 version: 1.2.1
@ -8866,6 +8921,34 @@ packages:
- pkg:pypi/terminado?source=hash-mapping - pkg:pypi/terminado?source=hash-mapping
size: 22452 size: 22452
timestamp: 1710262728753 timestamp: 1710262728753
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
name: textual
version: 6.6.0
sha256: 5a9484bd15ee8a6fd8ac4ed4849fb25ee56bed2cecc7b8a83c4cd7d5f19515e5
requires_dist:
- markdown-it-py[linkify]>=2.1.0
- mdit-py-plugins
- platformdirs>=3.6.0,<5
- pygments>=2.19.2,<3.0.0
- rich>=14.2.0
- tree-sitter>=0.25.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-bash>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-css>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-go>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-html>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-java>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-javascript>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-json>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-markdown>=0.3.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-python>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-regex>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-rust>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-sql>=0.3.11 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-toml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-xml>=0.7.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-yaml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
- typing-extensions>=4.4.0,<5.0.0
requires_python: '>=3.9,<4.0'
- conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda
sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd
md5: 9d64911b31d57ca443e9f1e36b04385f md5: 9d64911b31d57ca443e9f1e36b04385f

View file

@ -57,7 +57,7 @@ dependencies = [
"xgboost>=3.1.1,<4", "xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026", "s3fs>=2025.10.0,<2026",
"xarray-spatial", "xarray-spatial",
"cupy-xarray>=0.1.4,<0.2", "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
] ]
[project.scripts] [project.scripts]

View file

@ -1,16 +1,27 @@
"""Aggregation helpers."""
import gc
import os
import sys
from collections import defaultdict
from collections.abc import Callable, Generator
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache
from typing import Literal from typing import Literal
import cupy_xarray
import geopandas as gpd import geopandas as gpd
import numpy as np import numpy as np
import odc.geo.geobox import odc.geo.geobox
import pandas as pd
import psutil
import shapely import shapely
import shapely.ops import shapely.ops
import sklearn
import xarray as xr import xarray as xr
import xdggs import xdggs
import xvec import xvec
from rich import print
from rich.progress import track from rich.progress import track
from shapely.geometry import LineString, Polygon from shapely.geometry import LineString, Polygon
from stopuhr import stopwatch from stopuhr import stopwatch
@ -19,34 +30,100 @@ from xdggs.healpix import HealpixInfo
from entropice import grids from entropice import grids
@dataclass @dataclass(frozen=True)
class _Aggregations: class _Aggregations:
# ! The ordering is super important for this class!
mean: bool = True mean: bool = True
sum: bool = False sum: bool = False
std: bool = False std: bool = False
min: bool = False min: bool = False
max: bool = False max: bool = False
quantiles: list[float] = field(default_factory=lambda: []) median: bool = False # Will be mapped as quantile 0.5
quantiles: tuple[float] = field(default_factory=tuple)
def varnames(self, vars: list[str] | str) -> list[str]: def __post_init__(self):
if isinstance(vars, str): assert isinstance(self.quantiles, tuple), "Quantiles must be a tuple, ortherwise the class is not hashable"
vars = [vars] if self.median and 0.5 in self.quantiles:
agg_vars = [] raise ValueError("Median aggregation cannot be used together with quantile 0.5")
for var in vars:
if self.mean: @cache
agg_vars.append(f"{var}_mean") def __len__(self) -> int:
if self.sum: length = 0
agg_vars.append(f"{var}_sum") if self.mean:
if self.std: length += 1
agg_vars.append(f"{var}_std") if self.sum:
if self.min: length += 1
agg_vars.append(f"{var}_min") if self.std:
if self.max: length += 1
agg_vars.append(f"{var}_max") if self.min:
for q in self.quantiles: length += 1
q_int = int(q * 100) if self.max:
agg_vars.append(f"{var}_p{q_int}") length += 1
return agg_vars if self.median and 0.5 not in self.quantiles:
length += 1
length += len(self.quantiles)
return length
def aggnames(self) -> list[str]:
names = []
if self.mean:
names.append("mean")
if self.sum:
names.append("sum")
if self.std:
names.append("std")
if self.min:
names.append("min")
if self.max:
names.append("max")
if self.median:
names.append("median")
for q in sorted(self.quantiles): # ? Ordering is important here!
q_int = int(q * 100)
names.append(f"p{q_int}")
return names
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
j = 0
if self.mean:
cell_data[j, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
j += 1
if self.sum:
cell_data[j, ...] = flattened_var.sum(dim="z", skipna=True).to_numpy()
j += 1
if self.std:
cell_data[j, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
j += 1
if self.min:
cell_data[j, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
j += 1
if self.max:
cell_data[j, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
j += 1
if len(self.quantiles) > 0 or self.median:
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
if self.median:
quantiles_to_compute.insert(0, 0.5)
cell_data[j:, ...] = flattened_var.quantile(
q=quantiles_to_compute,
dim="z",
skipna=True,
).to_numpy()
def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray:
if isinstance(flattened, xr.DataArray):
return self._agg_cell_data_single(flattened)
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
for i, var in enumerate(flattened.data_vars):
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
# Transform to numpy arrays
# if flattened.cupy.is_cupy:
# cell_data = cell_data.get()
return cell_data
def _crosses_antimeridian(geom: Polygon) -> bool: def _crosses_antimeridian(geom: Polygon) -> bool:
@ -77,11 +154,11 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1 return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
@stopwatch("Getting corrected geometries", log=False) @stopwatch("Correcting geometries", log=False)
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]: def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
geom, gbox, crs = inp geom, gbox, crs = inp
# cell.geometry is a shapely Polygon # cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom): if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
geoms = [geom] geoms = [geom]
# Split geometry in case it crossed antimeridian # Split geometry in case it crossed antimeridian
else: else:
@ -92,153 +169,338 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
return geoms return geoms
@stopwatch("Correcting geometries")
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
"""Get corrected geometries for antimeridian-crossing polygons.
Args:
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
Returns:
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
"""
with ProcessPoolExecutor(max_workers=20) as executor:
cell_geometries = list(
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
)
return cell_geometries
@stopwatch("Aggregating cell data", log=False)
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
cell_data = {}
for var in flattened.data_vars:
if aggregations.mean:
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
if aggregations.sum:
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
if aggregations.std:
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
if aggregations.min:
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
if aggregations.max:
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
if len(aggregations.quantiles) > 0:
quantile_values = flattened[var].quantile(
q=aggregations.quantiles,
dim="z",
skipna=True,
)
for q, qv in zip(aggregations.quantiles, quantile_values):
q_int = int(q * 100)
cell_data[f"{var}_p{q_int}"] = qv
# Transform to numpy arrays
for key in cell_data:
cell_data[key] = cell_data[key].cupy.as_numpy().values
return cell_data
@stopwatch("Extracting cell data", log=False) @stopwatch("Extracting cell data", log=False)
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations): def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref") flattened = cropped.stack(z=spatdims) # noqa: PD013
flattened = cropped.stack(z=spatdims) # if flattened.z.size > 3000:
if flattened.z.size > 3000: # flattened = flattened.cupy.as_cupy()
flattened = flattened.cupy.as_cupy() cell_data = aggregations.agg_cell_data(flattened)
cell_data = _agg_cell_data(flattened, aggregations)
return cell_data return cell_data
# with np.errstate(divide="ignore", invalid="ignore"):
# cell_data = cropped.mean(spatdims)
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Extracting split cell data", log=False) @stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations): def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] spatdims = (
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)] ["latitude", "longitude"]
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z") if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
cell_data = _agg_cell_data(flattened, aggregations) else ["y", "x"]
)
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
cell_data = aggregations.agg_cell_data(flattened)
return cell_data return cell_data
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
# with np.errstate(divide="ignore", invalid="ignore"): @stopwatch("Cropping (and reading?) cell data", log=False)
# partial_means = [part.sum(dim=spatdims) for part in parts] def _read_cell_data(unaligned: xr.Dataset | xr.DataArray, geom: odc.geo.Geometry) -> xr.Dataset:
# n = xr.concat(partial_counts, dim="part").sum("part") cropped: xr.Dataset = unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref")
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n return cropped.compute()
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Cropping (and reading?) split cell data", log=False)
def _read_split_cell_data(unaligned: xr.Dataset | xr.DataArray, geoms: list[odc.geo.Geometry]) -> list[xr.Dataset]:
cropped: list[xr.Dataset] = [unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") for geom in geoms]
return [c.compute() for c in cropped]
@stopwatch("Processing cell", log=False)
def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
geoms = _get_corrected_geoms((poly, unaligned.odc.geobox, unaligned.odc.crs))
if len(geoms) == 0:
raise ValueError("No geometries to process")
elif len(geoms) == 1:
cropped = _read_cell_data(unaligned, geoms[0])
cell_data = _extract_cell_data(cropped, aggregations)
else:
cropped = _read_split_cell_data(unaligned, geoms)
cell_data = _extract_split_cell_data(cropped, aggregations)
return cell_data
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
for i in range(n_partitions):
partition = grid_gdf[labels == i]
yield partition
class _MemoryProfiler:
def __init__(self):
self.process = psutil.Process()
self._logs = defaultdict(list)
def log_memory(self, message: str, log: bool = True):
mem = self.process.memory_info().rss / 1024**3
self._logs[message].append(mem)
if log:
print(f"[MemoryProfiler] {message}: {mem:.2f} GB")
def summary(self):
summary_lines = ["[MemoryProfiler] Memory usage summary:"]
for message, mems in self._logs.items():
mems_str = [f"{m:.2f}" for m in mems]
summary_lines.append(f" {message}: min={min(mems):.2f} GB, max={max(mems):.2f} GB")
summary_lines.append(f" Measurements: {', '.join(mems_str)}")
return "\n".join(summary_lines)
memprof = None
def _init_worker():
global memprof
memprof = _MemoryProfiler()
def _align_partition(
grid_partition_gdf: gpd.GeoDataFrame,
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
pxbuffer: int,
):
# Strategy for each cell:
# 1. Correct the geometry to account for the antimeridian
# 2. Cop the dataset and load the data into memory
# 3. Aggregate the data
# 4. Store the aggregated data in the correct location in the output arrays
# Steps 1-3 are done in different Processes, since all of these steps are CPU bound
# Even the reading part is CPU bound, since it involves applying masks and cropping the data
# Also, the backend (icechunk or zarr) can do some computations, like combining chunks together
# For the Arcticdem on Hex3 grid the avg. time per task is:
# 1. ~0.01s
# 2. ~1-3s
# 3. ~10-20s
# And for Hex6:
# 1. unmeasurable
# 2. ~0.1-0.3s
# 3. ~0.05s
memprof.log_memory("Before reading partial raster", log=False)
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
if hasattr(raster, "__dask_graph__"):
graph_size = len(raster.__dask_graph__())
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
data = np.full(data_shape, np.nan, dtype=np.float32)
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
return data
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
memprof.log_memory("After reading partial raster", log=True)
with ProcessPoolExecutor(
max_workers=60,
) as executor:
futures = {}
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# ? Splitting the geometries already here and passing only the cropped data to the worker to
# reduce pickling overhead
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
if len(geoms) == 0:
continue
elif len(geoms) == 1:
cropped = [_read_cell_data(partial_raster, geoms[0])]
else:
cropped = _read_split_cell_data(partial_raster, geoms)
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
# Clean up immediately after submitting
del geoms, cropped # Added this line
for future in as_completed(futures):
i = futures[future]
try:
cell_data = future.result()
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {i}: {e}")
continue
data[i, ...] = cell_data
del cell_data, futures[future]
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# try:
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
# raise e
# except Exception as e:
# print(f"Error processing cell {row['cell_id']}: {e}")
# continue
# data[i, ...] = cell_data
partial_raster.close()
raster.close()
del partial_raster, raster, future, futures
gc.collect()
memprof.log_memory("After cleaning", log=True)
print("Finished processing partition")
print("### Stopwatch summary ###\n")
print(stopwatch.summary())
print("### Memory summary ###\n")
print(memprof.summary())
print("#########################")
# with ProcessPoolExecutor(
# max_workers=max_workers,
# initializer=_init_partial_raster_global,
# initargs=(raster, grid_partition_gdf, pxbuffer),
# ) as executor:
# futures = {
# executor.submit(_process_geom_global, row.geometry, aggregations): i
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
# }
# others_shape = tuple(
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
# )
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
# print(f"Allocating partial data array of shape {data_shape}")
# data = np.full(data_shape, np.nan, dtype=np.float32)
# lap = time.time()
# for future in track(
# as_completed(futures),
# total=len(futures),
# description="Spatially aggregating ERA5 data...",
# ):
# print(f"time since last cell: {time.time() - lap:.2f}s")
# i = futures[future]
# try:
# cell_data = future.result()
# data[i, ...] = cell_data
# except Exception as e:
# print(f"Error processing cell {i}: {e}")
# continue
# finally:
# # Try to free memory
# del futures[future], future
# if "cell_data" in locals():
# del cell_data
# gc.collect()
# lap = time.time()
return data
@stopwatch("Aligning data with grid") @stopwatch("Aligning data with grid")
def _align_data( def _align_data(
grid_gdf: gpd.GeoDataFrame, grid_gdf: gpd.GeoDataFrame,
unaligned: xr.Dataset, raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations, aggregations: _Aggregations,
) -> dict[str, np.ndarray]: n_partitions: int,
# Persist the dataset, as all the operations here MUST NOT be lazy max_workers: int,
unaligned = unaligned.load() pxbuffer: int,
):
partial_data = {}
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox) with ProcessPoolExecutor(
other_dims_shape = tuple( max_workers=max_workers,
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]] initializer=_init_worker,
) # initializer=_init_raster_global,
data_shape = (len(cell_geometries), *other_dims_shape) # initargs=(raster,),
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)} ) as executor:
with ProcessPoolExecutor(max_workers=10) as executor:
futures = {} futures = {}
for i, geoms in track( for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..." futures[
): executor.submit(
if len(geoms) == 0: _align_partition,
continue grid_partition,
elif len(geoms) == 1: raster.copy() if isinstance(raster, xr.Dataset) else raster,
geom = geoms[0] aggregations,
# Reduce the amount of data needed to be sent to the worker pxbuffer,
# Since we dont mask the data, only isel operations are done here )
# Thus, to properly extract the subset, another masked crop needs to be done in the worker ] = i
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations) print("Submitted all partitions, waiting for results...")
else:
# Same as above but for multiple parts
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
futures[fut] = i
for future in track( for future in track(
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..." as_completed(futures),
total=len(futures),
description="Processing grid partitions...",
): ):
i = futures[future] try:
cell_data = future.result() i = futures[future]
for var in unaligned.data_vars: print(f"Processed partition {i + 1}/{len(futures)}")
data[var][i, :] = cell_data[var] part_data = future.result()
partial_data[i] = part_data
except Exception as e:
print(f"Error processing partition {i}: {e}")
raise e
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
return data return data
def aggregate_raster_into_grid( def aggregate_raster_into_grid(
raster: xr.Dataset, raster: xr.Dataset | Callable[[], xr.Dataset],
grid_gdf: gpd.GeoDataFrame, grid_gdf: gpd.GeoDataFrame,
aggregations: _Aggregations, aggregations: _Aggregations,
grid: Literal["hex", "healpix"], grid: Literal["hex", "healpix"],
level: int, level: int,
n_partitions: int = 20,
max_workers: int = 5,
pxbuffer: int = 15,
): ):
aligned = _align_data(grid_gdf, raster, aggregations) """Aggregate raster data into grid cells.
Args:
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
aggregations (_Aggregations): The aggregations to perform.
grid (Literal["hex", "healpix"]): The type of grid to use.
level (int): The level of the grid.
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
Returns:
xr.Dataset: Aggregated data aligned with the grid.
"""
aligned = _align_data(
grid_gdf,
raster,
aggregations,
n_partitions=n_partitions,
max_workers=max_workers,
pxbuffer=pxbuffer,
)
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
cell_ids = grids.get_cell_ids(grid, level) cell_ids = grids.get_cell_ids(grid, level)
dims = ["cell_ids"] dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids} coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in raster.dims: for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
if dim not in ["y", "x", "latitude", "longitude"]: dims.append(dim)
dims.append(dim) coords[dim] = raster.coords[dim]
coords[dim] = raster.coords[dim]
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
data_vars = {var: (dims, values) for var, values in aligned.items()}
ongrid = xr.Dataset(
data_vars,
coords=coords,
)
gridinfo = { gridinfo = {
"grid_name": "h3" if grid == "hex" else grid, "grid_name": "h3" if grid == "hex" else grid,
"level": level, "level": level,
@ -247,8 +509,7 @@ def aggregate_raster_into_grid(
gridinfo["indexing_scheme"] = "nested" gridinfo["indexing_scheme"] = "nested"
ongrid.cell_ids.attrs = gridinfo ongrid.cell_ids.attrs = gridinfo
for var in raster.data_vars: for var in raster.data_vars:
for v in aggregations.varnames(var): ongrid[var].attrs = raster[var].attrs
ongrid[v].attrs = raster[var].attrs
ongrid = xdggs.decode(ongrid) ongrid = xdggs.decode(ongrid)
return ongrid return ongrid

View file

@ -77,7 +77,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
return feature.set(mean_dict) return feature.set(mean_dict)
# Process grid in batches of 100 # Process grid in batches of 100
batch_size = 100 batch_size = 50
all_results = [] all_results = []
n_batches = len(grid_gdf) // batch_size n_batches = len(grid_gdf) // batch_size
for batch_num, batch_grid in track( for batch_num, batch_grid in track(

View file

@ -1,4 +1,5 @@
import datetime import datetime
import multiprocessing as mp
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Literal from typing import Literal
@ -9,6 +10,7 @@ import cupyx.scipy.signal
import cyclopts import cyclopts
import dask.array import dask.array
import dask.distributed as dd import dask.distributed as dd
import icechunk
import icechunk.xarray import icechunk.xarray
import numpy as np import numpy as np
import smart_geocubes import smart_geocubes
@ -301,18 +303,23 @@ def enrich():
print("Enrichment complete.") print("Enrichment complete.")
@cli.command() def _open_adem():
def aggregate(grid: Literal["hex", "healpix"], level: int):
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
arcticdem_store = get_arcticdem_stores() arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem = accessor.open_xarray() adem = accessor.open_xarray()
assert {"x", "y"} == set(adem.dims) assert {"x", "y"} == set(adem.dims)
assert adem.odc.crs == "EPSG:3413" assert adem.odc.crs == "EPSG:3413"
return adem
@cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
mp.set_start_method("forkserver", force=True)
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
aggregations = _Aggregations( aggregations = _Aggregations(
mean=True, mean=True,
@ -320,13 +327,29 @@ def aggregate(grid: Literal["hex", "healpix"], level: int):
std=True, std=True,
min=True, min=True,
max=True, max=True,
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99], median=True,
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
)
aggregated = aggregate_raster_into_grid(
_open_adem,
grid_gdf,
aggregations,
grid,
level,
max_workers=max_workers,
n_partitions=200, # Results in 10GB large patches
pxbuffer=30,
) )
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
store = get_arcticdem_stores(grid, level) store = get_arcticdem_stores(grid, level)
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"): with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
print("Aggregation complete.")
print("### Finished ArcticDEM processing ###")
stopwatch.summary()
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()